fix 提取数据不成功
This commit is contained in:
19
Decoder.py
19
Decoder.py
@@ -276,8 +276,14 @@ class Decoder_main(threading.Thread):
|
|||||||
|
|
||||||
'''训练阶段采集数据'''
|
'''训练阶段采集数据'''
|
||||||
if self.zmqServer.state_mode == 'train': # 训练状态
|
if self.zmqServer.state_mode == 'train': # 训练状态
|
||||||
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
|
if self.zmqServer.pack_contain_event:
|
||||||
self.train_epoch[1] + self.zmqServer.event_inner_idx:
|
with self.zmqServer.paradigmBufferLock:
|
||||||
|
self.zmqServer.paradigmBuffer.resetAllPara()
|
||||||
|
self.zmqServer.pack_contain_event = False
|
||||||
|
|
||||||
|
if self.zmqServer.epoch_finished:
|
||||||
|
data_length = self.zmqServer.paradigmBuffer.GetDataLenCount()
|
||||||
|
if data_length >= self.train_epoch[1] + self.zmqServer.event_inner_idx:
|
||||||
self.currentLabel = self.zmqServer.currentLabel
|
self.currentLabel = self.zmqServer.currentLabel
|
||||||
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
|
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
|
||||||
algo_log(f"取出的:{trainTrial.shape},event:{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
algo_log(f"取出的:{trainTrial.shape},event:{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||||
@@ -291,7 +297,10 @@ class Decoder_main(threading.Thread):
|
|||||||
self.trainLabel.append(self.currentLabel)
|
self.trainLabel.append(self.currentLabel)
|
||||||
algo_log(f"SSMVEP训练集:{np.shape(self.trainData)}", level="DEBUG")
|
algo_log(f"SSMVEP训练集:{np.shape(self.trainData)}", level="DEBUG")
|
||||||
else:
|
else:
|
||||||
time.sleep(0.0001)
|
algo_log(f"epoch_finished {self.zmqServer.epoch_finished}, 数据长度不足 {data_length}", level="DEBUG")
|
||||||
|
self.zmqServer.epoch_finished = False
|
||||||
|
else:
|
||||||
|
time.sleep(0.001)
|
||||||
return
|
return
|
||||||
|
|
||||||
elif self.zmqServer.state_mode == 'predict': # 测试状态
|
elif self.zmqServer.state_mode == 'predict': # 测试状态
|
||||||
@@ -395,7 +404,7 @@ class Decoder_main(threading.Thread):
|
|||||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
||||||
self.plotLabel.append(self.currentLabel)
|
self.plotLabel.append(self.currentLabel)
|
||||||
else:
|
else:
|
||||||
time.sleep(0.0001)
|
time.sleep(0.001)
|
||||||
return
|
return
|
||||||
|
|
||||||
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
|
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
|
||||||
@@ -408,7 +417,7 @@ class Decoder_main(threading.Thread):
|
|||||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||||
self.interval_epoch[1] \
|
self.interval_epoch[1] \
|
||||||
+ self.zmqServer.event_inner_idx:
|
+ self.zmqServer.event_inner_idx:
|
||||||
time.sleep(0.0001)
|
time.sleep(0.001)
|
||||||
return
|
return
|
||||||
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
|
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
|
||||||
algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||||
|
|||||||
@@ -40,4 +40,6 @@ python upperHost_stimmock/MI_headless.py
|
|||||||
## MI
|
## MI
|
||||||
Epoch采集完成|收到命令: {'method': 'train'|取出的
|
Epoch采集完成|收到命令: {'method': 'train'|取出的
|
||||||
|
|
||||||
收到命令: {'method': 'train'|收到命令: {'method': 'train'|收到命令: {'method': 'predict'|事件检测到
|
收到命令: {'method': 'train'|收到命令: {'method': 'predict'|Epoch采集完成|事件检测到
|
||||||
|
|
||||||
|
收到命令: {'method': 'train|Epoch采集完成|事件检测到|取出的|SSMVEP训练集
|
||||||
@@ -339,29 +339,12 @@ class zmqServer(threading.Thread):
|
|||||||
|
|
||||||
# 写入范式缓冲区
|
# 写入范式缓冲区
|
||||||
with self.paradigmBufferLock:
|
with self.paradigmBufferLock:
|
||||||
|
self.paradigmBuffer.appendBuffer(data_np)
|
||||||
if self.interval_inited:
|
if self.interval_inited:
|
||||||
self.epoch_finished = self.detect_event(data_np)
|
self.pack_contain_event, self.epoch_finished = self.detect_event(data_np)
|
||||||
if self.pack_contain_event:
|
|
||||||
self.paradigmBuffer.resetAllPara()
|
|
||||||
self.paradigmBuffer.appendBuffer(data_np)
|
|
||||||
|
|
||||||
if self.epoch_finished:
|
if self.epoch_finished:
|
||||||
now = datetime.datetime.now()
|
algo_log(f"Epoch采集完成, 当前数据长度{self.paradigmBuffer.GetDataLenCount()}", level="DEBUG")
|
||||||
time_diff_str = ""
|
|
||||||
# 计算与上一次Epoch完成的时间差
|
|
||||||
if self.last_epoch_finish_time is not None:
|
|
||||||
# 时间差 单位:秒,保留3位小数
|
|
||||||
delta_seconds = (now - self.last_epoch_finish_time).total_seconds()
|
|
||||||
time_diff_str = f" | 与上一次间隔: {delta_seconds:.3f} s"
|
|
||||||
|
|
||||||
# 拼接日志,增加时间差信息
|
|
||||||
log_msg = f"Epoch采集完成: {now.strftime('%H:%M:%S.%f')[:-3]}{time_diff_str}"
|
|
||||||
algo_log(log_msg, level="DEBUG")
|
|
||||||
|
|
||||||
# 更新上一次Epoch完成时间为当前时间
|
|
||||||
self.last_epoch_finish_time = now
|
|
||||||
else:
|
|
||||||
self.paradigmBuffer.appendBuffer(data_np)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
algo_log(f"数据处理失败: {str(e)}", level="ERROR")
|
algo_log(f"数据处理失败: {str(e)}", level="ERROR")
|
||||||
@@ -371,7 +354,8 @@ class zmqServer(threading.Thread):
|
|||||||
|
|
||||||
# -------------------------- 事件检测 --------------------------
|
# -------------------------- 事件检测 --------------------------
|
||||||
def detect_event(self, samples):
|
def detect_event(self, samples):
|
||||||
self.pack_contain_event = False
|
pack_contain_event = False
|
||||||
|
epoch_finished = False
|
||||||
# 第65通道为事件通道
|
# 第65通道为事件通道
|
||||||
events = np.array(samples[-2], dtype=np.int32).tolist()
|
events = np.array(samples[-2], dtype=np.int32).tolist()
|
||||||
for idx, event in enumerate(events):
|
for idx, event in enumerate(events):
|
||||||
@@ -383,14 +367,20 @@ class zmqServer(threading.Thread):
|
|||||||
-%H-%M-%S"),
|
-%H-%M-%S"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
if len(self.count_events) > 0:
|
||||||
|
algo_log(f"当前有事件未采集完成,新事件{new_key}非法,被忽略")
|
||||||
|
return pack_contain_event, epoch_finished
|
||||||
|
else:
|
||||||
self.currentLabel = event
|
self.currentLabel = event
|
||||||
|
pack_contain_event = True
|
||||||
if event == self.predict_event:
|
if event == self.predict_event:
|
||||||
self.count_events[new_key] = self.latency + 1
|
self.count_events[new_key] = self.latency + 1
|
||||||
else:
|
else:
|
||||||
self.count_events[new_key] = self.train_latency + 1
|
self.count_events[new_key] = self.train_latency + 1
|
||||||
self.event_inner_idx = idx
|
self.event_inner_idx = idx
|
||||||
algo_log(f"事件检测到: {events},索引: {idx}", level="DEBUG")
|
algo_log(f"事件检测到: {events},索引: {idx}", level="DEBUG")
|
||||||
self.pack_contain_event = True
|
else:
|
||||||
|
pack_contain_event = False
|
||||||
|
|
||||||
# 倒计时并清理过期事件
|
# 倒计时并清理过期事件
|
||||||
drop_items = []
|
drop_items = []
|
||||||
@@ -403,9 +393,13 @@ class zmqServer(threading.Thread):
|
|||||||
for key in drop_items:
|
for key in drop_items:
|
||||||
del self.count_events[key]
|
del self.count_events[key]
|
||||||
|
|
||||||
if drop_items:
|
if len(drop_items) > 0:
|
||||||
return True
|
epoch_finished = True
|
||||||
return False
|
else:
|
||||||
|
epoch_finished = False
|
||||||
|
return pack_contain_event, epoch_finished
|
||||||
|
|
||||||
|
|
||||||
# -------------------------- 主循环 --------------------------
|
# -------------------------- 主循环 --------------------------
|
||||||
def run(self):
|
def run(self):
|
||||||
self.running = True
|
self.running = True
|
||||||
|
|||||||
14
datamock.py
14
datamock.py
@@ -17,6 +17,10 @@ LABEL_CMD_ADDR = 'tcp://127.0.0.1:8101' # 接收来自上位机范式的标签
|
|||||||
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms
|
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms
|
||||||
PKT_INTERVAL = N_SAMPLES_PER_PKT / FS
|
PKT_INTERVAL = N_SAMPLES_PER_PKT / FS
|
||||||
|
|
||||||
|
POINT_PER_3S = FS * 3 # 750
|
||||||
|
# 3秒对应的总包数(关键:每150包 = 3s)
|
||||||
|
PKT_PER_3S = POINT_PER_3S // N_SAMPLES_PER_PKT # 150
|
||||||
|
|
||||||
|
|
||||||
def build_packet(global_sample_idx):
|
def build_packet(global_sample_idx):
|
||||||
"""
|
"""
|
||||||
@@ -34,6 +38,16 @@ def build_packet(global_sample_idx):
|
|||||||
|
|
||||||
# Ch64: 标签值通道,初始化为 0
|
# Ch64: 标签值通道,初始化为 0
|
||||||
event = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
|
event = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
|
||||||
|
current_pkt_idx = global_sample_idx // N_SAMPLES_PER_PKT
|
||||||
|
|
||||||
|
# 判断是否为 3s 整数倍对应的包
|
||||||
|
if current_pkt_idx % PKT_PER_3S == 0:
|
||||||
|
# 当前是第 N 个3s节点:1、2、1、2...交替
|
||||||
|
cycle_num = (current_pkt_idx // PKT_PER_3S) % 2
|
||||||
|
if cycle_num == 0:
|
||||||
|
event[0, 0] = 1.0
|
||||||
|
else:
|
||||||
|
event[0, 0] = 2.0
|
||||||
|
|
||||||
# Ch65: 标签序号通道,初始化为 0
|
# Ch65: 标签序号通道,初始化为 0
|
||||||
label_idx = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
|
label_idx = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
|
||||||
|
|||||||
Reference in New Issue
Block a user