Files
bci_algo/Zmq/zmqServer.py
2026-06-07 11:05:24 +08:00

375 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import ast
import numpy as np
import threading
import json
import queue
from typing import Dict
import datetime
import time
# from Device.SunnyLinker import SunnyLinker64
from Zmq.dataBuffer import ParadigmRingBuffer
from Zmq.filterProcess import FilterRingBuffer
from PubLibrary.InifileHelper import IniRead
from logs.log import algo_log
import zmq
class zmqServer(threading.Thread):
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
threading.Thread.__init__(self)
self.device_info = device_info
self.host = host
self.cmd_port = cmd_port # 命令交互端口
self.data_port = data_port # 数据接收端口
self.running = False
# 原有业务状态变量
# self.get_Impedance = False # 是否返回阻抗值
self.open_Impedance = False # 是否开启阻抗检测功能
self.StartDecode = False # false 停止解码true=开始解码
self.StartTrain = False # False未进入训练状态True处于训练状态
self.state_mode = None # 'train'为训练状态rest'为休息状态,'test'为测试状态
self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签
self.IsExitApp = False # 当socket收到2的时候就置为True代表要退出系统了。
# self.getReport = False # 获取训练报告内容
self.daemon = True
# 范式数据缓存
self.paradigmBuffer = ParadigmRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10)
self.filterBuffer = FilterRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10)
self.paradigmBufferLock= threading.Lock()
# 命令与数据通信
self.context = zmq.Context()
# 指令通道 (8099) - ROUTER短JSON命令低频率
self.cmd_socket = self.context.socket(zmq.ROUTER)
# 通用套接字选项:仍在 SocketOption 中
self.cmd_socket.setsockopt(zmq.SocketOption.RCVHWM, 100)
self.cmd_socket.setsockopt(zmq.SocketOption.SNDHWM, 100)
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
# 数据通道 (8100) - ROUTER高频脑电二进制流
self.data_socket = self.context.socket(zmq.ROUTER)
self.data_socket.setsockopt(zmq.SocketOption.RCVHWM, 500)
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
# Poller 轮训器(保持不变)
self.poller = zmq.Poller()
self.poller.register(self.cmd_socket, zmq.POLLIN)
self.poller.register(self.data_socket, zmq.POLLIN)
# 业务变量
self.targetFreqs = []
self.changeTarget = False # 更换目标频率
# self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类已在Decoder实例化
self.labels = [0x01, 0x02,0x03]
self.decoder_switch = False #更换解码器
self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi'
# 客户端管理 - 区分命令/数据客户端
self.cmd_clients = set() # 命令端口客户端ID
self.data_clients = set() # 数据端口客户端ID
self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播)
# 范式buffer参数, 事件检测相关
self._event_lock = threading.Lock()
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self.latency = 50
self.train_latency = 50
self.count_events = {}
self.epoch_finished = False
self.pack_contain_event = False
self.event_inner_idx = -1
self.interval_inited = False
def reset_state(self):
"""清空采集器状态和缓存数据"""
with self.paradigmBufferLock:
self.paradigmBuffer.resetAllPara()
self.count_events = {}
self.epoch_finished = False
self.pack_contain_event = False
self.event_inner_idx = -1
self.interval_inited = False
def interval_init(self, decoder_class):
if decoder_class == 'ssmvep':
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息
self.train_epoch = [int(self.interval_epoch[0]),
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
self.latency = (self.interval_epoch[
1] + 0.1 * self.device_info['sample_rate']) // 5 # 提取epoch的延迟标记5代表每次解包得到的5位采样点;0.1表示比实际需要的长度多取0.1,会被截掉
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5
elif decoder_class == 'mi':
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息
self.train_epoch = self.interval_epoch.copy()
self.latency = (self.interval_epoch[1]) // 5 # 提取epoch的延迟标记5代表每次解包得到的5位采样点;
self.train_latency = self.latency
print('时间窗:', (interval_epoch))
self.count_events: Dict[str, int] = {} # 表示包延迟的计数信息
self.event_inner_idx = -1 # event在5位数据包内部的idx
self.epoch_finished = False # 接收epoch是否完整
self.pack_contain_event = False # 当前包是否含有event
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self.interval_inited = True
def broadcast_message(self, method, params):
"""Put message into queue to be sent to all command clients"""
self.send_queue.put((method, params))
def _handle_cmd_message(self, frames):
"""处理命令端口消息(原有命令交互逻辑)"""
if len(frames) < 3:
return
ident, _, message_bytes = frames[:3]
# 注册新的命令客户端
if ident not in self.cmd_clients:
self.cmd_clients.add(ident)
algo_log(f"New CMD Client Connected: {ident} (port: {self.cmd_port})")
# 解析消息
try:
message = json.loads(message_bytes.decode('utf-8'))
except json.JSONDecodeError:
algo_log(f"Invalid JSON from CMD client {ident}")
return
algo_log(f"Received CMD request: {message}")
method = message.get("method")
params = message.get("params")
# 原有命令处理逻辑
if method == "sync":
self.state_mode = 'sync'
elif method == "targetFreqs":
if not isinstance(params, list):
algo_log(f"targetFreqs must be a list")
return
if params != self.targetFreqs:
self.targetFreqs = params
self.changeTarget = True
elif method == "decoderClass":
if not isinstance(params, str):
algo_log(f"decoderClass must be a str")
return
if params != self.decoder_class:
self.decoder_class = params
self.decoder_switch = True
elif method == "train":#训练状态
self.state_mode = 'train'
self.StartTrain = True
self.currentLabel = params # 当前刺激端的训练标签
# self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
elif method == "predict":#预测状态
self.state_mode = 'predict'
if params == 1: #开始解码
self.StartDecode = True
# self.sunnyLinker.push_trigger(0x63)
elif params == 2: #停止解码
self.IsExitApp = True
self.running = False
elif method == "rest": #休息状态
self.state_mode = 'rest'
elif method == "impedance":
if params == 1:
self.open_Impedance = True # 开启阻抗
# self.get_Impedance = True # 返回阻抗
elif params == 2:
self.open_Impedance = False # 关闭阻抗
else:
algo_log(f"未知命令:{method}", level="WARNING")
# elif method == "getReport":
# self.getReport = True
# elif params == 2:
# self.open_Impedance = False # 关闭阻抗
# self.get_Impedance = False # 停止返回阻抗
def _handle_data_message(self, frames):
"""
处理8100端口原始脑电二进制数据
固定格式:上位机发送 (5,66) float32 二维数组字节流(已转换为微伏物理量)→ 转置为 (66,5) 写入双缓冲区
"""
# 1. 校验ZMQ消息帧完整性ROUTER接收DEALER消息的帧格式[客户端ID, 发送方ID, 空帧, 数据帧]
if len(frames) < 4: # 至少需要4帧
algo_log(f"Invalid data frame: 帧数量不足期望≥4实际{len(frames)}", level="ERROR")
return
# 2. 正确解析帧适配DEALER→ROUTER的帧格式
client_ident, sender_ident, empty_sep, data_bytes = frames[:4]
if empty_sep != b'': # 校验空分隔帧
algo_log(f"Invalid frame separator: 期望空字节,实际{empty_sep}", level="ERROR")
return
# 3. 客户端管理(单客户端场景,自动更新最新身份)
if client_ident not in self.data_clients:
self.data_clients.add(client_ident)
self.current_data_client = client_ident # 保存唯一客户端身份,用于后续回复滤波结果
print(f"[INFO] 新数据客户端连接成功:{client_ident}")
try:
# 4. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节
EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * 4 # 每个float32占4字节
if len(data_bytes) != EXPECTED_BYTES:
algo_log(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节", level="ERROR")
return
# 5. 零拷贝二进制解析 + 维度转换
data_np = np.frombuffer(data_bytes, dtype=np.float32)
data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
data_np = data_np.T.astype(np.float64)
# 6. 写入滤波缓冲区
self.filterBuffer.appendBuffer(data_np)
# 7. 写入范式缓冲区
try:
with self.paradigmBufferLock:
if self.interval_inited:
self.epoch_finished = self.detect_event(data_np)
if self.pack_contain_event:
self.paradigmBuffer.resetAllPara() # 检测到当前pack含有event清除ringbuffer中之前的数据
self.paradigmBuffer.appendBuffer(data_np)
if self.epoch_finished:
time.sleep(0.005)
algo_log('epoch_finished: ' + datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3], level="DEBUG")
else:
self.paradigmBuffer.appendBuffer(data_np)
except Exception as e:
print("锁:写入异常",e)
self.paradigmBuffer.appendBuffer(data_np)
# algo_log(f"数据写入成功shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG")
except Exception as e:
algo_log(f"数据处理失败:{str(e)}", level="ERROR")
if IniRead('system', 'algo_log_level', 'INFO') == 'DEBUG':
import traceback
traceback.print_exc()
# 检测是否含有标签
def detect_event(self, samples):
self.pack_contain_event = False
events = np.array(samples[-2])[0].tolist()
for idx, event in enumerate(events):
if int(event) in self.events:
new_key = "".join(
[
str(event),
datetime.datetime.now().strftime("%Y-%m-%d \
-%H-%M-%S"),
]
)
if event == self.predict_event:
self.count_events[new_key] = self.latency + 1
else:
self.count_events[new_key] = self.train_latency + 1
self.event_inner_idx = idx
self.pack_contain_event = True
drop_items = []
for key, value in self.count_events.items():
value = value - 1
if value == 0:
drop_items.append(key)
self.count_events[key] = value
for key in drop_items:
del self.count_events[key]
if drop_items:
return True
return False
def _process_send_queue(self):
"""处理发送队列,向所有命令客户端广播消息"""
while not self.send_queue.empty():
method, params = self.send_queue.get()
if self.cmd_clients:
try:
msg = {'method': method, 'params': params}
msg_bytes = json.dumps(msg).encode('utf-8')
# 打印日志(隐藏大尺寸数据)
if method in ['single_trial_plot', 'miReport']:
print(f"{{'method': '{method}', 'params': <Base64 Image Data>}}")
else:
print(f"Sending CMD message: {msg}")
# 广播到所有命令客户端
for client_id in list(self.cmd_clients):
try:
self.cmd_socket.send_multipart([client_id, b'', msg_bytes])
except Exception as e:
print(f"Error sending to CMD client {client_id}: {e}")
self.cmd_clients.discard(client_id) # 移除失效客户端
except Exception as e:
print(f"Error preparing broadcast: {e}")
def run(self):
self.running = True
algo_log(f"algo ZMQ Server started - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}", level="INFO")
try:
while self.running:
# 1. 处理发送队列(命令端口广播)
self._process_send_queue()
# 2. 轮训监听两个Socket的输入事件
socks = dict(self.poller.poll(50))
# 处理命令端口消息
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:
frames = self.cmd_socket.recv_multipart()
self._handle_cmd_message(frames)
# 处理数据端口消息
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
frames = self.data_socket.recv_multipart()
self._handle_data_message(frames)
except Exception as e:
print(f"Server error occurred: {e}")
finally:
self.running = False
# 关闭所有Socket和上下文
self.cmd_socket.close()
self.data_socket.close()
self.context.term()
print("Server sockets and context closed.")
def stop(self):
"""显式关闭服务器"""
self.running = False
self.cmd_socket.close()
self.data_socket.close()
self.context.term()
print(f"Server closed explicitly - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}")
if __name__ == '__main__':
# 初始化并启动服务器默认cmd=8099, data=8100
server = zmqServer()
server.start()
# 保持主线程运行
try:
while server.running:
threading.Event().wait(1)
except KeyboardInterrupt:
print("Received KeyboardInterrupt, stopping server...")
server.stop()