Files
bci_algo/Zmq/zmqServer.py
2026-06-10 17:55:43 +08:00

446 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*-coding:utf-8 -*-
import ast
import numpy as np
import threading
import zmq
import json
import queue
from typing import Dict
import datetime
import time
from Zmq.dataBuffer import ParadigmRingBuffer
from Zmq.filterProcess import FilterRingBuffer
from PubLibrary.InifileHelper import IniRead
from logs.log import algo_log
zmqServer_host = str(IniRead('system', 'zmqServer_host', '127.0.0.1'))
class zmqServer(threading.Thread):
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
threading.Thread.__init__(self)
self.device_info = device_info
self.host = zmqServer_host
self.cmd_port = cmd_port # 命令交互端口收JSON命令 + 返JSON结果
self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果
self.running = False
# 原有业务状态变量
self.open_Impedance = False #当前系统处于阻抗检测状态
self.StartDecode = False
self.StartTrain = False
self.state_mode = None
self.currentLabel = -1
self.IsExitApp = False
self.daemon = True
# 双环形缓冲区
self.paradigmBuffer = ParadigmRingBuffer(
self.device_info['channel_nums'],
self.device_info['sample_rate'] * 10
)
self.filterBuffer = FilterRingBuffer(
self.device_info['channel_nums'],
self.device_info['sample_rate'] * 10
)
self.paradigmBufferLock = threading.Lock()
self.filterBufferLock = threading.Lock()
# ZMQ上下文与套接字
self.context = zmq.Context()
# 8099命令端口ROUTER
self.cmd_socket = self.context.socket(zmq.ROUTER)
self.cmd_socket.setsockopt(zmq.SocketOption.RCVHWM, 100)
self.cmd_socket.setsockopt(zmq.SocketOption.SNDHWM, 100)
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
# 8100数据端口ROUTER
self.data_socket = self.context.socket(zmq.ROUTER)
self.data_socket.setsockopt(zmq.SocketOption.RCVHWM, 500)
self.data_socket.setsockopt(zmq.SocketOption.SNDHWM, 100) # 添加发送高水位线
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
# Poller轮询器
self.poller = zmq.Poller()
self.poller.register(self.cmd_socket, zmq.POLLIN)
self.poller.register(self.data_socket, zmq.POLLIN)
# 业务变量
self.targetFreqs = []
self.changeTarget = False
self.labels = [0x01, 0x02, 0x03]
self.decoder_switch = False
self.decoder_class = None
# 客户端管理(单客户端场景)
self.cmd_clients = set()
self.data_clients = set()
self.current_data_client = None # 唯一数据客户端身份,用于发送滤波结果
# 发送队列(双端口分离)
self.cmd_send_queue = queue.Queue() # 8099端口命令结果队列
self.data_send_queue = queue.Queue() # 8100端口滤波数据队列
# 范式buffer与事件检测参数
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self.latency = 50
self.train_latency = 50
self.count_events = {}
self.epoch_finished = False
self.pack_contain_event = False
self.event_inner_idx = -1
self.interval_inited = False
self.last_epoch_finish_time = None
def reset_state(self):
"""清空采集器状态和缓存数据"""
with self.paradigmBufferLock:
self.paradigmBuffer.resetAllPara()
self.count_events = {}
self.epoch_finished = False
self.pack_contain_event = False
self.event_inner_idx = -1
self.interval_inited = False
def interval_init(self, decoder_class):
if decoder_class == 'ssmvep':
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # [50, 550]
self.train_epoch = [
int(self.interval_epoch[0]),
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])
] # [50, 575]
self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #115包, 575个点
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #120包 600个点
elif decoder_class == 'mi':
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] #[125, 1125]
self.train_epoch = self.interval_epoch.copy()
self.latency = self.interval_epoch[1] // 5 #225
self.train_latency = self.latency #225
algo_log(f"时间窗初始化完成: {interval_epoch}", level="INFO")
self.count_events: Dict[str, int] = {}
self.event_inner_idx = -1
self.epoch_finished = False
self.pack_contain_event = False
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self.interval_inited = True
# -------------------------- 8099端口命令结果广播 --------------------------
def broadcast_message(self, method, params):
"""
向所有8099端口客户端广播JSON格式的命令结果
用于:解码结果、训练状态、错误提示、进度通知等
"""
self.cmd_send_queue.put((method, params))
def _process_cmd_send_queue(self):
"""处理8099端口发送队列在主线程执行保证ZMQ线程安全"""
while not self.cmd_send_queue.empty():
method, params = self.cmd_send_queue.get()
if not self.cmd_clients:
continue
try:
msg = {'method': method, 'params': params}
msg_bytes = json.dumps(msg).encode('utf-8')
algo_log(f"发送命令结果: {msg}", level="DEBUG")
# 广播到所有命令客户端
for client_id in list(self.cmd_clients):
try:
self.cmd_socket.send_multipart([client_id, b"", msg_bytes])
except Exception as e:
algo_log(f"向命令客户端{client_id}发送失败: {e}", level="ERROR")
self.cmd_clients.discard(client_id)
except Exception as e:
algo_log(f"命令结果打包失败: {e}", level="ERROR")
# -------------------------- 8100端口滤波结果发送 --------------------------
def send_filtered_data(self, filtered_data):
"""
向8100端口客户端发送二进制格式的滤波结果
用于:上位机实时绘图的脑电波形数据
:param filtered_data: 滤波后数据shape=(通道数, 50)float64格式
"""
if self.current_data_client is None:
algo_log("数据客户端未连接,跳过滤波数据发送", level="WARNING")
return
# 转置为上位机需要的[50, 通道数]格式
filtered_data = filtered_data.T.astype(np.float64)
send_buf = filtered_data.tobytes()
algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG", record_once=True)
self.data_send_queue.put(send_buf)
def _process_data_send_queue(self):
"""处理8100端口发送队列在主线程执行保证ZMQ线程安全"""
while not self.data_send_queue.empty():
send_buf = self.data_send_queue.get()
if self.current_data_client is None:
continue
try:
# 标准ROUTER发送格式[客户端ID, 空分隔帧, 数据帧]
self.data_socket.send_multipart([
self.current_data_client,
b"",
send_buf
])
algo_log(f"发送滤波数据成功,长度: {len(send_buf)}字节", level="DEBUG", record_once=True)
except Exception as e:
algo_log(f"发送滤波数据失败: {e}", level="ERROR")
# 客户端断开,重置身份
self.current_data_client = None
self.data_clients.clear()
# -------------------------- 命令端口消息处理 --------------------------
def _handle_cmd_message(self, frames):
"""处理8099端口JSON命令消息"""
if len(frames) < 3:
algo_log(f"无效命令帧长度不足3帧实际{len(frames)}", level="ERROR")
return
ident, _, message_bytes = frames[:3]
# 注册新的命令客户端
if ident not in self.cmd_clients:
self.cmd_clients.add(ident)
algo_log(f"新命令客户端连接成功: {ident}", level="INFO")
# 解析JSON命令
try:
message = json.loads(message_bytes.decode('utf-8'))
except json.JSONDecodeError:
algo_log(f"无效JSON命令: {message_bytes.hex()}", level="ERROR")
self.broadcast_message("error", {"code": 400, "message": "无效JSON格式"})
return
algo_log(f"收到命令: {message}", level="INFO")
method = message.get("method")
params = message.get("params")
# 命令处理逻辑
if method == "sync":
self.state_mode = 'sync'
elif method == "targetFreqs":
if not isinstance(params, list):
algo_log(f"targetFreqs must be a list")
return
if params != self.targetFreqs:
self.targetFreqs = params
self.changeTarget = True
elif method == "decoderClass":
if not isinstance(params, str):
algo_log(f"decoderClass必须是字符串")
return
if params != self.decoder_class:
self.decoder_class = params
self.decoder_switch = True
elif method == "train":
self.state_mode = 'train'
resp = {
"method": "train_response",
"params": {
"code": 200,
"message": "ok"
}
}
try:
resp_bytes = json.dumps(resp, ensure_ascii=False).encode("utf-8")
self.cmd_socket.send_multipart([ident, b"", resp_bytes])
algo_log(f"train 命令已即时回复客户端 {ident}", level="DEBUG")
except Exception as e:
algo_log(f"train 命令回复失败: {e}", level="ERROR")
return
elif method == "predict":
self.state_mode = 'predict'
if params == 1: #开始解码
self.StartDecode = True
elif params == 2: #停止解码
self.IsExitApp = True
self.running = False
elif method == "rest":
self.state_mode = 'rest'
elif method == "impedance":
if params == 1:
self.open_Impedance = True
elif params == 2:
self.open_Impedance = False
else:
self.broadcast_message("error", {"code": 404, "message": f"未知命令: {method}"})
# -------------------------- 数据端口消息处理 --------------------------
def _handle_data_message(self, frames):
"""处理8100端口二进制脑电数据消息"""
algo_log(f"收到数据帧,总帧数:{len(frames)}", level="DEBUG", record_once=True)
# 然后再进行解析
if len(frames) == 4:
# 你的上位机格式
ident, sender_ident, empty_sep, data_bytes = frames[:4]
elif len(frames) == 3:
# 标准格式
ident, empty_sep, data_bytes = frames[:3]
elif len(frames) == 2:
ident, data_bytes = frames[:2]
else:
return
# 注册新的数据客户端(单客户端场景,自动覆盖旧身份)
if ident not in self.data_clients:
self.data_clients.clear() # 单客户端,只保留最新连接
self.data_clients.add(ident)
self.current_data_client = ident
algo_log(f"新数据客户端连接成功: {ident}", level="INFO")
try:
# 精确长度校验
EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * np.dtype(np.float64).itemsize
if len(data_bytes) != EXPECTED_BYTES:
algo_log(f"数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节", level="ERROR")
return
# 零拷贝解析 + 维度转换
data_np = np.frombuffer(data_bytes, dtype=np.float64)
data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
data_np = data_np.T.astype(np.float64)
# 写入滤波缓冲区
with self.filterBufferLock:
self.filterBuffer.appendBuffer(data_np)
# 写入范式缓冲区
with self.paradigmBufferLock:
if self.interval_inited:
self.epoch_finished = self.detect_event(data_np)
if self.pack_contain_event:
self.paradigmBuffer.resetAllPara()
self.paradigmBuffer.appendBuffer(data_np)
if self.epoch_finished:
now = datetime.datetime.now()
time_diff_str = ""
# 计算与上一次Epoch完成的时间差
if self.last_epoch_finish_time is not None:
# 时间差 单位保留3位小数
delta_seconds = (now - self.last_epoch_finish_time).total_seconds()
time_diff_str = f" | 与上一次间隔: {delta_seconds:.3f} s"
# 拼接日志,增加时间差信息
log_msg = f"Epoch采集完成: {now.strftime('%H:%M:%S.%f')[:-3]}{time_diff_str}"
algo_log(log_msg, level="DEBUG")
# 更新上一次Epoch完成时间为当前时间
self.last_epoch_finish_time = now
else:
self.paradigmBuffer.appendBuffer(data_np)
except Exception as e:
algo_log(f"数据处理失败: {str(e)}", level="ERROR")
if IniRead('system', 'algo_log_level', 'INFO') == 'DEBUG':
import traceback
traceback.print_exc()
# -------------------------- 事件检测 --------------------------
def detect_event(self, samples):
self.pack_contain_event = False
# 第65通道为事件通道
event = int(samples[-2][0])
# for idx, event in enumerate(events):
if event in self.events:
new_key = "".join(
[
str(event),
datetime.datetime.now().strftime("%Y-%m-%d \
-%H-%M-%S"),
]
)
self.currentLabel = event
if event == self.predict_event:
self.count_events[new_key] = self.latency + 1
else:
self.count_events[new_key] = self.train_latency + 1
self.event_inner_idx = self.device_info['frame_points'] - 1
# algo_log(f"事件检测到: {event},索引: {idx}", level="DEBUG")
self.pack_contain_event = True
# 倒计时并清理过期事件
drop_items = []
for key, value in self.count_events.items():
value -= 1
if value == 0:
drop_items.append(key)
self.count_events[key] = value
for key in drop_items:
del self.count_events[key]
if drop_items:
return True
return False
# -------------------------- 主循环 --------------------------
def run(self):
self.running = True
algo_log(f"ZMQ服务器启动成功 - host: {self.host}, 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO")
try:
while self.running:
# 1. 处理两个端口的发送队列(必须在主线程执行)
self._process_cmd_send_queue()
self._process_data_send_queue()
# 2. 轮询监听两个端口的输入事件
socks = dict(self.poller.poll(50))
# 处理8099命令端口消息
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:
frames = self.cmd_socket.recv_multipart()
self._handle_cmd_message(frames)
# 处理8100数据端口消息排空积压消除标签延迟
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
while True:
try:
frames = self.data_socket.recv_multipart(zmq.NOBLOCK)
self._handle_data_message(frames)
except zmq.Again:
break
except Exception as e:
algo_log(f"服务器主循环异常: {e}", level="ERROR")
finally:
self.running = False
# 优雅关闭所有资源
self.cmd_socket.close()
self.data_socket.close()
self.context.term()
algo_log("ZMQ服务器已关闭", level="INFO")
def stop(self):
"""显式关闭服务器"""
self.running = False
self.cmd_socket.close()
self.data_socket.close()
self.context.term()
algo_log(f"服务器已显式关闭 - 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO")
if __name__ == '__main__':
# 初始化并启动服务器
server = zmqServer()
server.start()
# 保持主线程运行
try:
while server.running:
threading.Event().wait(1)
except KeyboardInterrupt:
algo_log("收到键盘中断信号,正在停止服务器...", level="INFO")
server.stop()