Files
bci_algo/system_test.py

422 lines
14 KiB
Python
Raw Normal View History

2026-06-09 10:57:28 +08:00
# -*- coding: utf-8 -*-
"""
ZMQ 脑电数据测试工具语法错误修复版
修复点
1. dataclass 可变列表默认值报错
2. threading.Thread daemon 参数语法错误
适配Python3.10全链路 float64ZMQ DEALER<->ROUTER
端口8099(命令) / 8100(数据)
"""
import zmq
import time
import threading
import numpy as np
import matplotlib.pyplot as plt
import json
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union, Tuple
from matplotlib.animation import FuncAnimation
# ===================== 1. 配置管理 =====================
@dataclass(frozen=True) # 冻结配置类
class TestConfig:
# 网络配置
SERVER_IP: str = "127.0.0.1"
CMD_PORT: int = 8099
DATA_PORT: int = 8100
# 硬件与时序
SAMPLE_RATE: int = 250
FRAME_INTERVAL_MS: int = 20
SEND_INTERVAL: float = FRAME_INTERVAL_MS / 1000
CHANNEL_NUMS: int = 66
FRAME_POINTS: int = 5
FILTER_OUT_CHAN: int = 64
FILTER_FRAME_POINTS: int = 50
# 数据类型 & 字节数 (float64 8字节)
DATA_DTYPE: np.dtype = np.float64
RAW_FRAME_BYTES: int = CHANNEL_NUMS * FRAME_POINTS * 8 # 66*5*8 = 2640
FILTER_FRAME_BYTES: int = FILTER_OUT_CHAN * FILTER_FRAME_POINTS * 8 # 25600
# 事件通道索引
EVENT_CHANNEL_IDX: int = -2
# 列表类型 使用 default_factory 规避可变默认值报错
EVENT_TAGS: List[int] = field(default_factory=lambda: [1, 2, 99])
SIM_SIGNAL_FREQ: List[float] = field(default_factory=lambda: [8.0, 9.0])
# 仿真噪声
NOISE_STD: float = 0.25
# 可视化配置
PLOT_TARGET_CHAN: int = 0
PLOT_WINDOW_LEN: int = 400
PLOT_REFRESH_INTERVAL: int = 50
# 日志限流
FRAME_ERR_INTERVAL: float = 3.0
# ZMQ 配置
SEND_RETRY_MAX: int = 3
SEND_RETRY_SLEEP: float = 0.01
ZMQ_HWM: int = 1000
# 初始化全局配置
CONFIG = TestConfig()
# ===================== 2. 全局状态管理 =====================
class GlobalState:
def __init__(self):
self.run_flag: bool = True
self.last_frame_err_time: float = 0.0
GLOBAL_STATE = GlobalState()
# ===================== 3. Matplotlib 中文初始化 =====================
def init_matplotlib():
# Windows 黑体Linux/Mac 自行替换字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False # 修复负号乱码
init_matplotlib()
# ===================== 4. ZMQ DEALER 客户端 =====================
class ZmqDealerClient:
"""适配 ROUTER 的 DEALER 客户端,高频流式数据专用"""
def __init__(self, server_ip: str, port: int):
self.ctx: zmq.Context = zmq.Context()
self.socket: zmq.Socket = self.ctx.socket(zmq.DEALER)
self._configure_socket()
self.socket.connect(f"tcp://{server_ip}:{port}")
def _configure_socket(self):
"""套接字参数配置"""
self.socket.setsockopt(zmq.RCVHWM, CONFIG.ZMQ_HWM)
self.socket.setsockopt(zmq.SNDHWM, CONFIG.ZMQ_HWM)
self.socket.setsockopt(zmq.RCVTIMEO, 0)
self.socket.setsockopt(zmq.SNDTIMEO, 0)
def send_json(self, data: Dict) -> bool:
"""发送JSON命令带重试机制"""
try:
payload = json.dumps(data, ensure_ascii=False).encode("utf-8")
except Exception as e:
print(f"[JSON序列化失败] {e}")
return False
for _ in range(CONFIG.SEND_RETRY_MAX):
try:
self.socket.send_multipart([b"", payload])
return True
except zmq.Again:
time.sleep(CONFIG.SEND_RETRY_SLEEP)
except Exception as e:
print(f"[JSON发送异常] {e}")
time.sleep(CONFIG.SEND_RETRY_SLEEP)
print(f"[JSON发送重试失败]")
return False
def send_bytes(self, data: bytes) -> bool:
"""发送二进制脑电数据,带重试"""
for _ in range(CONFIG.SEND_RETRY_MAX):
try:
self.socket.send_multipart([b"", data])
return True
except zmq.Again:
time.sleep(CONFIG.SEND_RETRY_SLEEP)
except Exception as e:
print(f"[二进制发送异常] {e}")
time.sleep(CONFIG.SEND_RETRY_SLEEP)
print(f"[二进制发送重试失败]")
return False
def recv_json(self) -> Optional[Dict]:
"""接收JSON命令响应标准3帧"""
try:
frames = self.socket.recv_multipart()
if len(frames) < 3:
self._log_frame_err(f"帧数异常: {len(frames)}")
return None
payload = frames[2].decode("utf-8")
return json.loads(payload)
except json.JSONDecodeError:
self._log_frame_err("JSON解析失败")
return None
except Exception as e:
self._log_frame_err(f"接收异常: {e}")
return None
def recv_bytes(self) -> Optional[bytes]:
"""接收滤波数据兼容3/4帧格式"""
try:
frames = self.socket.recv_multipart()
frame_len = len(frames)
if frame_len == 3:
payload = frames[2]
elif frame_len == 4:
payload = frames[3]
else:
self._log_frame_err(f"帧数异常: {frame_len}")
return None
if len(payload) != CONFIG.FILTER_FRAME_BYTES:
self._log_frame_err(f"字节不匹配: 期望{CONFIG.FILTER_FRAME_BYTES}, 实际{len(payload)}")
return None
return payload
except Exception as e:
self._log_frame_err(f"数据接收异常: {e}")
return None
def _log_frame_err(self, msg: str):
"""日志限流,防止刷屏"""
now = time.time()
if now - GLOBAL_STATE.last_frame_err_time > CONFIG.FRAME_ERR_INTERVAL:
print(f"[帧异常] {msg}")
GLOBAL_STATE.last_frame_err_time = now
def close(self):
"""优雅释放ZMQ资源"""
try:
self.socket.close(linger=0)
self.ctx.term()
except Exception as e:
print(f"[资源释放异常] {e}")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
# ===================== 5. 仿真脑电数据生成 =====================
def generate_raw_eeg_frame(add_event: bool = False) -> np.ndarray:
"""生成单帧float64仿真脑电数据"""
t = np.linspace(
0, CONFIG.FRAME_POINTS / CONFIG.SAMPLE_RATE,
CONFIG.FRAME_POINTS, endpoint=False
)
eeg_frame = np.zeros(
(CONFIG.CHANNEL_NUMS, CONFIG.FRAME_POINTS),
dtype=CONFIG.DATA_DTYPE
)
# 模拟脑电信号 + 高斯噪声
for freq in CONFIG.SIM_SIGNAL_FREQ:
eeg_frame[:CONFIG.FILTER_OUT_CHAN] += np.sin(2 * np.pi * freq * t)
eeg_frame[:CONFIG.FILTER_OUT_CHAN] += np.random.normal(
0, CONFIG.NOISE_STD,
size=(CONFIG.FILTER_OUT_CHAN, CONFIG.FRAME_POINTS)
)
# 事件通道处理
eeg_frame[CONFIG.EVENT_CHANNEL_IDX] = 0.0
if add_event:
event_pos = np.random.randint(0, CONFIG.FRAME_POINTS)
eeg_frame[CONFIG.EVENT_CHANNEL_IDX, event_pos] = np.random.choice(CONFIG.EVENT_TAGS)
# 预留通道置0
eeg_frame[-1] = 0.0
return eeg_frame
# ===================== 6. 后台工作线程 =====================
def start_cmd_response_thread(cmd_client: ZmqDealerClient):
"""命令响应接收线程"""
print("[线程-命令接收] 已启动")
while GLOBAL_STATE.run_flag:
msg = cmd_client.recv_json()
if msg:
print(f"\n【命令响应】{json.dumps(msg, ensure_ascii=False, indent=2)}")
time.sleep(0.01)
print("[线程-命令接收] 已退出")
def start_raw_eeg_send_thread(data_client: ZmqDealerClient):
"""原始脑电发送线程20ms/帧)"""
print(f"[线程-原始数据发送] 20ms/帧 | 单帧{CONFIG.RAW_FRAME_BYTES}字节 | float64")
frame_count = 0
while GLOBAL_STATE.run_flag:
insert_event = (frame_count % 20 == 0)
eeg_frame = generate_raw_eeg_frame(add_event=insert_event)
frame_bytes = eeg_frame.tobytes()
# 字节校验
if len(frame_bytes) != CONFIG.RAW_FRAME_BYTES:
print(f"[字节警告] 期望{CONFIG.RAW_FRAME_BYTES}, 实际{len(frame_bytes)}")
time.sleep(CONFIG.SEND_INTERVAL)
frame_count += 1
continue
data_client.send_bytes(frame_bytes)
frame_count += 1
time.sleep(CONFIG.SEND_INTERVAL)
print("[线程-原始数据发送] 已退出")
def start_filter_data_recv_thread(data_client: ZmqDealerClient, plot_queue: List[np.ndarray]):
"""滤波数据接收线程"""
print(f"[线程-滤波数据接收] 单包{CONFIG.FILTER_FRAME_BYTES}字节 | float64")
while GLOBAL_STATE.run_flag:
raw_bytes = data_client.recv_bytes()
if not raw_bytes:
time.sleep(0.01)
continue
try:
filter_arr = np.frombuffer(raw_bytes, dtype=CONFIG.DATA_DTYPE)
filter_arr = filter_arr.reshape(CONFIG.FILTER_FRAME_POINTS, CONFIG.FILTER_OUT_CHAN)
plot_queue.append(filter_arr[:, CONFIG.PLOT_TARGET_CHAN])
except Exception as e:
print(f"[滤波数据解析异常] {e}")
continue
print("[线程-滤波数据接收] 已退出")
# ===================== 7. 实时波形可视化 =====================
def start_wave_visualization(plot_queue: List[np.ndarray]):
"""启动实时滤波波形绘图"""
fig, ax = plt.subplots(figsize=(14, 4))
x_axis = np.arange(0, CONFIG.PLOT_WINDOW_LEN)
wave_data = np.zeros(CONFIG.PLOT_WINDOW_LEN, dtype=CONFIG.DATA_DTYPE)
line, = ax.plot(x_axis, wave_data, color="#2E86AB", linewidth=1.2)
ax.set_title(
f"实时滤波脑电波形 | 通道 {CONFIG.PLOT_TARGET_CHAN} | {CONFIG.SAMPLE_RATE}Hz | float64",
fontsize=12
)
ax.set_ylim(-3.0, 3.0)
ax.grid(True, alpha=0.3, linestyle="--")
plt.tight_layout()
def update_plot(_):
nonlocal wave_data
if plot_queue:
new_wave = plot_queue.pop(0)
wave_data = np.roll(wave_data, -len(new_wave))
wave_data[-len(new_wave)] = new_wave
line.set_ydata(wave_data)
return (line,)
ani = FuncAnimation(
fig, update_plot,
interval=CONFIG.PLOT_REFRESH_INTERVAL,
blit=True,
cache_frame_data=False
)
plt.show()
# ===================== 8. 全量业务测试用例 =====================
def run_full_test_cases(cmd_client: ZmqDealerClient):
"""全覆盖 zmqServer 所有命令sync/targetFreqs/decoderClass/impedance/train/predict/rest"""
print("\n" + "="*60)
print("开始执行全量命令测试用例")
print("="*60)
time.sleep(2)
# 1. 同步命令
print("\n[用例 1] 发送 sync 命令")
cmd_client.send_json({"method": "sync", "params": {}})
time.sleep(1)
# 2. 设置目标频率
print("\n[用例 2] 发送 targetFreqs = [8.0, 9.0]")
cmd_client.send_json({"method": "targetFreqs", "params": [8.0, 9.0]})
time.sleep(1)
# 3. 切换解码器
print("\n[用例 3] 切换解码器为 ssmvep")
cmd_client.send_json({"method": "decoderClass", "params": "ssmvep"})
time.sleep(2)
print("\n[用例 3-2] 切换解码器为 mi")
cmd_client.send_json({"method": "decoderClass", "params": "mi"})
time.sleep(2)
# 4. 阻抗检测开关
print("\n[用例 4] 开启阻抗检测 impedance=1")
cmd_client.send_json({"method": "impedance", "params": 1})
time.sleep(1)
print("\n[用例 4-2] 关闭阻抗检测 impedance=2")
cmd_client.send_json({"method": "impedance", "params": 2})
time.sleep(1)
# 5. 训练模式
print("\n[用例 5] 启动训练 train标签=1")
cmd_client.send_json({"method": "train", "params": 1})
time.sleep(3)
# # 6. 休息模式
# print("\n[用例 6] 切换 rest 休息模式")
# cmd_client.send_json({"method": "rest", "params": {}})
# time.sleep(1)
# 7. 启动解码
print("\n[用例 7] 启动解码 predict=1")
cmd_client.send_json({"method": "predict", "params": 1})
time.sleep(4)
# # 8. 非法命令(异常测试)
# print("\n[用例 8] 发送非法命令 test_cmd_illegal")
# cmd_client.send_json({"method": "test_cmd_illegal", "params": {}})
# time.sleep(1)
# # 9. 停止解码
# print("\n[用例 9] 停止解码 predict=2")
# cmd_client.send_json({"method": "predict", "params": 2})
# time.sleep(2)
print("\n" + "="*60)
print("所有测试用例执行完毕")
print("="*60)
# ===================== 主程序入口(修复线程语法) =====================
if __name__ == "__main__":
print("="*60)
print("ZMQ 脑电仿真测试工具 启动")
print(f"命令端口: {CONFIG.CMD_PORT} | 数据端口: {CONFIG.DATA_PORT}")
print(f"原始帧{CONFIG.RAW_FRAME_BYTES}字节 | 滤波帧{CONFIG.FILTER_FRAME_BYTES}字节 | float64")
print("="*60)
try:
with ZmqDealerClient(CONFIG.SERVER_IP, CONFIG.CMD_PORT) as cmd_client, \
ZmqDealerClient(CONFIG.SERVER_IP, CONFIG.DATA_PORT) as data_client:
plot_queue = []
# ========== 重点修复线程语法daemon 移出 args ==========
# 命令接收线程
t_cmd = threading.Thread(
target=start_cmd_response_thread,
args=(cmd_client,), # 单元素元组保留逗号
daemon=True
)
# 原始数据发送线程
t_eeg = threading.Thread(
target=start_raw_eeg_send_thread,
args=(data_client,),
daemon=True
)
# 滤波数据接收线程
t_filter = threading.Thread(
target=start_filter_data_recv_thread,
args=(data_client, plot_queue),
daemon=True
)
# 启动线程
t_cmd.start()
t_eeg.start()
t_filter.start()
# 执行测试用例
run_full_test_cases(cmd_client)
# 启动可视化(阻塞主线程)
print("\n[提示] 波形窗口已启动,关闭窗口 / Ctrl+C 退出程序")
start_wave_visualization(plot_queue)
except KeyboardInterrupt:
print("\n\n[用户中断] 接收到 Ctrl+C准备退出...")
except Exception as e:
print(f"\n[程序异常] {e}")
finally:
# 停止所有后台线程
GLOBAL_STATE.run_flag = False
time.sleep(0.2)
print("程序已安全退出")