From 4faeae0ff321ee0c3c8e5047230aaf12c3446e12 Mon Sep 17 00:00:00 2001 From: lizhao Date: Mon, 8 Jun 2026 11:56:42 +0800 Subject: [PATCH] add filter process --- Decoder.py | 59 +++---- README.md | 7 - Zmq/filterProcess.py | 172 ++++++++++---------- Zmq/zmqServer.py | 365 ++++++++++++++++++++++++------------------- runDecoder.py | 2 +- 5 files changed, 306 insertions(+), 299 deletions(-) diff --git a/Decoder.py b/Decoder.py index 092a126..2452ccd 100644 --- a/Decoder.py +++ b/Decoder.py @@ -42,49 +42,33 @@ MODEL_FOLDER = "online_Models" class Decoder_main(threading.Thread): def __init__(self, device_info=None): threading.Thread.__init__(self) - self.device_info = { - 'sample_rate': device_info['sample_rate'], - 'frame_points': device_info['frame_points'], - 'channel_nums': device_info['channel_nums'], - 'channel_names': device_info['channel_names'], - 'channel_index': device_info['channel_index'], - } + self.device_info = device_info self.Runing=True self.decoder = None self.decoder_class = None #解码器类别 - - # 与采集设备通信的状态码,0为异常,1为正常 - # self.status_code = 0 - # self.device_info['sample_rate'] = 250 # 采样率 - # self.energy = 0 # 电量 - - self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果 + + self.zmqServer = None + self.sliding_filter = None + + self._init_threads() + + def _init_threads(self): + """初始化ZMQ服务和滤波线程""" + # 1. 初始化ZMQServer并启动 self.zmqServer = zmqServer(device_info=self.device_info) - self.zmqServer.start() - - self.filter = SlidingFilter() + self.zmqServer.start() # 启动ZMQ接收线程 - # self.zmqClient = zmqClient(_upper_host, _upper_port) - # self.zmqClient.set_zmq_server(self.zmqServer) - # self.zmqClient.connect() + # 2. 初始化滤波线程(关联ZMQServer的环形缓存) + self.sliding_filter = SlidingFilter( + ring_buffer=self.zmqServer.filterBuffer, + n_chan=self.zmqServer.device_info['channel_nums'], + srate=self.zmqServer.device_info['sample_rate'] + ) - - # def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None): - # self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type')) - # _device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host')) - # _device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port')) - # _upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host')) - # _upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port')) - - # if self.DeviceType == 1: - # self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.device_info['sample_rate'], 64, method='tcp') - # self.thread_data_server.host = _device_host - # self.thread_data_server.port = _device_port - - # self.thread_data_server.toUv = True - # self.thread_data_server.start() + # 注册滤波结果回调(示例:打印数据形状) + self.sliding_filter.set_result_callback(self.zmqServer.send_filtered_data) def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号 @@ -210,6 +194,10 @@ class Decoder_main(threading.Thread): def run(self): while self.Runing: + # 当滤波数据大于5秒时,启动滤波线程 + if self.zmqServer.filterBuffer.GetDataLenCount() > self.device_info['sample_rate'] * 5: + self.sliding_filter.start() + if self.zmqServer.decoder_switch or self.zmqServer.changeTarget: print(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}") self.zmqServer.decoder_switch = False @@ -487,6 +475,7 @@ class Decoder_main(threading.Thread): @return: ''' self.zmqServer.stop() + self.sliding_filter.stop() self.Runing=False def reset_state(self): diff --git a/README.md b/README.md index d1fcc39..f233ccc 100644 --- a/README.md +++ b/README.md @@ -13,13 +13,6 @@ Debug_64ch_Decoder_Optimize is an updated version that fixes several issues and 6. decoder class切换问题 7. decoder_class切换时,数据重置、各类参数重置 -# update -2026年6月5日13:55:34 - -# 遗留问题 -1. 之前当处于阻抗检测状态时,Decoder在空跑。当前无法判断是否处于阻抗检测状态。 - - 解决方法,保留之前发阻抗命令 - # 常用命令 source activate 3in1Py310 diff --git a/Zmq/filterProcess.py b/Zmq/filterProcess.py index 522e736..402a971 100644 --- a/Zmq/filterProcess.py +++ b/Zmq/filterProcess.py @@ -3,6 +3,7 @@ 数据滤波模块 """ import numpy as np +import time import threading from scipy import signal from logs.log import algo_log @@ -10,7 +11,7 @@ from logs.log import algo_log class FilterRingBuffer: def __init__(self, n_chan, n_points): """ - 初始化纯数据环形缓存 + 初始化纯数据环形缓存(线程安全) :param n_chan: 通道数 :param n_points: 总缓存点数(与paradigmRingBuffer参数完全一致) """ @@ -18,11 +19,9 @@ class FilterRingBuffer: self.n_points = n_points self.buffer = np.zeros((n_chan, n_points), dtype=np.float32) - self.current_ptr = 0 # 写入指针 + self.current_ptr = 0 # 写入指针:指向下一个要写入的位置 self.total_samples = 0 # 已写入总点数 - - # 线程安全锁(多线程环境必须) - self.lock = threading.Lock() + self.lock = threading.Lock() # 线程安全锁 def appendBuffer(self, data): """ @@ -33,8 +32,8 @@ class FilterRingBuffer: n = data.shape[1] if n == 0: return - - # 环形写入逻辑 + + # 环形写入逻辑:指针到末尾则绕回 write_end = self.current_ptr + n if write_end <= self.n_points: self.buffer[:, self.current_ptr:write_end] = data @@ -42,14 +41,15 @@ class FilterRingBuffer: split = self.n_points - self.current_ptr self.buffer[:, self.current_ptr:] = data[:, :split] self.buffer[:, :write_end - self.n_points] = data[:, split:] - - # 更新指针和计数 + + # 更新指针(取模保证环形)和计数(不超过缓存总长度) self.current_ptr = write_end % self.n_points self.total_samples = min(self.total_samples + n, self.n_points) def getData(self, count): """ - 从读指针位置读取count个点(与paradigmRingBuffer接口一致) + 从最新位置向前读取count个点(环形读取) + 核心逻辑:current_ptr是下一个写入位置 → 最新数据在current_ptr之前 :param count: 读取点数 :return: np.ndarray, shape=(n_chan, count) """ @@ -57,14 +57,15 @@ class FilterRingBuffer: count = min(count, self.total_samples) if count == 0: return np.zeros((self.n_chan, 0)) - - # 环形读取逻辑(与paradigmRingBuffer完全相同) + + # 环形读取:end是当前写入指针(最新数据的下一位),start是end - count end = self.current_ptr start = end - count if start >= 0: return self.buffer[:, start:end].copy() else: - part1 = self.buffer[:, start:] + # 跨环形边界:前半部分从缓存末尾取,后半部分从开头取 + part1 = self.buffer[:, start:] # start为负,等价于n_points + start part2 = self.buffer[:, :end] return np.concatenate((part1, part2), axis=1) @@ -72,7 +73,7 @@ class FilterRingBuffer: """ 扩展方法:获取最新的n个点(不移动读指针,用于滑动窗口) :param n: 点数 - :return: np.ndarray, shape=(n_chan, n) + :return: np.ndarray, shape=(n_chan, n) | None(数据不足时) """ with self.lock: if self.total_samples < n: @@ -93,43 +94,37 @@ class FilterRingBuffer: # ----------------------------------------------------------------------------- # 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现) -# 可替换任意缓存实现,只要实现appendBuffer、get_latest_n_points接口 # ----------------------------------------------------------------------------- -class SlidingFilter: +class SlidingFilter(threading.Thread): def __init__( self, + ring_buffer: FilterRingBuffer, n_chan=66, srate=250, - buffer_sec=5, window_sec=3, - step_sec=0.2, + step_sec=0.2, # 200ms滑动步长 packet_size=5 ): - """ - 初始化滑动滤波器 - :param n_chan: 通道数 - :param srate: 采样率 - :param buffer_sec: 总缓存时长(秒) - :param window_sec: 滤波窗口时长(秒) - :param step_sec: 滑动步长/输出时长(秒) - :param packet_size: 每包数据点数(20ms一包=5点) - """ + super().__init__(daemon=True) # 核心参数 self.n_chan = n_chan self.srate = srate - self.buffer_size = int(srate * buffer_sec) - self.window_size = int(srate * window_sec) - self.step_size = int(srate * step_sec) + self.step_sec = step_sec # 200ms滑动步长 + self.window_sec = window_sec # 3秒窗口 + self.step_sec = step_sec # 200ms滑动步长 + self.window_size = int(srate * window_sec) # 3秒点数:250*3=750 + self.step_size = int(srate * step_sec) # 200ms点数:250*0.2=50 self.packet_size = packet_size - - # 初始化纯数据缓存(解耦核心) - self.buffer = FilterRingBuffer(n_chan, self.buffer_size) - - # 滤波触发计数器 - self.packet_count = 0 - self.ready_to_filter = False - - # 预计算滤波器系数 + + # 关联ZMQServer的环形缓存(解耦:仅依赖接口) + self.ring_buffer = ring_buffer + # 线程控制 + self.running = threading.Event() + self.running.set() + # 滤波结果回调(外部可注册,获取滤波后的数据) + self.filter_result_callback = None + + # 预计算滤波器系数(仅执行一次) self._init_filters() def _init_filters(self): @@ -145,65 +140,60 @@ class SlidingFilter: ) self.a_bp = np.array([1.0]) - def append_and_check_trigger(self, raw_data): - """ - 追加单包原始数据并检查是否触发滤波 - :param raw_data: 上位机原始数据,shape=(packet_size, n_chan) - :return: bool: 是否触发本次滤波 - """ - # 转置为标准格式:(通道数, 点数) - data = raw_data.T.astype(np.float64) - - # 写入缓存(纯缓存操作) - self.buffer.appendBuffer(data) - - # 更新包计数器 - self.packet_count += 1 - - # 检查滤波触发条件:数据≥窗口长度 且 累计满一个步长的包数 - packets_per_step = int(self.step_size / self.packet_size) # 10包=200ms - if (self.buffer.GetDataLenCount() >= self.window_size - and self.packet_count >= packets_per_step): - self.packet_count = 0 - self.ready_to_filter = True - return True - return False - - def filter_and_get_output(self): - """ - 执行滤波并返回无边界效应的输出数据 - :return: np.ndarray: 滤波后数据,shape=(n_chan, step_size) - """ - if not self.ready_to_filter: - return None - - # 获取最新的完整滤波窗口数据 - window_data = self.buffer.get_latest_n_points(self.window_size) - if window_data is None: - self.ready_to_filter = False - return None - + def _filter_window_data(self, window_data): + """对3秒窗口数据执行滤波,返回无边界效应的200ms数据""" # 零相位滤波(无延迟,无边界效应) filtered = window_data - np.mean(window_data, axis=-1, keepdims=True) filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1) filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1) - - # 提取倒数第二个步长的数据(完全避开两端边界效应) + + # 提取倒数第二个200ms的数据(完全避开两端边界效应) + # 窗口长度750,步长50 → start=750-100=650,end=750-50=700 start_idx = self.window_size - 2 * self.step_size end_idx = self.window_size - self.step_size output_data = filtered[:, start_idx:end_idx].copy() - - # 重置触发标志 - self.ready_to_filter = False - return output_data - def reset(self): - """重置滤波器和缓存""" - self.buffer.resetAllPara() - self.packet_count = 0 - self.ready_to_filter = False + def run(self): + """线程主逻辑:精确200ms触发一次滤波""" + # 精确定时核心:基于perf_counter计算下一次执行时间,补偿sleep误差 + interval = self.step_sec # 200ms = 0.2秒 + next_run_time = time.perf_counter() - def get_buffer_length(self): - """获取当前缓存数据长度""" - return self.buffer.GetDataLenCount() \ No newline at end of file + while self.running.is_set(): + # 1. 等待到下一次执行时间(精确定时) + current_time = time.perf_counter() + if current_time < next_run_time: + time.sleep(next_run_time - current_time) + next_run_time += interval # 补偿:下次执行时间基于上一次目标时间 + else: + # 若超时(如滤波耗时超过200ms),重置下一次时间(避免累积误差) + algo_log("滤波耗时超过200ms,定时偏移", level='debug') + next_run_time = time.perf_counter() + interval + + # 2. 执行滤波逻辑 + try: + # 获取最新的3秒窗口数据 + window_data = self.ring_buffer.get_latest_n_points(self.window_size) + if window_data is None: + algo_log(f"缓存数据不足,当前缓存{self.ring_buffer.GetDataLenCount()}点,需{self.window_size}点", level='debug') + continue + + # 滤波并提取无边界效应的200ms数据 + filtered_data = self._filter_window_data(window_data) + + # 回调返回结果(外部可处理) + if self.filter_result_callback is not None: + self.filter_result_callback(filtered_data[:64, :]) # 只发送前64通道数据 + + except Exception as e: + algo_log(f"滤波执行异常: {e}", level='error') + + def set_result_callback(self, callback): + """注册滤波结果回调函数""" + self.filter_result_callback = callback + + def stop(self): + """停止滤波线程""" + self.running.clear() + self.join(timeout=1) diff --git a/Zmq/zmqServer.py b/Zmq/zmqServer.py index f18d185..34e980c 100644 --- a/Zmq/zmqServer.py +++ b/Zmq/zmqServer.py @@ -1,3 +1,4 @@ +# -*-coding:utf-8 -*- import ast import numpy as np import threading @@ -7,7 +8,6 @@ from typing import Dict import datetime import time -# from Device.SunnyLinker import SunnyLinker64 from Zmq.dataBuffer import ParadigmRingBuffer from Zmq.filterProcess import FilterRingBuffer from PubLibrary.InifileHelper import IniRead @@ -21,63 +21,68 @@ class zmqServer(threading.Thread): self.device_info = device_info self.host = host - self.cmd_port = cmd_port # 命令交互端口 - self.data_port = data_port # 数据接收端口 + self.cmd_port = cmd_port # 命令交互端口:收JSON命令 + 返JSON结果 + self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果 self.running = False # 原有业务状态变量 - # self.get_Impedance = False # 是否返回阻抗值 - self.open_Impedance = False # 是否开启阻抗检测功能 - self.StartDecode = False # false 停止解码,true=开始解码 - self.StartTrain = False # False未进入训练状态,True处于训练状态 - self.state_mode = None # 'train'为训练状态,’rest'为休息状态,'test'为测试状态 - self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签 - self.IsExitApp = False # 当socket收到2的时候,就置为True,代表要退出系统了。 - # self.getReport = False # 获取训练报告内容 + self.open_Impedance = False #当前系统处于阻抗检测状态 + self.StartDecode = False + self.StartTrain = False + self.state_mode = None + self.currentLabel = -1 + self.IsExitApp = False self.daemon = True - # 范式数据缓存 - self.paradigmBuffer = ParadigmRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10) - self.filterBuffer = FilterRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10) - self.paradigmBufferLock= threading.Lock() + # 双环形缓冲区 + self.paradigmBuffer = ParadigmRingBuffer( + self.device_info['channel_nums'], + self.device_info['sample_rate'] * 10 + ) + self.filterBuffer = FilterRingBuffer( + self.device_info['channel_nums'], + self.device_info['sample_rate'] * 10 + ) + self.paradigmBufferLock = threading.Lock() + self.filterBufferLock = threading.Lock() - - # 命令与数据通信 + # ZMQ上下文与套接字 self.context = zmq.Context() - # 指令通道 (8099) - ROUTER:短JSON命令,低频率 + + # 8099命令端口:ROUTER self.cmd_socket = self.context.socket(zmq.ROUTER) - # 通用套接字选项:仍在 SocketOption 中 self.cmd_socket.setsockopt(zmq.SocketOption.RCVHWM, 100) self.cmd_socket.setsockopt(zmq.SocketOption.SNDHWM, 100) self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}") - # 数据通道 (8100) - ROUTER:高频脑电二进制流 + # 8100数据端口:ROUTER self.data_socket = self.context.socket(zmq.ROUTER) self.data_socket.setsockopt(zmq.SocketOption.RCVHWM, 500) + self.data_socket.setsockopt(zmq.SocketOption.SNDHWM, 100) # 添加发送高水位线 self.data_socket.bind(f"tcp://{self.host}:{data_port}") - # Poller 轮训器(保持不变) + # Poller轮询器 self.poller = zmq.Poller() self.poller.register(self.cmd_socket, zmq.POLLIN) self.poller.register(self.data_socket, zmq.POLLIN) # 业务变量 self.targetFreqs = [] - self.changeTarget = False # 更换目标频率 - # self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类,已在Decoder实例化 - self.labels = [0x01, 0x02,0x03] - self.decoder_switch = False #更换解码器 - self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi' + self.changeTarget = False + self.labels = [0x01, 0x02, 0x03] + self.decoder_switch = False + self.decoder_class = None - # 客户端管理 - 区分命令/数据客户端 - self.cmd_clients = set() # 命令端口客户端ID - self.data_clients = set() # 数据端口客户端ID - self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播) - + # 客户端管理(单客户端场景) + self.cmd_clients = set() + self.data_clients = set() + self.current_data_client = None # 唯一数据客户端身份,用于发送滤波结果 - # 范式buffer参数, 事件检测相关 - self._event_lock = threading.Lock() - + # 发送队列(双端口分离) + self.cmd_send_queue = queue.Queue() # 8099端口命令结果队列 + self.data_send_queue = queue.Queue() # 8100端口滤波数据队列 + + # 范式buffer与事件检测参数 self.predict_event = 99 self.events = [1, 2, self.predict_event] self.latency = 50 @@ -98,60 +103,131 @@ class zmqServer(threading.Thread): self.event_inner_idx = -1 self.interval_inited = False - def interval_init(self, decoder_class): if decoder_class == 'ssmvep': interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) - self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息 - self.train_epoch = [int(self.interval_epoch[0]), - int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch - self.latency = (self.interval_epoch[ - 1] + 0.1 * self.device_info['sample_rate']) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点;0.1表示比实际需要的长度多取0.1,会被截掉 + self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] + self.train_epoch = [ + int(self.interval_epoch[0]), + int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) + ] + self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 elif decoder_class == 'mi': interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) - self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息 + self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] self.train_epoch = self.interval_epoch.copy() - self.latency = (self.interval_epoch[1]) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点; + self.latency = self.interval_epoch[1] // 5 self.train_latency = self.latency - print('时间窗:', (interval_epoch)) - self.count_events: Dict[str, int] = {} # 表示包延迟的计数信息 - self.event_inner_idx = -1 # event在5位数据包内部的idx - self.epoch_finished = False # 接收epoch是否完整 - self.pack_contain_event = False # 当前包是否含有event + algo_log(f"时间窗初始化完成: {interval_epoch}", level="INFO") + self.count_events: Dict[str, int] = {} + self.event_inner_idx = -1 + self.epoch_finished = False + self.pack_contain_event = False self.predict_event = 99 self.events = [1, 2, self.predict_event] self.interval_inited = True + # -------------------------- 8099端口:命令结果广播 -------------------------- def broadcast_message(self, method, params): - """Put message into queue to be sent to all command clients""" - self.send_queue.put((method, params)) + """ + 向所有8099端口客户端广播JSON格式的命令结果 + 用于:解码结果、训练状态、错误提示、进度通知等 + """ + self.cmd_send_queue.put((method, params)) - def _handle_cmd_message(self, frames): - """处理命令端口消息(原有命令交互逻辑)""" - if len(frames) < 3: + def _process_cmd_send_queue(self): + """处理8099端口发送队列,在主线程执行(保证ZMQ线程安全)""" + while not self.cmd_send_queue.empty(): + method, params = self.cmd_send_queue.get() + if not self.cmd_clients: + continue + + try: + msg = {'method': method, 'params': params} + msg_bytes = json.dumps(msg).encode('utf-8') + + algo_log(f"发送命令结果: {msg}", level="DEBUG") + + # 广播到所有命令客户端 + for client_id in list(self.cmd_clients): + try: + self.cmd_socket.send_multipart([client_id, b"", msg_bytes]) + except Exception as e: + algo_log(f"向命令客户端{client_id}发送失败: {e}", level="ERROR") + self.cmd_clients.discard(client_id) + + except Exception as e: + algo_log(f"命令结果打包失败: {e}", level="ERROR") + + # -------------------------- 8100端口:滤波结果发送 -------------------------- + def send_filtered_data(self, filtered_data): + """ + 向8100端口客户端发送二进制格式的滤波结果 + 用于:上位机实时绘图的脑电波形数据 + :param filtered_data: 滤波后数据,shape=(通道数, 50),float64格式 + """ + if self.current_data_client is None: + algo_log("数据客户端未连接,跳过滤波数据发送", level="WARNING") return + + # 转置为上位机需要的[50, 通道数]格式 + filtered_data = filtered_data.T.astype(np.float32) + send_buf = filtered_data.tobytes() + algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG") + self.data_send_queue.put(send_buf) + + def _process_data_send_queue(self): + """处理8100端口发送队列,在主线程执行(保证ZMQ线程安全)""" + while not self.data_send_queue.empty(): + send_buf = self.data_send_queue.get() + if self.current_data_client is None: + continue + + try: + # 标准ROUTER发送格式:[客户端ID, 空分隔帧, 数据帧] + self.data_socket.send_multipart([ + self.current_data_client, + b"", + send_buf + ]) + algo_log(f"发送滤波数据成功,长度: {len(send_buf)}字节", level="DEBUG") + + except Exception as e: + algo_log(f"发送滤波数据失败: {e}", level="ERROR") + # 客户端断开,重置身份 + self.current_data_client = None + self.data_clients.clear() + + # -------------------------- 命令端口消息处理 -------------------------- + def _handle_cmd_message(self, frames): + """处理8099端口JSON命令消息""" + if len(frames) < 3: + algo_log(f"无效命令帧:长度不足3帧,实际{len(frames)}", level="ERROR") + return + ident, _, message_bytes = frames[:3] # 注册新的命令客户端 if ident not in self.cmd_clients: self.cmd_clients.add(ident) - algo_log(f"New CMD Client Connected: {ident} (port: {self.cmd_port})") + algo_log(f"新命令客户端连接成功: {ident}", level="INFO") - # 解析消息 + # 解析JSON命令 try: message = json.loads(message_bytes.decode('utf-8')) except json.JSONDecodeError: - algo_log(f"Invalid JSON from CMD client {ident}") + algo_log(f"无效JSON命令: {message_bytes.hex()}", level="ERROR") + self.broadcast_message("error", {"code": 400, "message": "无效JSON格式"}) return - algo_log(f"Received CMD request: {message}") - + + algo_log(f"收到命令: {message}", level="INFO") method = message.get("method") params = message.get("params") - # 原有命令处理逻辑 + # 命令处理逻辑 if method == "sync": self.state_mode = 'sync' elif method == "targetFreqs": @@ -163,108 +239,89 @@ class zmqServer(threading.Thread): self.changeTarget = True elif method == "decoderClass": if not isinstance(params, str): - algo_log(f"decoderClass must be a str") + algo_log(f"decoderClass必须是字符串") return if params != self.decoder_class: self.decoder_class = params self.decoder_switch = True - elif method == "train":#训练状态 + elif method == "train": self.state_mode = 'train' self.StartTrain = True - self.currentLabel = params # 当前刺激端的训练标签 - # self.sunnyLinker.push_trigger(self.labels[self.currentLabel]) - elif method == "predict":#预测状态 + self.currentLabel = params + elif method == "predict": self.state_mode = 'predict' if params == 1: #开始解码 self.StartDecode = True - # self.sunnyLinker.push_trigger(0x63) elif params == 2: #停止解码 self.IsExitApp = True self.running = False - elif method == "rest": #休息状态 + elif method == "rest": self.state_mode = 'rest' elif method == "impedance": if params == 1: - self.open_Impedance = True # 开启阻抗 - # self.get_Impedance = True # 返回阻抗 + self.open_Impedance = True elif params == 2: - self.open_Impedance = False # 关闭阻抗 + self.open_Impedance = False else: - algo_log(f"未知命令:{method}", level="WARNING") - - # elif method == "getReport": - # self.getReport = True - - # elif params == 2: - # self.open_Impedance = False # 关闭阻抗 - # self.get_Impedance = False # 停止返回阻抗 + self.broadcast_message("error", {"code": 404, "message": f"未知命令: {method}"}) + # -------------------------- 数据端口消息处理 -------------------------- def _handle_data_message(self, frames): - """ - 处理8100端口原始脑电二进制数据 - 固定格式:上位机发送 (5,66) float32 二维数组字节流(已转换为微伏物理量)→ 转置为 (66,5) 写入双缓冲区 - """ - # 1. 校验ZMQ消息帧完整性(ROUTER接收DEALER消息的帧格式:[客户端ID, 发送方ID, 空帧, 数据帧]) - if len(frames) < 4: # 至少需要4帧 - algo_log(f"Invalid data frame: 帧数量不足,期望≥4,实际{len(frames)}", level="ERROR") + """处理8100端口二进制脑电数据消息""" + algo_log(f"收到数据帧,总帧数:{len(frames)}", level="DEBUG", record_once=True) + # 然后再进行解析 + if len(frames) == 4: + # 你的上位机格式 + ident, sender_ident, empty_sep, data_bytes = frames[:4] + elif len(frames) == 3: + # 标准格式 + ident, empty_sep, data_bytes = frames[:3] + else: return - - # 2. 正确解析帧(适配DEALER→ROUTER的帧格式) - client_ident, sender_ident, empty_sep, data_bytes = frames[:4] - if empty_sep != b'': # 校验空分隔帧 - algo_log(f"Invalid frame separator: 期望空字节,实际{empty_sep}", level="ERROR") - return - - # 3. 客户端管理(单客户端场景,自动更新最新身份) - if client_ident not in self.data_clients: - self.data_clients.add(client_ident) - self.current_data_client = client_ident # 保存唯一客户端身份,用于后续回复滤波结果 - print(f"[INFO] 新数据客户端连接成功:{client_ident}") - + # 注册新的数据客户端(单客户端场景,自动覆盖旧身份) + if ident not in self.data_clients: + self.data_clients.clear() # 单客户端,只保留最新连接 + self.data_clients.add(ident) + self.current_data_client = ident + algo_log(f"新数据客户端连接成功: {ident}", level="INFO") try: - # 4. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节) - EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * 4 # 每个float32占4字节 + # 精确长度校验 + EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * 4 if len(data_bytes) != EXPECTED_BYTES: - algo_log(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节", level="ERROR") + algo_log(f"数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节", level="ERROR") return - # 5. 零拷贝二进制解析 + 维度转换 - + # 零拷贝解析 + 维度转换 data_np = np.frombuffer(data_bytes, dtype=np.float32) data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums']) data_np = data_np.T.astype(np.float64) - # 6. 写入滤波缓冲区 - self.filterBuffer.appendBuffer(data_np) + # 写入滤波缓冲区 + with self.filterBufferLock: + self.filterBuffer.appendBuffer(data_np) - # 7. 写入范式缓冲区 - try: - with self.paradigmBufferLock: - if self.interval_inited: - self.epoch_finished = self.detect_event(data_np) - if self.pack_contain_event: - self.paradigmBuffer.resetAllPara() # 检测到当前pack含有event,清除ringbuffer中之前的数据 - self.paradigmBuffer.appendBuffer(data_np) - if self.epoch_finished: - time.sleep(0.005) - algo_log('epoch_finished: ' + datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3], level="DEBUG") - else: - self.paradigmBuffer.appendBuffer(data_np) - except Exception as e: - print("锁:写入异常",e) - self.paradigmBuffer.appendBuffer(data_np) - - # algo_log(f"数据写入成功:shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG") + # 写入范式缓冲区 + with self.paradigmBufferLock: + if self.interval_inited: + self.epoch_finished = self.detect_event(data_np) + if self.pack_contain_event: + self.paradigmBuffer.resetAllPara() + self.paradigmBuffer.appendBuffer(data_np) + if self.epoch_finished: + algo_log('Epoch采集完成: ' + datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3], level="DEBUG") + else: + self.paradigmBuffer.appendBuffer(data_np) except Exception as e: - algo_log(f"数据处理失败:{str(e)}", level="ERROR") + algo_log(f"数据处理失败: {str(e)}", level="ERROR") if IniRead('system', 'algo_log_level', 'INFO') == 'DEBUG': import traceback traceback.print_exc() - # 检测是否含有标签 + # -------------------------- 事件检测 -------------------------- def detect_event(self, samples): self.pack_contain_event = False + # 第65通道为事件通道 events = np.array(samples[-2])[0].tolist() for idx, event in enumerate(events): if int(event) in self.events: @@ -281,76 +338,54 @@ class zmqServer(threading.Thread): self.count_events[new_key] = self.train_latency + 1 self.event_inner_idx = idx self.pack_contain_event = True + + # 倒计时并清理过期事件 drop_items = [] for key, value in self.count_events.items(): - value = value - 1 + value -= 1 if value == 0: drop_items.append(key) self.count_events[key] = value + for key in drop_items: del self.count_events[key] + if drop_items: return True return False - - - - def _process_send_queue(self): - """处理发送队列,向所有命令客户端广播消息""" - while not self.send_queue.empty(): - method, params = self.send_queue.get() - if self.cmd_clients: - try: - msg = {'method': method, 'params': params} - msg_bytes = json.dumps(msg).encode('utf-8') - - # 打印日志(隐藏大尺寸数据) - if method in ['single_trial_plot', 'miReport']: - print(f"{{'method': '{method}', 'params': }}") - else: - print(f"Sending CMD message: {msg}") - - # 广播到所有命令客户端 - for client_id in list(self.cmd_clients): - try: - self.cmd_socket.send_multipart([client_id, b'', msg_bytes]) - except Exception as e: - print(f"Error sending to CMD client {client_id}: {e}") - self.cmd_clients.discard(client_id) # 移除失效客户端 - except Exception as e: - print(f"Error preparing broadcast: {e}") - + # -------------------------- 主循环 -------------------------- def run(self): self.running = True - algo_log(f"algo ZMQ Server started - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}", level="INFO") + algo_log(f"ZMQ服务器启动成功 - 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO") try: while self.running: - # 1. 处理发送队列(命令端口广播) - self._process_send_queue() + # 1. 处理两个端口的发送队列(必须在主线程执行) + self._process_cmd_send_queue() + self._process_data_send_queue() - # 2. 轮训监听两个Socket的输入事件 + # 2. 轮询监听两个端口的输入事件 socks = dict(self.poller.poll(50)) - # 处理命令端口消息 + # 处理8099命令端口消息 if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN: frames = self.cmd_socket.recv_multipart() self._handle_cmd_message(frames) - # 处理数据端口消息 + # 处理8100数据端口消息 if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN: frames = self.data_socket.recv_multipart() self._handle_data_message(frames) except Exception as e: - print(f"Server error occurred: {e}") + algo_log(f"服务器主循环异常: {e}", level="ERROR") finally: self.running = False - # 关闭所有Socket和上下文 + # 优雅关闭所有资源 self.cmd_socket.close() self.data_socket.close() self.context.term() - print("Server sockets and context closed.") + algo_log("ZMQ服务器已关闭", level="INFO") def stop(self): """显式关闭服务器""" @@ -358,10 +393,10 @@ class zmqServer(threading.Thread): self.cmd_socket.close() self.data_socket.close() self.context.term() - print(f"Server closed explicitly - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}") + algo_log(f"服务器已显式关闭 - 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO") if __name__ == '__main__': - # 初始化并启动服务器(默认cmd=8099, data=8100) + # 初始化并启动服务器 server = zmqServer() server.start() @@ -370,5 +405,5 @@ if __name__ == '__main__': while server.running: threading.Event().wait(1) except KeyboardInterrupt: - print("Received KeyboardInterrupt, stopping server...") - server.stop() + algo_log("收到键盘中断信号,正在停止服务器...", level="INFO") + server.stop() \ No newline at end of file diff --git a/runDecoder.py b/runDecoder.py index 2999992..aa91bbc 100644 --- a/runDecoder.py +++ b/runDecoder.py @@ -42,7 +42,7 @@ if __name__ == "__main__": # ) device_info= get_device_info(1) - algo_log(f"device_info: {device_info}", level="INFO") + algo_log(f"device_info: {device_info}", level="DEBUG") decoder = Decoder_main(device_info=device_info) try: