diff --git a/Zmq/zmqServer1.py b/Zmq/zmqServer1.py deleted file mode 100644 index 79cee04..0000000 --- a/Zmq/zmqServer1.py +++ /dev/null @@ -1,445 +0,0 @@ -import numpy as np -import zmq -import threading -import json -import queue -import time -from Device.SunnyLinker import SunnyLinker64, RingBuffer -from collections import deque - - -class zmqServer(threading.Thread): - def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100): - threading.Thread.__init__(self) - self.host = host - self.cmd_port = cmd_port - self.data_port = data_port - self.running = False - self.get_Impedance = False - self.open_Impedance = None - self.StartDecode = False - self.StartTrain = False - self.state_mode = None - self.currentLabel = -1 - self.IsExitApp = False - self.getReport = False - self.daemon = True - - # ZMQ Context - self.context = zmq.Context() - - # 指令通道 (8099) - ROUTER - self.cmd_socket = self.context.socket(zmq.ROUTER) - self.cmd_socket.setsockopt(zmq.RCVHWM, 1000) - self.cmd_socket.setsockopt(zmq.SNDHWM, 1000) - self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}") - - # 数据通道 (8100)) - ROUTER - self.data_socket = self.context.socket(zmq.ROUTER) - self.data_socket.setsockopt(zmq.RCVHWM, 1000) - self.data_socket.setsockopt(zmq.RCVTIMEO, 50) - self.data_socket.bind(f"tcp://{self.host}:{data_port}") - - self.targetFreqs = [] - self.changeTarget = False - self.sunnyLinker = SunnyLinker64(None, None, None, None, None) - self.labels = [0x01, 0x02, 0x03] - - self.decoder_switch = False - self.decoder_class = None - self.cmd_clients = set() - self.data_clients = set() - self.send_queue = queue.Queue() - - # ========== 数据缓冲区 (RingBuffer) ========== - # 与 SunnyLinker 保持一致,使用 RingBuffer - # 66 = 64 EEG通道 + 1 事件通道(第65) + 1 标签序号通道(第66) - # 缓存约 10 秒数据 (250Hz * 10s = 2500 点) - self.n_chan = 66 - self.t_buffer = 10.0 # 缓冲区时长(秒) - self.__ringBuffer = RingBuffer(self.n_chan, int(self.t_buffer * 250)) - - # 事件检测相关 - self._event_lock = threading.Lock() - self._epoch_finished = False - self._event_inner_idx = -1 - self.pack_contain_event = False - self.predict_event = 99 - self.events = [1, 2, self.predict_event] - self.count_events = {} - self.latency = 50 - self.train_latency = 50 - - # 当前事件标签序号 (从第66通道获取) - self.current_label_index = 0 - - # 初始化标志 - self._interval_inited = False - self._currentLabel = -1 - - # 注册的客户端(兼容旧接口) - self.clients = set() - - # ========== 事件属性:线程安全访问 ========== - @property - def epoch_finished(self): - with self._event_lock: - return self._epoch_finished - - @epoch_finished.setter - def epoch_finished(self, value): - with self._event_lock: - self._epoch_finished = value - - @property - def event_inner_idx(self): - with self._event_lock: - return self._event_inner_idx - - @event_inner_idx.setter - def event_inner_idx(self, value): - with self._event_lock: - self._event_inner_idx = value - - @property - def interval_inited(self): - return self._interval_inited - - @interval_inited.setter - def interval_inited(self, value): - self._interval_inited = value - - @property - def currentLabel(self): - return self._currentLabel - - @currentLabel.setter - def currentLabel(self, value): - self._currentLabel = value - - def broadcast_message(self, method, params): - """Put message into queue to be sent to all connected clients""" - self.send_queue.put((method, params)) - - # ========== 数据缓冲区操作接口 ========== - def GetDataLenCount(self): - """返回缓冲区当前数据点数""" - return self.__ringBuffer.nUpdate - - def getData(self, count): - """获取最新count个数据点,不消费(只读)""" - with self.__ringBuffer.RingBufferLock: - count = min(count, self.__ringBuffer.nUpdate) - if count == 0: - return np.zeros((self.n_chan, 0)) - - # 计算读取范围(从尾部取最新数据) - read_end = (self.__ringBuffer.currentPtr - 1) % self.__ringBuffer.n_points - read_start = (read_end - count + 1) % self.__ringBuffer.n_points - - if self.__ringBuffer.currentPtr == 0: - read_start = self.__ringBuffer.n_points - count - read_end = self.__ringBuffer.n_points - 1 - - if read_start <= read_end: - data = self.__ringBuffer.buffer[:, read_start:read_end + 1] - else: - part1 = self.__ringBuffer.buffer[:, read_start:] - part2 = self.__ringBuffer.buffer[:, :read_end + 1] - data = np.concatenate((part1, part2), axis=1) - - return data - - def consumeData(self, count): - """消费(丢弃)指定数量的数据点,从头部移除""" - with self.__ringBuffer.RingBufferLock: - count = min(count, self.__ringBuffer.nUpdate) - self.__ringBuffer.readPtr = (self.__ringBuffer.readPtr + count) % self.__ringBuffer.n_points - self.__ringBuffer.nUpdate -= count - - def ResetAll(self): - """重置缓冲区""" - with self.__ringBuffer.RingBufferLock: - self.__ringBuffer.resetAllPara() - with self._event_lock: - self._epoch_finished = False - self._event_inner_idx = -1 - self.pack_contain_event = False - self.count_events.clear() - self.current_label_index = 0 - - def reset_data_buffer(self): - self.ResetAll() - - def reset_state(self): - self.ResetAll() - - def interval_init(self, decoder_class): - """初始化事件检测参数""" - import ast - from PubLibrary.InifileHelper import IniRead - - if decoder_class == 'ssmvep': - interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) - self.interval_epoch = [int(i * 250) for i in interval_epoch] - self.train_epoch = [int(self.interval_epoch[0]), - int(self.interval_epoch[1] + 0.1 * 250)] - self.latency = (self.interval_epoch[1] + 0.1 * 250) // 5 - self.train_latency = (self.train_epoch[1] + 0.1 * 250) // 5 - - elif decoder_class == 'mi': - interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) - self.interval_epoch = [int(i * 250) for i in interval_epoch] - self.train_epoch = self.interval_epoch.copy() - self.latency = self.interval_epoch[1] // 5 - self.train_latency = self.latency - - self.count_events = {} - self._event_inner_idx = -1 - self._epoch_finished = False - self.pack_contain_event = False - self.predict_event = 99 - self.events = [1, 2, self.predict_event] - self._interval_inited = True - - # ========== 事件检测 ========== - def detect_event(self, data_matrix): - """ - 检测事件通道中的触发信号 - - @param data_matrix: shape (66, N) - N个采样点的数据 - 第65行(索引64) = 事件通道 - 第66行(索引65) = 标签通道 - @return: 是否检测到事件 - """ - if data_matrix.shape[1] == 0: - return False - - self.pack_contain_event = False - event_channel = data_matrix[64, :] # 第65通道 = 标签值(event值) - label_channel = data_matrix[65, :] # 第66通道 = 标签序号(label index) - - events = event_channel.tolist() - - with self._event_lock: - self._event_inner_idx = -1 - self.current_event_label = 0 - - for idx, event in enumerate(events): - if int(event) in self.events: - self._event_inner_idx = idx - self.current_label_index = int(label_channel[idx]) - self.pack_contain_event = True - - new_key = f"{event}_{time.time()}" - latency = self.latency if event == self.predict_event else self.train_latency - self.count_events[new_key] = latency + 1 - - # 延迟计数递减 - drop_items = [] - for key, value in self.count_events.items(): - value = value - 1 - if value == 0: - drop_items.append(key) - self.count_events[key] = value - for key in drop_items: - del self.count_events[key] - - if drop_items: - self._epoch_finished = True - # 检测到事件时,清除RingBuffer中之前的数据,只保留当前包 - if self.pack_contain_event: - self.__ringBuffer.resetAllPara() - return True - - self._epoch_finished = False - return False - - def run(self): - self.running = True - print(f"Server running - CMD: {self.cmd_port}, DATA: {self.data_port}") - - cmd_poller = zmq.Poller() - cmd_poller.register(self.cmd_socket, zmq.POLLIN) - - data_poller = zmq.Poller() - data_poller.register(self.data_socket, zmq.POLLIN) - - try: - while self.running: - # --- 处理发送队列 (指令通道) --- - while not self.send_queue.empty(): - method, params = self.send_queue.get() - if self.cmd_clients: - try: - msg = {'method': method, 'params': params} - msg_bytes = json.dumps(msg).encode('utf-8') - for client_id in list(self.cmd_clients): - try: - self.cmd_socket.send_multipart([client_id, b'', msg_bytes]) - except Exception: - pass - except Exception: - pass - - # --- 处理指令通道 --- - socks = dict(cmd_poller.poll(10)) - if self.cmd_socket in socks: - self._handle_cmd_socket() - - # --- 处理数据通道 --- - socks = dict(data_poller.poll(10)) - if self.data_socket in socks: - self._handle_data_socket() - - except Exception as e: - print(f"Server error: {e}") - finally: - self.running = False - self.cmd_socket.close() - self.data_socket.close() - self.context.term() - - def _handle_cmd_socket(self): - """处理指令通道消息""" - try: - frames = self.cmd_socket.recv_multipart() - if len(frames) < 3: - return - ident, _, message_bytes = frames[:3] - self.cmd_clients.add(ident) - self.clients.add(ident) - - message = json.loads(message_bytes.decode('utf-8')) - method = message.get("method") - params = message.get("params") - - print(f"[CMD] {method}: {params}") - - if method == "sync": - self.state_mode = 'sync' - elif method == "targetFreqs": - if isinstance(params, list) and params != self.targetFreqs: - self.targetFreqs = params - self.changeTarget = True - elif method == "decoderClass": - if isinstance(params, str) and params != self.decoder_class: - self.decoder_class = params - self.decoder_switch = True - elif method == "getReport": - self.getReport = True - elif method == "train": - self.state_mode = 'train' - self.StartTrain = True - self.currentLabel = params - elif method == "predict": - self.state_mode = 'predict' - if params == 1: - self.StartDecode = True - elif params == 2: - self.IsExitApp = True - self.running = False - elif method == "rest": - self.state_mode = 'rest' - elif method == "impedance": - if params == 1: - self.open_Impedance = True - self.get_Impedance = True - elif params == 2: - self.open_Impedance = False - self.get_Impedance = False - - except Exception as e: - print(f"CMD socket error: {e}") - - def _handle_data_socket(self): - """处理数据通道消息 (EEG数据) - - 上位机数据格式: - - 数据帧: [identity, '', meta_json, data_buffer] - data_buffer = [N, 66] float32 -> 转置为 [66, N] - """ - try: - frames = self.data_socket.recv_multipart() - if len(frames) < 4: - return - ident, _, message_bytes = frames[:3] - self.data_clients.add(ident) - - meta = json.loads(message_bytes.decode('utf-8')) - - # data: [N, 66] -> 转置 -> [66, N] - raw_data = np.frombuffer(frames[3], dtype=np.float32) - n_samples, n_channels = meta.get('shape', [5, 66]) - data_matrix = raw_data.reshape(n_samples, n_channels).T.astype(np.float32) - - # 写入 RingBuffer - with self.__ringBuffer.RingBufferLock: - self.__ringBuffer.appendBuffer(data_matrix) - - # 事件检测 - self.detect_event(data_matrix) - - except Exception as e: - print(f"DATA socket error: {e}") - - # ========== 各范式数据访问接口 ========== - def get_MIData(self): - """获取MI导联数据 (21通道 + 事件)""" - data = self.getData(self.GetDataLenCount()) - rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65] - row_to_select = np.array(rows_to_extract) - if data.shape[1] > 0: - return data[row_to_select, :] - return np.zeros((len(rows_to_extract), 0)) - - def get_SSMVEPData(self): - """获取SSMVEP导联数据 (8通道 + 事件)""" - data = self.getData(self.GetDataLenCount()) - rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65] - row_to_select = np.array(rows_to_extract) - if data.shape[1] > 0: - return data[row_to_select, :] - return np.zeros((len(rows_to_extract), 0)) - - def getDataViaSSVEP(self, count): - """获取SSVEP数据 (8通道 + 事件)""" - data = self.getData(count) - rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64] - row_to_select = np.array(rows_to_extract) - if data.shape[1] > 0: - return data[row_to_select, :] - return np.zeros((len(rows_to_extract), 0)) - - def get_concentrateData(self, count): - """获取专注力数据 (2通道)""" - data = self.getData(count) - rows_to_extract = [0, 1] - row_to_select = np.array(rows_to_extract) - if data.shape[1] > 0: - return data[row_to_select, :] - return np.zeros((len(rows_to_extract), 0)) - - def get_blinkData(self, count): - """获取眨眼数据 (2通道)""" - data = self.getData(count) - rows_to_extract = [0, 1] - row_to_select = np.array(rows_to_extract) - if data.shape[1] > 0: - return data[row_to_select, :] - return np.zeros((len(rows_to_extract), 0)) - - def getImpedance(self, data, decoder_class): - """计算阻抗(ZMQ模式下不可用)""" - return np.zeros(8) - - def stop(self): - self.running = False - self.cmd_socket.close() - self.data_socket.close() - self.context.term() - - -if __name__ == '__main__': - server = zmqServer() - server.start()