# -*-coding:utf-8 -*- import ast import numpy as np import threading import zmq import json import queue from typing import Dict import datetime import time from Zmq.dataBuffer import ParadigmRingBuffer from Zmq.filterProcess import FilterRingBuffer from PubLibrary.InifileHelper import IniRead from logs.log import algo_log zmqServer_host = str(IniRead('system', 'zmqServer_host', '127.0.0.1')) class zmqServer(threading.Thread): def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None): threading.Thread.__init__(self) self.device_info = device_info self.host = zmqServer_host self.cmd_port = cmd_port # 命令交互端口:收JSON命令 + 返JSON结果 self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果 self.running = False # 原有业务状态变量 self.open_Impedance = False #当前系统处于阻抗检测状态 self.StartDecode = False self.StartTrain = False self.state_mode = None self.currentLabel = -1 self.IsExitApp = False self.daemon = True # 双环形缓冲区 self.paradigmBuffer = ParadigmRingBuffer( self.device_info['channel_nums'], self.device_info['sample_rate'] * 10 ) self.filterBuffer = FilterRingBuffer( self.device_info['channel_nums'], self.device_info['sample_rate'] * 10 ) self.paradigmBufferLock = threading.Lock() self.filterBufferLock = threading.Lock() # ZMQ上下文与套接字 self.context = zmq.Context() # 8099命令端口:ROUTER self.cmd_socket = self.context.socket(zmq.ROUTER) self.cmd_socket.setsockopt(zmq.SocketOption.RCVHWM, 100) self.cmd_socket.setsockopt(zmq.SocketOption.SNDHWM, 100) self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}") # 8100数据端口:ROUTER self.data_socket = self.context.socket(zmq.ROUTER) self.data_socket.setsockopt(zmq.SocketOption.RCVHWM, 500) self.data_socket.setsockopt(zmq.SocketOption.SNDHWM, 100) # 添加发送高水位线 self.data_socket.bind(f"tcp://{self.host}:{data_port}") # Poller轮询器 self.poller = zmq.Poller() self.poller.register(self.cmd_socket, zmq.POLLIN) self.poller.register(self.data_socket, zmq.POLLIN) # 业务变量 self.targetFreqs = [] self.changeTarget = False self.labels = [0x01, 0x02, 0x03] self.decoder_switch = False self.decoder_class = None # 客户端管理(单客户端场景) self.cmd_clients = set() self.data_clients = set() self.current_data_client = None # 唯一数据客户端身份,用于发送滤波结果 # 发送队列(双端口分离) self.cmd_send_queue = queue.Queue() # 8099端口命令结果队列 self.data_send_queue = queue.Queue() # 8100端口滤波数据队列 # 范式buffer与事件检测参数 self.predict_event = 99 self.events = [1, 2, self.predict_event] self.latency = 50 self.train_latency = 50 self.count_events = {} self.epoch_finished = False self.pack_contain_event = False self.event_inner_idx = -1 self.interval_inited = False self.last_epoch_finish_time = None def reset_state(self): """清空采集器状态和缓存数据""" with self.paradigmBufferLock: self.paradigmBuffer.resetAllPara() self.count_events = {} self.epoch_finished = False self.pack_contain_event = False self.event_inner_idx = -1 self.interval_inited = False def interval_init(self, decoder_class): if decoder_class == 'ssmvep': interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2] self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # [50, 550] self.train_epoch = [ int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) ] # [50, 575] self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #115包, 575个点 self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #120包 600个点 elif decoder_class == 'mi': interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5] self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] #[125, 1125] self.train_epoch = self.interval_epoch.copy() self.latency = self.interval_epoch[1] // 5 #225 self.train_latency = self.latency #225 algo_log(f"时间窗初始化完成: {interval_epoch}", level="INFO") self.count_events: Dict[str, int] = {} self.event_inner_idx = -1 self.epoch_finished = False self.pack_contain_event = False self.predict_event = 99 self.events = [1, 2, self.predict_event] self.interval_inited = True # -------------------------- 8099端口:命令结果广播 -------------------------- def broadcast_message(self, method, params): """ 向所有8099端口客户端广播JSON格式的命令结果 用于:解码结果、训练状态、错误提示、进度通知等 """ self.cmd_send_queue.put((method, params)) def _process_cmd_send_queue(self): """处理8099端口发送队列,在主线程执行(保证ZMQ线程安全)""" while not self.cmd_send_queue.empty(): method, params = self.cmd_send_queue.get() if not self.cmd_clients: continue try: msg = {'method': method, 'params': params} msg_bytes = json.dumps(msg).encode('utf-8') if msg['method'] == 'beta_psd': algo_log(f"发送命令结果: {msg}", level="DEBUG", record_once=True) else: algo_log(f"发送命令结果: {msg}", level="DEBUG") # 广播到所有命令客户端 for client_id in list(self.cmd_clients): try: self.cmd_socket.send_multipart([client_id, b"", msg_bytes]) except Exception as e: algo_log(f"向命令客户端{client_id}发送失败: {e}", level="ERROR") self.cmd_clients.discard(client_id) except Exception as e: algo_log(f"命令结果打包失败: {e}", level="ERROR") # -------------------------- 8100端口:滤波结果发送 -------------------------- def send_filtered_data(self, filtered_data): """ 向8100端口客户端发送二进制格式的滤波结果 用于:上位机实时绘图的脑电波形数据 :param filtered_data: 滤波后数据,shape=(通道数, 50),float64格式 """ if self.current_data_client is None: algo_log("数据客户端未连接,跳过滤波数据发送", level="WARNING") return # 转置为上位机需要的[50, 通道数]格式 filtered_data = filtered_data.T.astype(np.float64) send_buf = filtered_data.tobytes() # algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG", record_once=True) self.data_send_queue.put(send_buf) def _process_data_send_queue(self): """处理8100端口发送队列,在主线程执行(保证ZMQ线程安全)""" while not self.data_send_queue.empty(): send_buf = self.data_send_queue.get() if self.current_data_client is None: continue try: # 标准ROUTER发送格式:[客户端ID, 空分隔帧, 数据帧] self.data_socket.send_multipart([ self.current_data_client, b"", send_buf ]) algo_log(f"发送滤波数据成功,长度: {len(send_buf)}字节", level="DEBUG", record_once=True) except Exception as e: algo_log(f"发送滤波数据失败: {e}", level="ERROR") # 客户端断开,重置身份 self.current_data_client = None self.data_clients.clear() # -------------------------- 命令端口消息处理 -------------------------- def _handle_cmd_message(self, frames): """处理8099端口JSON命令消息""" if len(frames) < 3: algo_log(f"无效命令帧:长度不足3帧,实际{len(frames)}", level="ERROR") return ident, _, message_bytes = frames[:3] # 注册新的命令客户端 if ident not in self.cmd_clients: self.cmd_clients.add(ident) algo_log(f"新命令客户端连接成功: {ident}", level="INFO") # 解析JSON命令 try: message = json.loads(message_bytes.decode('utf-8')) except json.JSONDecodeError: algo_log(f"无效JSON命令: {message_bytes.hex()}", level="ERROR") self.broadcast_message("error", {"code": 400, "message": "无效JSON格式"}) return except Exception as e: algo_log(f"_handle_cmd_message exception: {e}", level="ERROR") return algo_log(f"收到命令: {message}", level="INFO") method = message.get("method") params = message.get("params") # 命令处理逻辑 if method == "sync": self.state_mode = 'sync' elif method == "targetFreqs": if not isinstance(params, list): algo_log(f"targetFreqs must be a list") return if params != self.targetFreqs: self.targetFreqs = params self.changeTarget = True elif method == "decoderClass": if not isinstance(params, str): algo_log(f"decoderClass必须是字符串") return if params != self.decoder_class: self.decoder_class = params self.decoder_switch = True elif method == "train": self.state_mode = 'train' resp = { "method": "train_response", "params": { "code": 200, "message": "ok" } } try: resp_bytes = json.dumps(resp, ensure_ascii=False).encode("utf-8") self.cmd_socket.send_multipart([ident, b"", resp_bytes]) algo_log(f"train 命令已即时回复客户端 {ident}", level="DEBUG") except Exception as e: algo_log(f"train 命令回复失败: {e}", level="ERROR") return elif method == "predict": self.state_mode = 'predict' if params == 1: #开始解码 self.StartDecode = True elif params == 2: #停止解码 self.IsExitApp = True self.running = False elif method == "rest": self.state_mode = 'rest' elif method == "impedance": if params == 1: self.open_Impedance = True elif params == 2: self.open_Impedance = False else: self.broadcast_message("error", {"code": 404, "message": f"未知命令: {method}"}) # -------------------------- 数据端口消息处理 -------------------------- def _handle_data_message(self, frames): """处理8100端口二进制脑电数据消息""" algo_log(f"收到数据帧,总帧数:{len(frames)}", level="DEBUG", record_once=True) # 然后再进行解析 if len(frames) == 4: # 你的上位机格式 ident, sender_ident, empty_sep, data_bytes = frames[:4] elif len(frames) == 3: # 标准格式 ident, empty_sep, data_bytes = frames[:3] elif len(frames) == 2: ident, data_bytes = frames[:2] else: return # 注册新的数据客户端(单客户端场景,自动覆盖旧身份) if ident not in self.data_clients: self.data_clients.clear() # 单客户端,只保留最新连接 self.data_clients.add(ident) self.current_data_client = ident algo_log(f"新数据客户端连接成功: {ident}", level="INFO") try: # 精确长度校验 EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * np.dtype(np.float64).itemsize if len(data_bytes) != EXPECTED_BYTES: algo_log(f"数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节", level="ERROR") return # 零拷贝解析 + 维度转换 data_np = np.frombuffer(data_bytes, dtype=np.float64) data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums']) data_np = data_np.T.astype(np.float64) # 写入滤波缓冲区 with self.filterBufferLock: self.filterBuffer.appendBuffer(data_np) # 写入范式缓冲区 with self.paradigmBufferLock: if self.interval_inited: self.epoch_finished = self.detect_event(data_np) if self.pack_contain_event: self.paradigmBuffer.resetAllPara() self.paradigmBuffer.appendBuffer(data_np) if self.epoch_finished: now = datetime.datetime.now() time_diff_str = "" # 计算与上一次Epoch完成的时间差 if self.last_epoch_finish_time is not None: # 时间差 单位:秒,保留3位小数 delta_seconds = (now - self.last_epoch_finish_time).total_seconds() time_diff_str = f" | 与上一次间隔: {delta_seconds:.3f} s" # 拼接日志,增加时间差信息 log_msg = f"Epoch采集完成: {now.strftime('%H:%M:%S.%f')[:-3]}{time_diff_str}" algo_log(log_msg, level="DEBUG") # 更新上一次Epoch完成时间为当前时间 self.last_epoch_finish_time = now else: self.paradigmBuffer.appendBuffer(data_np) except Exception as e: algo_log(f"数据处理失败: {str(e)}", level="ERROR") if IniRead('system', 'algo_log_level', 'INFO') == 'DEBUG': import traceback traceback.print_exc() # -------------------------- 事件检测 -------------------------- def detect_event(self, samples): self.pack_contain_event = False # 第65通道为事件通道 event = int(samples[-2][0]) # for idx, event in enumerate(events): if event in self.events: new_key = "".join( [ str(event), datetime.datetime.now().strftime("%Y-%m-%d \ -%H-%M-%S"), ] ) self.currentLabel = event if event == self.predict_event: self.count_events[new_key] = self.latency + 1 else: self.count_events[new_key] = self.train_latency + 1 self.event_inner_idx = self.device_info['frame_points'] - 1 # algo_log(f"事件检测到: {event},索引: {idx}", level="DEBUG") self.pack_contain_event = True # 倒计时并清理过期事件 drop_items = [] for key, value in self.count_events.items(): value -= 1 if value == 0: drop_items.append(key) self.count_events[key] = value for key in drop_items: del self.count_events[key] if drop_items: return True return False # -------------------------- 主循环 -------------------------- def run(self): self.running = True algo_log(f"ZMQ服务器启动成功 - host: {self.host}, 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO") try: while self.running: # 1. 处理两个端口的发送队列(必须在主线程执行) self._process_cmd_send_queue() self._process_data_send_queue() # 2. 轮询监听两个端口的输入事件 socks = dict(self.poller.poll(50)) # 处理8099命令端口消息 if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN: frames = self.cmd_socket.recv_multipart() self._handle_cmd_message(frames) # 处理8100数据端口消息(排空积压,消除标签延迟) if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN: while True: try: frames = self.data_socket.recv_multipart(zmq.NOBLOCK) self._handle_data_message(frames) except zmq.Again: break except Exception as e: algo_log(f"服务器主循环异常: {str(e)}", level="ERROR") return finally: self.running = False # 优雅关闭所有资源 self.cmd_socket.close() self.data_socket.close() self.context.term() algo_log("ZMQ服务器已关闭", level="INFO") def stop(self): """显式关闭服务器""" self.running = False self.cmd_socket.close() self.data_socket.close() self.context.term() algo_log(f"服务器已显式关闭 - 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO") if __name__ == '__main__': # 初始化并启动服务器 server = zmqServer() server.start() # 保持主线程运行 try: while server.running: threading.Event().wait(1) except KeyboardInterrupt: algo_log("收到键盘中断信号,正在停止服务器...", level="INFO") server.stop()