del train

This commit is contained in:
2026-06-09 10:57:28 +08:00
parent f47e7d914f
commit 07560304ca
6 changed files with 537 additions and 71 deletions

View File

@@ -96,7 +96,7 @@ class Decoder_main(threading.Thread):
elif decoder_class == 'ssmvep': elif decoder_class == 'ssmvep':
self.zmqServer.interval_init(decoder_class) self.zmqServer.interval_init(decoder_class)
self.n_chan = 8 self.n_chan = 8
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位 self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
self.single_train = 10 # 单类别数量 self.single_train = 10 # 单类别数量
self.num_target = 2 # 分类目标数目 self.num_target = 2 # 分类目标数目
@@ -268,16 +268,16 @@ class Decoder_main(threading.Thread):
'''训练阶段采集数据''' '''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train': # 训练状态 if self.zmqServer.state_mode == 'train': # 训练状态
if self.zmqServer.StartTrain:
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
self.train_epoch[1] + self.zmqServer.event_inner_idx:
self.currentLabel = self.zmqServer.currentLabel self.currentLabel = self.zmqServer.currentLabel
self.zmqServer.StartTrain = False
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.train_epoch[1] \
+ self.zmqServer.event_inner_idx:
time.sleep(0.0001)
return
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount()) print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据 trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx]) print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx])
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理 trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[ trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
@@ -288,6 +288,9 @@ class Decoder_main(threading.Thread):
and self.trainLabel.count(self.currentLabel) < self.single_train: and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial) self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel) self.trainLabel.append(self.currentLabel)
else:
time.sleep(0.0001)
return
elif self.zmqServer.state_mode == 'predict': # 测试状态 elif self.zmqServer.state_mode == 'predict': # 测试状态
if self.load_model == False: # 模型尚未训练完成 if self.load_model == False: # 模型尚未训练完成

View File

@@ -19,3 +19,4 @@ source activate 3in1Py310
python runDecoder.py python runDecoder.py
python datamock.py python datamock.py
python ZeroMQClient_mock.py python ZeroMQClient_mock.py
python system_test.py

View File

@@ -21,6 +21,10 @@ class zmqServer(threading.Thread):
self.device_info = device_info self.device_info = device_info
self.host = host self.host = host
test_host = "10.200.27.140"
self.host = test_host
self.cmd_port = cmd_port # 命令交互端口收JSON命令 + 返JSON结果 self.cmd_port = cmd_port # 命令交互端口收JSON命令 + 返JSON结果
self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果 self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果
self.running = False self.running = False
@@ -105,14 +109,14 @@ class zmqServer(threading.Thread):
def interval_init(self, decoder_class): def interval_init(self, decoder_class):
if decoder_class == 'ssmvep': if decoder_class == 'ssmvep':
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # [50, 550]
self.train_epoch = [ self.train_epoch = [
int(self.interval_epoch[0]), int(self.interval_epoch[0]),
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])
] ] # [50, 575]
self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #115包, 575个点
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #120包 600个点
elif decoder_class == 'mi': elif decoder_class == 'mi':
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
@@ -246,8 +250,6 @@ class zmqServer(threading.Thread):
self.decoder_switch = True self.decoder_switch = True
elif method == "train": elif method == "train":
self.state_mode = 'train' self.state_mode = 'train'
self.StartTrain = True
self.currentLabel = params
elif method == "predict": elif method == "predict":
self.state_mode = 'predict' self.state_mode = 'predict'
if params == 1: #开始解码 if params == 1: #开始解码
@@ -322,9 +324,9 @@ class zmqServer(threading.Thread):
def detect_event(self, samples): def detect_event(self, samples):
self.pack_contain_event = False self.pack_contain_event = False
# 第65通道为事件通道 # 第65通道为事件通道
events = samples[-2].tolist() event = int(samples[-2][0])
for idx, event in enumerate(events): # for idx, event in enumerate(events):
if int(event) in self.events: if event in self.events:
new_key = "".join( new_key = "".join(
[ [
str(event), str(event),
@@ -332,11 +334,13 @@ class zmqServer(threading.Thread):
-%H-%M-%S"), -%H-%M-%S"),
] ]
) )
self.currentLabel = event
if event == self.predict_event: if event == self.predict_event:
self.count_events[new_key] = self.latency + 1 self.count_events[new_key] = self.latency + 1
else: else:
self.count_events[new_key] = self.train_latency + 1 self.count_events[new_key] = self.train_latency + 1
self.event_inner_idx = idx self.event_inner_idx = self.device_info['frame_points'] - 1
# algo_log(f"事件检测到: {event},索引: {idx}", level="DEBUG")
self.pack_contain_event = True self.pack_contain_event = True
# 倒计时并清理过期事件 # 倒计时并清理过期事件

View File

@@ -11,6 +11,7 @@ N_CHAN = 66 # 通道数: 64 EEG + 1 标签值 + 1 标签序号
EEG_FREQ = 10 # EEG 正弦波频率 Hz EEG_FREQ = 10 # EEG 正弦波频率 Hz
EEG_AMP = 100.0 # EEG 幅值 100μV EEG_AMP = 100.0 # EEG 幅值 100μV
LABEL_INTERVAL = 5 # 标签间隔秒数 LABEL_INTERVAL = 5 # 标签间隔秒数
# SERVER_ADDR = 'tcp://127.0.0.1:8100'
SERVER_ADDR = 'tcp://127.0.0.1:8100' SERVER_ADDR = 'tcp://127.0.0.1:8100'
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms # 发送间隔: 每包 5 采样点 / 250Hz = 20ms

View File

@@ -1,24 +1,54 @@
import os import os
from datetime import datetime from datetime import datetime, timedelta
import logging import logging
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
import inspect # 新增导入 import inspect
from PubLibrary.InifileHelper import IniRead from PubLibrary.InifileHelper import IniRead
# 全局配置
console_output = IniRead('system', 'console_output', '1') console_output = IniRead('system', 'console_output', '1')
log_level = IniRead('system', 'algo_log_level', 'INFO') log_level = IniRead('system', 'algo_log_level', 'INFO')
log_once_cache = set() log_once_cache = set()
# 缓存已经创建过的logger避免重复创建handler
logger_cache = {} logger_cache = {}
LOG_RETENTION_DAYS = 3
LOG_DIR = './logs/'
LOG_FILE_PREFIX = 'algo_log_'
# 日志格式:时间 - 日志器名 - 级别 - 文件名:行号 - 函数名 - 日志内容
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
def clean_old_logs():
"""清理超过指定天数的旧日志文件"""
try:
if not os.path.exists(LOG_DIR):
return
expire_date = datetime.now() - timedelta(days=LOG_RETENTION_DAYS)
for filename in os.listdir(LOG_DIR):
if not filename.startswith(LOG_FILE_PREFIX) or not filename.endswith('.log'):
continue
date_str = filename[len(LOG_FILE_PREFIX):-4]
try:
file_date = datetime.strptime(date_str, '%Y-%m-%d')
if file_date < expire_date:
file_path = os.path.join(LOG_DIR, filename)
os.remove(file_path)
print(f"清理过期日志: {file_path}")
except ValueError:
continue
except Exception as e:
print(f"清理旧日志异常: {str(e)}")
def init_module_logger(logger_name): def init_module_logger(logger_name):
log_dir = './logs/' """初始化日志器 + 清理旧日志"""
os.makedirs(log_dir, exist_ok=True) os.makedirs(LOG_DIR, exist_ok=True)
log_file = os.path.join(log_dir, f'algo_log_{datetime.now().strftime("%Y-%m-%d")}.log') clean_old_logs()
current_date = datetime.now().strftime("%Y-%m-%d")
log_file = os.path.join(LOG_DIR, f"{LOG_FILE_PREFIX}{current_date}.log")
# 已创建直接返回
if logger_name in logger_cache: if logger_name in logger_cache:
return logger_cache[logger_name] return logger_cache[logger_name]
@@ -28,19 +58,18 @@ def init_module_logger(logger_name):
logger_cache[logger_name] = logger logger_cache[logger_name] = logger
return logger return logger
# 文件输出处理器
file_handler = RotatingFileHandler( file_handler = RotatingFileHandler(
log_file, log_file,
maxBytes=10 * 1024 * 1024, maxBytes=10 * 1024 * 1024,
backupCount=10, backupCount=10,
encoding='utf-8' encoding='utf-8'
) )
formatter = logging.Formatter( formatter = logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
logger.addHandler(file_handler) logger.addHandler(file_handler)
# 控制台输出
if console_output: if console_output:
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
@@ -51,29 +80,35 @@ def init_module_logger(logger_name):
def algo_log(content, level="INFO", record_once=False): def algo_log(content, level="INFO", record_once=False):
# 向上回溯1层栈拿到调用algo_log的代码文件信息 """
frame = inspect.currentframe().f_back 日志入口函数
file_path = frame.f_code.co_filename 自动记录:调用文件名、代码行号、所在函数
# 提取py文件名不带后缀/带后缀自选) """
file_name = os.path.basename(file_path) # 例zmqServer.py # 回溯栈帧,获取真正调用 algo_log 的代码位置
# file_name = os.path.splitext(os.path.basename(file_path))[0] # 例zmqServer # f_back(1) -> algo_log 自身f_back(2) -> 业务调用处
frame = inspect.currentframe().f_back.f_back
if not frame:
file_name = "unknown"
else:
file_name = os.path.basename(frame.f_code.co_filename)
logger = init_module_logger(file_name) logger = init_module_logger(file_name)
# 单次日志去重
if record_once: if record_once:
log_key = f"{level.upper()}_{content}" log_key = f"{level.upper()}_{content}"
if log_key in log_once_cache: if log_key in log_once_cache:
return return
log_once_cache.add(log_key) log_once_cache.add(log_key)
# 日志级别分发
level_upper = level.upper() level_upper = level.upper()
if level_upper == "DEBUG": log_map = {
logger.debug(content) "DEBUG": logger.debug,
elif level_upper == "WARNING": "WARNING": logger.warning,
logger.warning(content) "ERROR": logger.error,
elif level_upper == "ERROR": "FATAL": logger.fatal,
logger.error(content) "INFO": logger.info
elif level_upper == "FATAL": }
logger.fatal(content) log_func = log_map.get(level_upper, logger.info)
else: log_func(content)
logger.info(content)

422
system_test.py Normal file
View File

@@ -0,0 +1,422 @@
# -*- coding: utf-8 -*-
"""
ZMQ 脑电数据测试工具【语法错误修复版】
修复点:
1. dataclass 可变列表默认值报错
2. threading.Thread daemon 参数语法错误
适配Python3.10、全链路 float64、ZMQ DEALER<->ROUTER
端口8099(命令) / 8100(数据)
"""
import zmq
import time
import threading
import numpy as np
import matplotlib.pyplot as plt
import json
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union, Tuple
from matplotlib.animation import FuncAnimation
# ===================== 1. 配置管理 =====================
@dataclass(frozen=True) # 冻结配置类
class TestConfig:
# 网络配置
SERVER_IP: str = "127.0.0.1"
CMD_PORT: int = 8099
DATA_PORT: int = 8100
# 硬件与时序
SAMPLE_RATE: int = 250
FRAME_INTERVAL_MS: int = 20
SEND_INTERVAL: float = FRAME_INTERVAL_MS / 1000
CHANNEL_NUMS: int = 66
FRAME_POINTS: int = 5
FILTER_OUT_CHAN: int = 64
FILTER_FRAME_POINTS: int = 50
# 数据类型 & 字节数 (float64 8字节)
DATA_DTYPE: np.dtype = np.float64
RAW_FRAME_BYTES: int = CHANNEL_NUMS * FRAME_POINTS * 8 # 66*5*8 = 2640
FILTER_FRAME_BYTES: int = FILTER_OUT_CHAN * FILTER_FRAME_POINTS * 8 # 25600
# 事件通道索引
EVENT_CHANNEL_IDX: int = -2
# 列表类型 使用 default_factory 规避可变默认值报错
EVENT_TAGS: List[int] = field(default_factory=lambda: [1, 2, 99])
SIM_SIGNAL_FREQ: List[float] = field(default_factory=lambda: [8.0, 9.0])
# 仿真噪声
NOISE_STD: float = 0.25
# 可视化配置
PLOT_TARGET_CHAN: int = 0
PLOT_WINDOW_LEN: int = 400
PLOT_REFRESH_INTERVAL: int = 50
# 日志限流
FRAME_ERR_INTERVAL: float = 3.0
# ZMQ 配置
SEND_RETRY_MAX: int = 3
SEND_RETRY_SLEEP: float = 0.01
ZMQ_HWM: int = 1000
# 初始化全局配置
CONFIG = TestConfig()
# ===================== 2. 全局状态管理 =====================
class GlobalState:
def __init__(self):
self.run_flag: bool = True
self.last_frame_err_time: float = 0.0
GLOBAL_STATE = GlobalState()
# ===================== 3. Matplotlib 中文初始化 =====================
def init_matplotlib():
# Windows 黑体Linux/Mac 自行替换字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False # 修复负号乱码
init_matplotlib()
# ===================== 4. ZMQ DEALER 客户端 =====================
class ZmqDealerClient:
"""适配 ROUTER 的 DEALER 客户端,高频流式数据专用"""
def __init__(self, server_ip: str, port: int):
self.ctx: zmq.Context = zmq.Context()
self.socket: zmq.Socket = self.ctx.socket(zmq.DEALER)
self._configure_socket()
self.socket.connect(f"tcp://{server_ip}:{port}")
def _configure_socket(self):
"""套接字参数配置"""
self.socket.setsockopt(zmq.RCVHWM, CONFIG.ZMQ_HWM)
self.socket.setsockopt(zmq.SNDHWM, CONFIG.ZMQ_HWM)
self.socket.setsockopt(zmq.RCVTIMEO, 0)
self.socket.setsockopt(zmq.SNDTIMEO, 0)
def send_json(self, data: Dict) -> bool:
"""发送JSON命令带重试机制"""
try:
payload = json.dumps(data, ensure_ascii=False).encode("utf-8")
except Exception as e:
print(f"[JSON序列化失败] {e}")
return False
for _ in range(CONFIG.SEND_RETRY_MAX):
try:
self.socket.send_multipart([b"", payload])
return True
except zmq.Again:
time.sleep(CONFIG.SEND_RETRY_SLEEP)
except Exception as e:
print(f"[JSON发送异常] {e}")
time.sleep(CONFIG.SEND_RETRY_SLEEP)
print(f"[JSON发送重试失败]")
return False
def send_bytes(self, data: bytes) -> bool:
"""发送二进制脑电数据,带重试"""
for _ in range(CONFIG.SEND_RETRY_MAX):
try:
self.socket.send_multipart([b"", data])
return True
except zmq.Again:
time.sleep(CONFIG.SEND_RETRY_SLEEP)
except Exception as e:
print(f"[二进制发送异常] {e}")
time.sleep(CONFIG.SEND_RETRY_SLEEP)
print(f"[二进制发送重试失败]")
return False
def recv_json(self) -> Optional[Dict]:
"""接收JSON命令响应标准3帧"""
try:
frames = self.socket.recv_multipart()
if len(frames) < 3:
self._log_frame_err(f"帧数异常: {len(frames)}")
return None
payload = frames[2].decode("utf-8")
return json.loads(payload)
except json.JSONDecodeError:
self._log_frame_err("JSON解析失败")
return None
except Exception as e:
self._log_frame_err(f"接收异常: {e}")
return None
def recv_bytes(self) -> Optional[bytes]:
"""接收滤波数据兼容3/4帧格式"""
try:
frames = self.socket.recv_multipart()
frame_len = len(frames)
if frame_len == 3:
payload = frames[2]
elif frame_len == 4:
payload = frames[3]
else:
self._log_frame_err(f"帧数异常: {frame_len}")
return None
if len(payload) != CONFIG.FILTER_FRAME_BYTES:
self._log_frame_err(f"字节不匹配: 期望{CONFIG.FILTER_FRAME_BYTES}, 实际{len(payload)}")
return None
return payload
except Exception as e:
self._log_frame_err(f"数据接收异常: {e}")
return None
def _log_frame_err(self, msg: str):
"""日志限流,防止刷屏"""
now = time.time()
if now - GLOBAL_STATE.last_frame_err_time > CONFIG.FRAME_ERR_INTERVAL:
print(f"[帧异常] {msg}")
GLOBAL_STATE.last_frame_err_time = now
def close(self):
"""优雅释放ZMQ资源"""
try:
self.socket.close(linger=0)
self.ctx.term()
except Exception as e:
print(f"[资源释放异常] {e}")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
# ===================== 5. 仿真脑电数据生成 =====================
def generate_raw_eeg_frame(add_event: bool = False) -> np.ndarray:
"""生成单帧float64仿真脑电数据"""
t = np.linspace(
0, CONFIG.FRAME_POINTS / CONFIG.SAMPLE_RATE,
CONFIG.FRAME_POINTS, endpoint=False
)
eeg_frame = np.zeros(
(CONFIG.CHANNEL_NUMS, CONFIG.FRAME_POINTS),
dtype=CONFIG.DATA_DTYPE
)
# 模拟脑电信号 + 高斯噪声
for freq in CONFIG.SIM_SIGNAL_FREQ:
eeg_frame[:CONFIG.FILTER_OUT_CHAN] += np.sin(2 * np.pi * freq * t)
eeg_frame[:CONFIG.FILTER_OUT_CHAN] += np.random.normal(
0, CONFIG.NOISE_STD,
size=(CONFIG.FILTER_OUT_CHAN, CONFIG.FRAME_POINTS)
)
# 事件通道处理
eeg_frame[CONFIG.EVENT_CHANNEL_IDX] = 0.0
if add_event:
event_pos = np.random.randint(0, CONFIG.FRAME_POINTS)
eeg_frame[CONFIG.EVENT_CHANNEL_IDX, event_pos] = np.random.choice(CONFIG.EVENT_TAGS)
# 预留通道置0
eeg_frame[-1] = 0.0
return eeg_frame
# ===================== 6. 后台工作线程 =====================
def start_cmd_response_thread(cmd_client: ZmqDealerClient):
"""命令响应接收线程"""
print("[线程-命令接收] 已启动")
while GLOBAL_STATE.run_flag:
msg = cmd_client.recv_json()
if msg:
print(f"\n【命令响应】{json.dumps(msg, ensure_ascii=False, indent=2)}")
time.sleep(0.01)
print("[线程-命令接收] 已退出")
def start_raw_eeg_send_thread(data_client: ZmqDealerClient):
"""原始脑电发送线程20ms/帧)"""
print(f"[线程-原始数据发送] 20ms/帧 | 单帧{CONFIG.RAW_FRAME_BYTES}字节 | float64")
frame_count = 0
while GLOBAL_STATE.run_flag:
insert_event = (frame_count % 20 == 0)
eeg_frame = generate_raw_eeg_frame(add_event=insert_event)
frame_bytes = eeg_frame.tobytes()
# 字节校验
if len(frame_bytes) != CONFIG.RAW_FRAME_BYTES:
print(f"[字节警告] 期望{CONFIG.RAW_FRAME_BYTES}, 实际{len(frame_bytes)}")
time.sleep(CONFIG.SEND_INTERVAL)
frame_count += 1
continue
data_client.send_bytes(frame_bytes)
frame_count += 1
time.sleep(CONFIG.SEND_INTERVAL)
print("[线程-原始数据发送] 已退出")
def start_filter_data_recv_thread(data_client: ZmqDealerClient, plot_queue: List[np.ndarray]):
"""滤波数据接收线程"""
print(f"[线程-滤波数据接收] 单包{CONFIG.FILTER_FRAME_BYTES}字节 | float64")
while GLOBAL_STATE.run_flag:
raw_bytes = data_client.recv_bytes()
if not raw_bytes:
time.sleep(0.01)
continue
try:
filter_arr = np.frombuffer(raw_bytes, dtype=CONFIG.DATA_DTYPE)
filter_arr = filter_arr.reshape(CONFIG.FILTER_FRAME_POINTS, CONFIG.FILTER_OUT_CHAN)
plot_queue.append(filter_arr[:, CONFIG.PLOT_TARGET_CHAN])
except Exception as e:
print(f"[滤波数据解析异常] {e}")
continue
print("[线程-滤波数据接收] 已退出")
# ===================== 7. 实时波形可视化 =====================
def start_wave_visualization(plot_queue: List[np.ndarray]):
"""启动实时滤波波形绘图"""
fig, ax = plt.subplots(figsize=(14, 4))
x_axis = np.arange(0, CONFIG.PLOT_WINDOW_LEN)
wave_data = np.zeros(CONFIG.PLOT_WINDOW_LEN, dtype=CONFIG.DATA_DTYPE)
line, = ax.plot(x_axis, wave_data, color="#2E86AB", linewidth=1.2)
ax.set_title(
f"实时滤波脑电波形 | 通道 {CONFIG.PLOT_TARGET_CHAN} | {CONFIG.SAMPLE_RATE}Hz | float64",
fontsize=12
)
ax.set_ylim(-3.0, 3.0)
ax.grid(True, alpha=0.3, linestyle="--")
plt.tight_layout()
def update_plot(_):
nonlocal wave_data
if plot_queue:
new_wave = plot_queue.pop(0)
wave_data = np.roll(wave_data, -len(new_wave))
wave_data[-len(new_wave)] = new_wave
line.set_ydata(wave_data)
return (line,)
ani = FuncAnimation(
fig, update_plot,
interval=CONFIG.PLOT_REFRESH_INTERVAL,
blit=True,
cache_frame_data=False
)
plt.show()
# ===================== 8. 全量业务测试用例 =====================
def run_full_test_cases(cmd_client: ZmqDealerClient):
"""全覆盖 zmqServer 所有命令sync/targetFreqs/decoderClass/impedance/train/predict/rest"""
print("\n" + "="*60)
print("开始执行全量命令测试用例")
print("="*60)
time.sleep(2)
# 1. 同步命令
print("\n[用例 1] 发送 sync 命令")
cmd_client.send_json({"method": "sync", "params": {}})
time.sleep(1)
# 2. 设置目标频率
print("\n[用例 2] 发送 targetFreqs = [8.0, 9.0]")
cmd_client.send_json({"method": "targetFreqs", "params": [8.0, 9.0]})
time.sleep(1)
# 3. 切换解码器
print("\n[用例 3] 切换解码器为 ssmvep")
cmd_client.send_json({"method": "decoderClass", "params": "ssmvep"})
time.sleep(2)
print("\n[用例 3-2] 切换解码器为 mi")
cmd_client.send_json({"method": "decoderClass", "params": "mi"})
time.sleep(2)
# 4. 阻抗检测开关
print("\n[用例 4] 开启阻抗检测 impedance=1")
cmd_client.send_json({"method": "impedance", "params": 1})
time.sleep(1)
print("\n[用例 4-2] 关闭阻抗检测 impedance=2")
cmd_client.send_json({"method": "impedance", "params": 2})
time.sleep(1)
# 5. 训练模式
print("\n[用例 5] 启动训练 train标签=1")
cmd_client.send_json({"method": "train", "params": 1})
time.sleep(3)
# # 6. 休息模式
# print("\n[用例 6] 切换 rest 休息模式")
# cmd_client.send_json({"method": "rest", "params": {}})
# time.sleep(1)
# 7. 启动解码
print("\n[用例 7] 启动解码 predict=1")
cmd_client.send_json({"method": "predict", "params": 1})
time.sleep(4)
# # 8. 非法命令(异常测试)
# print("\n[用例 8] 发送非法命令 test_cmd_illegal")
# cmd_client.send_json({"method": "test_cmd_illegal", "params": {}})
# time.sleep(1)
# # 9. 停止解码
# print("\n[用例 9] 停止解码 predict=2")
# cmd_client.send_json({"method": "predict", "params": 2})
# time.sleep(2)
print("\n" + "="*60)
print("所有测试用例执行完毕")
print("="*60)
# ===================== 主程序入口(修复线程语法) =====================
if __name__ == "__main__":
print("="*60)
print("ZMQ 脑电仿真测试工具 启动")
print(f"命令端口: {CONFIG.CMD_PORT} | 数据端口: {CONFIG.DATA_PORT}")
print(f"原始帧{CONFIG.RAW_FRAME_BYTES}字节 | 滤波帧{CONFIG.FILTER_FRAME_BYTES}字节 | float64")
print("="*60)
try:
with ZmqDealerClient(CONFIG.SERVER_IP, CONFIG.CMD_PORT) as cmd_client, \
ZmqDealerClient(CONFIG.SERVER_IP, CONFIG.DATA_PORT) as data_client:
plot_queue = []
# ========== 重点修复线程语法daemon 移出 args ==========
# 命令接收线程
t_cmd = threading.Thread(
target=start_cmd_response_thread,
args=(cmd_client,), # 单元素元组保留逗号
daemon=True
)
# 原始数据发送线程
t_eeg = threading.Thread(
target=start_raw_eeg_send_thread,
args=(data_client,),
daemon=True
)
# 滤波数据接收线程
t_filter = threading.Thread(
target=start_filter_data_recv_thread,
args=(data_client, plot_queue),
daemon=True
)
# 启动线程
t_cmd.start()
t_eeg.start()
t_filter.start()
# 执行测试用例
run_full_test_cases(cmd_client)
# 启动可视化(阻塞主线程)
print("\n[提示] 波形窗口已启动,关闭窗口 / Ctrl+C 退出程序")
start_wave_visualization(plot_queue)
except KeyboardInterrupt:
print("\n\n[用户中断] 接收到 Ctrl+C准备退出...")
except Exception as e:
print(f"\n[程序异常] {e}")
finally:
# 停止所有后台线程
GLOBAL_STATE.run_flag = False
time.sleep(0.2)
print("程序已安全退出")