diff --git a/Decoder.py b/Decoder.py index 6aa4bec..6b83fed 100644 --- a/Decoder.py +++ b/Decoder.py @@ -276,22 +276,31 @@ class Decoder_main(threading.Thread): '''训练阶段采集数据''' if self.zmqServer.state_mode == 'train': # 训练状态 - if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \ - self.train_epoch[1] + self.zmqServer.event_inner_idx: - self.currentLabel = self.zmqServer.currentLabel - trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据 - algo_log(f"取出的:{trainTrial.shape},event:{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") - trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理 - trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[ - 0]:self.zmqServer.event_inner_idx + self.train_epoch[1]] - if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance( - self.trainLabel, list) \ - and self.trainLabel.count(self.currentLabel) < self.single_train: - self.trainData.append(trainTrial) - self.trainLabel.append(self.currentLabel) - algo_log(f"SSMVEP训练集:{np.shape(self.trainData)}", level="DEBUG") + if self.zmqServer.pack_contain_event: + with self.zmqServer.paradigmBufferLock: + self.zmqServer.paradigmBuffer.resetAllPara() + self.zmqServer.pack_contain_event = False + + if self.zmqServer.epoch_finished: + data_length = self.zmqServer.paradigmBuffer.GetDataLenCount() + if data_length >= self.train_epoch[1] + self.zmqServer.event_inner_idx: + self.currentLabel = self.zmqServer.currentLabel + trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据 + algo_log(f"取出的:{trainTrial.shape},event:{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") + trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理 + trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[ + 0]:self.zmqServer.event_inner_idx + self.train_epoch[1]] + if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance( + self.trainLabel, list) \ + and self.trainLabel.count(self.currentLabel) < self.single_train: + self.trainData.append(trainTrial) + self.trainLabel.append(self.currentLabel) + algo_log(f"SSMVEP训练集:{np.shape(self.trainData)}", level="DEBUG") + else: + algo_log(f"epoch_finished {self.zmqServer.epoch_finished}, 数据长度不足 {data_length}", level="DEBUG") + self.zmqServer.epoch_finished = False else: - time.sleep(0.0001) + time.sleep(0.001) return elif self.zmqServer.state_mode == 'predict': # 测试状态 @@ -395,7 +404,7 @@ class Decoder_main(threading.Thread): 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]) self.plotLabel.append(self.currentLabel) else: - time.sleep(0.0001) + time.sleep(0.001) return elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态 @@ -408,7 +417,7 @@ class Decoder_main(threading.Thread): if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ self.interval_epoch[1] \ + self.zmqServer.event_inner_idx: - time.sleep(0.0001) + time.sleep(0.001) return originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据 algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") diff --git a/README.md b/README.md index e5f05dc..ad87930 100644 --- a/README.md +++ b/README.md @@ -40,4 +40,6 @@ python upperHost_stimmock/MI_headless.py ## MI Epoch采集完成|收到命令: {'method': 'train'|取出的 -收到命令: {'method': 'train'|收到命令: {'method': 'train'|收到命令: {'method': 'predict'|事件检测到 \ No newline at end of file +收到命令: {'method': 'train'|收到命令: {'method': 'predict'|Epoch采集完成|事件检测到 + +收到命令: {'method': 'train|Epoch采集完成|事件检测到|取出的|SSMVEP训练集 \ No newline at end of file diff --git a/Zmq/zmqServer.py b/Zmq/zmqServer.py index faf12d6..15cc266 100644 --- a/Zmq/zmqServer.py +++ b/Zmq/zmqServer.py @@ -339,29 +339,12 @@ class zmqServer(threading.Thread): # 写入范式缓冲区 with self.paradigmBufferLock: + self.paradigmBuffer.appendBuffer(data_np) if self.interval_inited: - self.epoch_finished = self.detect_event(data_np) - if self.pack_contain_event: - self.paradigmBuffer.resetAllPara() - self.paradigmBuffer.appendBuffer(data_np) - + self.pack_contain_event, self.epoch_finished = self.detect_event(data_np) if self.epoch_finished: - now = datetime.datetime.now() - time_diff_str = "" - # 计算与上一次Epoch完成的时间差 - if self.last_epoch_finish_time is not None: - # 时间差 单位:秒,保留3位小数 - delta_seconds = (now - self.last_epoch_finish_time).total_seconds() - time_diff_str = f" | 与上一次间隔: {delta_seconds:.3f} s" - - # 拼接日志,增加时间差信息 - log_msg = f"Epoch采集完成: {now.strftime('%H:%M:%S.%f')[:-3]}{time_diff_str}" - algo_log(log_msg, level="DEBUG") - - # 更新上一次Epoch完成时间为当前时间 - self.last_epoch_finish_time = now - else: - self.paradigmBuffer.appendBuffer(data_np) + algo_log(f"Epoch采集完成, 当前数据长度{self.paradigmBuffer.GetDataLenCount()}", level="DEBUG") + except Exception as e: algo_log(f"数据处理失败: {str(e)}", level="ERROR") @@ -371,7 +354,8 @@ class zmqServer(threading.Thread): # -------------------------- 事件检测 -------------------------- def detect_event(self, samples): - self.pack_contain_event = False + pack_contain_event = False + epoch_finished = False # 第65通道为事件通道 events = np.array(samples[-2], dtype=np.int32).tolist() for idx, event in enumerate(events): @@ -383,14 +367,20 @@ class zmqServer(threading.Thread): -%H-%M-%S"), ] ) - self.currentLabel = event - if event == self.predict_event: - self.count_events[new_key] = self.latency + 1 + if len(self.count_events) > 0: + algo_log(f"当前有事件未采集完成,新事件{new_key}非法,被忽略") + return pack_contain_event, epoch_finished else: - self.count_events[new_key] = self.train_latency + 1 - self.event_inner_idx = idx - algo_log(f"事件检测到: {events},索引: {idx}", level="DEBUG") - self.pack_contain_event = True + self.currentLabel = event + pack_contain_event = True + if event == self.predict_event: + self.count_events[new_key] = self.latency + 1 + else: + self.count_events[new_key] = self.train_latency + 1 + self.event_inner_idx = idx + algo_log(f"事件检测到: {events},索引: {idx}", level="DEBUG") + else: + pack_contain_event = False # 倒计时并清理过期事件 drop_items = [] @@ -403,9 +393,13 @@ class zmqServer(threading.Thread): for key in drop_items: del self.count_events[key] - if drop_items: - return True - return False + if len(drop_items) > 0: + epoch_finished = True + else: + epoch_finished = False + return pack_contain_event, epoch_finished + + # -------------------------- 主循环 -------------------------- def run(self): self.running = True diff --git a/datamock.py b/datamock.py index b624f8d..fdf08be 100644 --- a/datamock.py +++ b/datamock.py @@ -17,6 +17,10 @@ LABEL_CMD_ADDR = 'tcp://127.0.0.1:8101' # 接收来自上位机范式的标签 # 发送间隔: 每包 5 采样点 / 250Hz = 20ms PKT_INTERVAL = N_SAMPLES_PER_PKT / FS +POINT_PER_3S = FS * 3 # 750 +# 3秒对应的总包数(关键:每150包 = 3s) +PKT_PER_3S = POINT_PER_3S // N_SAMPLES_PER_PKT # 150 + def build_packet(global_sample_idx): """ @@ -34,6 +38,16 @@ def build_packet(global_sample_idx): # Ch64: 标签值通道,初始化为 0 event = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64) + current_pkt_idx = global_sample_idx // N_SAMPLES_PER_PKT + + # 判断是否为 3s 整数倍对应的包 + if current_pkt_idx % PKT_PER_3S == 0: + # 当前是第 N 个3s节点:1、2、1、2...交替 + cycle_num = (current_pkt_idx // PKT_PER_3S) % 2 + if cycle_num == 0: + event[0, 0] = 1.0 + else: + event[0, 0] = 2.0 # Ch65: 标签序号通道,初始化为 0 label_idx = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)