import ast import numpy as np import threading import json import queue from typing import Dict # from Device.SunnyLinker import SunnyLinker64 from dataBuffer import ParadigmRingBuffer from filterProcess import FilterRingBuffer from PubLibrary.InifileHelper import IniRead from logs.log import algo_log import zmq class zmqServer(threading.Thread): def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None): threading.Thread.__init__(self) self.device_info = device_info self.host = host self.cmd_port = cmd_port # 命令交互端口 self.data_port = data_port # 数据接收端口 self.running = False # 原有业务状态变量 # self.get_Impedance = False # 是否返回阻抗值 # self.open_Impedance = None # 是否开启阻抗检测功能 self.StartDecode = False # false 停止解码,true=开始解码 self.StartTrain = False # False未进入训练状态,True处于训练状态 self.state_mode = None # 'train'为训练状态,’rest'为休息状态,'test'为测试状态 self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签 self.IsExitApp = False # 当socket收到2的时候,就置为True,代表要退出系统了。 # self.getReport = False # 获取训练报告内容 self.daemon = True # 范式数据缓存 self.paradigmBuffer = ParadigmRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10) self.filterBuffer = FilterRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10) # 命令与数据通信 self.context = zmq.Context() # 指令通道 (8099) - ROUTER:短JSON命令,低频率 self.cmd_socket = self.context.socket(zmq.ROUTER) self.cmd_socket.setsockopt(zmq.RCVHWM, 100) # 指令不需要大缓存,100条足够 self.cmd_socket.setsockopt(zmq.SNDHWM, 100) self.cmd_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法,降低指令延迟 self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}") # 数据通道 (8100) - ROUTER:高频脑电二进制流 self.data_socket = self.context.socket(zmq.ROUTER) self.data_socket.setsockopt(zmq.RCVHWM, 500) # 500包=10秒缓存,足够应对短时卡顿 self.data_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法,减少数据传输延迟 self.data_socket.bind(f"tcp://{self.host}:{data_port}") # Poller 轮训器(保持不变) self.poller = zmq.Poller() self.poller.register(self.cmd_socket, zmq.POLLIN) self.poller.register(self.data_socket, zmq.POLLIN) # 业务变量 self.targetFreqs = [] self.changeTarget = False # 更换目标频率 # self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类,已在Decoder实例化 self.labels = [0x01, 0x02,0x03] self.decoder_switch = False #更换解码器 self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi' # 客户端管理 - 区分命令/数据客户端 self.cmd_clients = set() # 命令端口客户端ID self.data_clients = set() # 数据端口客户端ID self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播) # 范式buffer参数, 事件检测相关 self._event_lock = threading.Lock() self._epoch_finished = False self._event_inner_idx = -1 self.pack_contain_event = False self.predict_event = 99 self.events = [1, 2, self.predict_event] self.count_events = {} self.latency = 50 self.train_latency = 50 self._interval_inited = False @property def interval_inited(self): return self._interval_inited @interval_inited.setter def interval_inited(self, value): self._interval_inited = value @property def epoch_finished(self): with self._event_lock: return self._epoch_finished @epoch_finished.setter def epoch_finished(self, value): with self._event_lock: self._epoch_finished = value @property def event_inner_idx(self): with self._event_lock: return self._event_inner_idx @event_inner_idx.setter def event_inner_idx(self, value): with self._event_lock: self._event_inner_idx = value def interval_init(self, decoder_class): if decoder_class == 'ssmvep': interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息 self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch self.latency = (self.interval_epoch[ 1] + 0.1 * self.device_info['sample_rate']) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点;0.1表示比实际需要的长度多取0.1,会被截掉 self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 elif decoder_class == 'mi': interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息 self.train_epoch = self.interval_epoch.copy() self.latency = (self.interval_epoch[1]) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点; self.train_latency = self.latency print('时间窗:', (interval_epoch)) self.count_events: Dict[str, int] = {} # 表示包延迟的计数信息 self.event_inner_idx = -1 # event在5位数据包内部的idx self.epoch_finished = False # 接收epoch是否完整 self.pack_contain_event = False # 当前包是否含有event self.predict_event = 99 self.events = [1, 2, self.predict_event] self.interval_inited = True # if getattr(self, 'serial', None) and self.serial.is_open: # self.serial.close() # self.serial = serial.Serial(self.serial_port, 460800, timeout=1) # 连接同步器串口 def broadcast_message(self, method, params): """Put message into queue to be sent to all command clients""" self.send_queue.put((method, params)) def _handle_cmd_message(self, frames): """处理命令端口消息(原有命令交互逻辑)""" if len(frames) < 3: return ident, _, message_bytes = frames[:3] # 注册新的命令客户端 if ident not in self.cmd_clients: self.cmd_clients.add(ident) algo_log(f"New CMD Client Connected: {ident} (port: {self.cmd_port})") # 解析消息 try: message = json.loads(message_bytes.decode('utf-8')) except json.JSONDecodeError: algo_log(f"Invalid JSON from CMD client {ident}") return algo_log(f"Received CMD request: {message}") method = message.get("method") params = message.get("params") # 原有命令处理逻辑 if method == "sync": self.state_mode = 'sync' elif method == "targetFreqs": if not isinstance(params, list): algo_log(f"targetFreqs must be a list") return if params != self.targetFreqs: self.targetFreqs = params self.changeTarget = True elif method == "decoderClass": if not isinstance(params, str): algo_log(f"decoderClass must be a str") return if params != self.decoder_class: self.decoder_class = params self.decoder_switch = True elif method == "train":#训练状态 self.state_mode = 'train' self.StartTrain = True self.currentLabel = params # 当前刺激端的训练标签 # self.sunnyLinker.push_trigger(self.labels[self.currentLabel]) elif method == "predict":#预测状态 self.state_mode = 'predict' if params == 1: #开始解码 self.StartDecode = True # self.sunnyLinker.push_trigger(0x63) elif params == 2: #停止解码 self.IsExitApp = True self.running = False elif method == "rest": #休息状态 self.state_mode = 'rest' else: algo_log(f"未知命令:{method}", level="WARNING") # elif method == "getReport": # self.getReport = True # elif method == "impedance": # if params == 1: # self.open_Impedance = True # 开启阻抗 # self.get_Impedance = True # 返回阻抗 # elif params == 2: # self.open_Impedance = False # 关闭阻抗 # self.get_Impedance = False # 停止返回阻抗 def _handle_data_message(self, frames): """ 处理8100端口原始脑电二进制数据 固定格式:上位机发送 (5,66) float32 二维数组字节流(已转换为微伏物理量)→ 转置为 (66,5) 写入双缓冲区 """ # 1. 校验ZMQ消息帧完整性 if len(frames) < 3: print(f"[ERROR] 无效数据帧:长度不足3帧,实际长度={len(frames)}") return ident, _, data_bytes = frames[:3] # 2. 客户端管理(单客户端场景,自动更新最新身份) if ident not in self.data_clients: self.data_clients.add(ident) self.current_data_client = ident # 保存唯一客户端身份,用于后续回复滤波结果 print(f"[INFO] 新数据客户端连接成功:{ident}") try: # 3. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节,与int32字节数相同) EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * 4 # 每个float32占4字节 if len(data_bytes) != EXPECTED_BYTES: print(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节") return # 4. 零拷贝二进制解析 + 维度转换 # 步骤:字节流 → (330,) float32数组 → (5,66) 原始格式 → 转置为 (66,5) 缓冲区标准格式 data_np = np.frombuffer(data_bytes, dtype=np.float32) # 重塑为上位机原始维度 data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums']) # 转置为(通道数, 采样点数)标准格式,转换为float64保证滤波运算精度 data_np = data_np.T.astype(np.float64) # 5. 同时写入双环形缓冲区(方法名与现有类保持一致:appendBuffer) # 注意:上位机已发送微伏物理量,无需再乘以增益系数 self.paradigmBuffer.appendBuffer(data_np) self.filterBuffer.appendBuffer(data_np) # 生产环境必须注释!每秒50次打印会导致CPU占用飙升30%以上 algo_log(f"数据写入成功:shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG", record_once=True) except Exception as e: algo_log(f"数据处理失败:{str(e)}", level="ERROR") # 调试阶段临时打开,生产环境务必注释 import traceback traceback.print_exc() def _process_send_queue(self): """处理发送队列,向所有命令客户端广播消息""" while not self.send_queue.empty(): method, params = self.send_queue.get() if self.cmd_clients: try: msg = {'method': method, 'params': params} msg_bytes = json.dumps(msg).encode('utf-8') # 打印日志(隐藏大尺寸数据) if method in ['single_trial_plot', 'miReport']: print(f"{{'method': '{method}', 'params': }}") else: print(f"Sending CMD message: {msg}") # 广播到所有命令客户端 for client_id in list(self.cmd_clients): try: self.cmd_socket.send_multipart([client_id, b'', msg_bytes]) except Exception as e: print(f"Error sending to CMD client {client_id}: {e}") self.cmd_clients.discard(client_id) # 移除失效客户端 except Exception as e: print(f"Error preparing broadcast: {e}") def run(self): self.running = True print(f"ZMQ Server started - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}") try: while self.running: # 1. 处理发送队列(命令端口广播) self._process_send_queue() # 2. 轮训监听两个Socket的输入事件(10ms超时,避免阻塞) socks = dict(self.poller.poll(50)) # 处理命令端口消息 if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN: frames = self.cmd_socket.recv_multipart() self._handle_cmd_message(frames) # 处理数据端口消息 if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN: frames = self.data_socket.recv_multipart() self._handle_data_message(frames) except Exception as e: print(f"Server error occurred: {e}") finally: self.running = False # 关闭所有Socket和上下文 self.cmd_socket.close() self.data_socket.close() self.context.term() print("Server sockets and context closed.") def stop(self): """显式关闭服务器""" self.running = False self.cmd_socket.close() self.data_socket.close() self.context.term() print(f"Server closed explicitly - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}") if __name__ == '__main__': # 初始化并启动服务器(默认cmd=8099, data=8100) server = zmqServer() server.start() # 保持主线程运行 try: while server.running: threading.Event().wait(1) except KeyboardInterrupt: print("Received KeyboardInterrupt, stopping server...") server.stop()