Compare commits
43 Commits
main
...
ba4ae92647
| Author | SHA1 | Date | |
|---|---|---|---|
| ba4ae92647 | |||
| 43adc6fb42 | |||
| b329989181 | |||
| 68106d8aed | |||
|
|
506ebfd973 | ||
|
|
5a2cc82100 | ||
|
|
81a8d78ab2 | ||
| 73e01782df | |||
| b78e583bec | |||
| 504e89ee47 | |||
|
|
a9dbe7261b | ||
| 7b5f4f6eb9 | |||
| 0cffd1ae02 | |||
| 0e5e79fcdd | |||
| 694321b52c | |||
| 9f034d1105 | |||
| 07560304ca | |||
| f47e7d914f | |||
| af4fb48737 | |||
| fdddc814c7 | |||
| d741e3548f | |||
| 509fc5a1d7 | |||
|
|
67587f354b | ||
|
|
d5ef2311a1 | ||
| 31d91d6cc7 | |||
| ac0de93e31 | |||
|
|
140fd9a487 | ||
| 4faeae0ff3 | |||
| 880caa9f7b | |||
| d576cae3c0 | |||
|
|
9c9b522443 | ||
|
|
540d0c361f | ||
|
|
30c690e4e3 | ||
|
|
853037726d | ||
| 8a9d9a5c78 | |||
| 29b6118f11 | |||
| fce7d93d5e | |||
| 949801198e | |||
|
|
494515463d | ||
|
|
9a655ffdeb | ||
|
|
a9fd51e935 | ||
|
|
4b7e48be38 | ||
| 2d190d6431 |
10
.gitignore
vendored
10
.gitignore
vendored
@@ -4,8 +4,11 @@ __pycache__/
|
|||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
build/
|
build/
|
||||||
dist/
|
dist/
|
||||||
|
dist_nuitka/
|
||||||
# Environments
|
upperHost_stim/
|
||||||
|
.vscode/
|
||||||
|
#!upperHost_stim/MI_headless.py
|
||||||
|
#!upperHost_stim/ssmvep_headless.py
|
||||||
.env
|
.env
|
||||||
.venv
|
.venv
|
||||||
env/
|
env/
|
||||||
@@ -24,7 +27,8 @@ venv.bak/
|
|||||||
*.xlsx
|
*.xlsx
|
||||||
*.mat
|
*.mat
|
||||||
*.json
|
*.json
|
||||||
|
*.txt
|
||||||
|
*.pth
|
||||||
|
|
||||||
# PyCharm
|
# PyCharm
|
||||||
# JetBrains specific template is maintained in a separate repository that is not distributed with PyCharm itself
|
# JetBrains specific template is maintained in a separate repository that is not distributed with PyCharm itself
|
||||||
|
|||||||
496
Decoder.py
496
Decoder.py
@@ -1,4 +1,7 @@
|
|||||||
import ast
|
import ast
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
@@ -8,58 +11,57 @@ import torch
|
|||||||
from queue import Empty
|
from queue import Empty
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from Device.SunnyLinker import SunnyLinker64
|
# from Device.SunnyLinker import SunnyLinker64
|
||||||
from SSMVEP.algorithm.tdca import TDCA
|
from SSMVEP.algorithm.tdca import TDCA
|
||||||
from SSMVEP.algorithm.base import generate_cca_references
|
from SSMVEP.algorithm.base import generate_cca_references
|
||||||
from concentration.algorithm.calculate_focus import Calculate
|
# from concentration.algorithm.calculate_focus import Calculate
|
||||||
from blinkdetection.algorithm.eye_detection import blink_detection
|
# from blinkdetection.algorithm.eye_detection import blink_detection
|
||||||
from Zmq.zmqServer import zmqServer
|
from Zmq.zmqServer import zmqServer
|
||||||
from Zmq.zmqClient import zmqClient
|
from Zmq.zmqClient import zmqClient
|
||||||
from MI.Algorithm.conformer_2class import onlineTrain
|
from MI.Algorithm.conformer_2class import onlineTrain
|
||||||
from PubLibrary.InifileHelper import IniRead
|
from PubLibrary.InifileHelper import IniRead
|
||||||
|
from logs.log import algo_log
|
||||||
from SSVEP.dwfbcca import FbccaDw
|
from SSVEP.dwfbcca import FbccaDw
|
||||||
from Tools.plot_MI_EEG import plotMain
|
# from Tools.plot_MI_EEG import plotMain
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from Zmq.filterProcess import SlidingFilter
|
||||||
|
|
||||||
class Decoder_main(threading.Thread, device_type):
|
save_train_data = int(IniRead('system', 'save_train_data', 0))
|
||||||
def __init__(self, device_type=None):
|
|
||||||
|
def get_root_path():
|
||||||
|
"""
|
||||||
|
Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录)
|
||||||
|
"""
|
||||||
|
if getattr(sys, 'frozen', False):
|
||||||
|
# 打包后:返回 exe 所在目录
|
||||||
|
return os.path.dirname(sys.executable)
|
||||||
|
else:
|
||||||
|
# 开发时:返回 py 文件所在目录
|
||||||
|
return os.path.dirname(os.path.abspath(__file__))
|
||||||
|
MODEL_FOLDER = "online_Models"
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder_main(threading.Thread):
|
||||||
|
def __init__(self, device_info=None):
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
|
self.device_info = device_info
|
||||||
self.Runing=True
|
self.Runing=True
|
||||||
self.decoder = None
|
self.decoder = None
|
||||||
|
|
||||||
self.fs = 250 # 采样率
|
|
||||||
self.energy = 0 # 电量
|
|
||||||
self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常
|
|
||||||
self.decoder_class = None #解码器类别
|
self.decoder_class = None #解码器类别
|
||||||
|
|
||||||
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
|
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
|
||||||
self.device_info = {
|
|
||||||
'device_type': None,
|
|
||||||
'sample_rate': None,
|
|
||||||
'channel_num': None,
|
|
||||||
}
|
|
||||||
|
|
||||||
def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None):
|
|
||||||
self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type'))
|
|
||||||
_device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host'))
|
|
||||||
_device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port'))
|
|
||||||
_upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host'))
|
|
||||||
_upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port'))
|
|
||||||
|
|
||||||
if self.DeviceType == 1:
|
self.zmqServer = zmqServer(device_info=self.device_info)
|
||||||
self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.fs, 64, method='tcp')
|
self.zmqServer.start() # 启动ZMQ接收线程
|
||||||
self.thread_data_server.host = _device_host
|
|
||||||
self.thread_data_server.port = _device_port
|
|
||||||
|
|
||||||
self.thread_data_server.toUv = True
|
self.sliding_filter = SlidingFilter(
|
||||||
self.thread_data_server.start()
|
ring_buffer=self.zmqServer.filterBuffer,
|
||||||
|
n_chan=self.zmqServer.device_info['channel_nums'],
|
||||||
|
srate=self.zmqServer.device_info['sample_rate']
|
||||||
|
)
|
||||||
|
|
||||||
self.zmqServer = zmqServer()
|
# 注册滤波结果回调(示例:打印数据形状)
|
||||||
self.zmqServer.start()
|
self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data
|
||||||
|
|
||||||
self.zmqClient = zmqClient(_upper_host, _upper_port)
|
|
||||||
self.zmqClient.set_zmq_server(self.zmqServer)
|
|
||||||
self.zmqClient.connect()
|
|
||||||
|
|
||||||
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
||||||
# data: (chans, samples)
|
# data: (chans, samples)
|
||||||
@@ -74,45 +76,44 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
:return:
|
:return:
|
||||||
'''
|
'''
|
||||||
self.decoder_class = decoder_class
|
self.decoder_class = decoder_class
|
||||||
if decoder_class == 'ssvep' or decoder_class == 'pvs':
|
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
|
||||||
self.n_chan = 8
|
self.n_chan = 8
|
||||||
self.thread_data_server.interval_inited = False
|
# self.thread_data_server.interval_inited = False
|
||||||
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
|
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
|
||||||
self.ListFreq = self.zmqServer.targetFreqs
|
self.ListFreq = self.zmqServer.targetFreqs
|
||||||
self.num_target = len(self.ListFreq)
|
self.num_target = len(self.ListFreq)
|
||||||
if self.num_target == 0:
|
if self.num_target == 0:
|
||||||
return
|
return
|
||||||
# 初始化对象 二代算法
|
# 初始化对象 二代算法
|
||||||
self.dw = FbccaDw(self.fs, self.num_target, self.n_chan, 5, 5,
|
self.dw = FbccaDw(self.device_info['sample_rate'], self.num_target, self.n_chan, 5, 5,
|
||||||
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
|
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
|
||||||
# frequence band
|
# frequence band
|
||||||
self.dw.filterFrequenceBank()
|
self.dw.filterFrequenceBank()
|
||||||
self.dw.setNotchFilterPara()
|
self.dw.setNotchFilterPara()
|
||||||
self.calculateCount = 0
|
self.calculateCount = 0
|
||||||
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.fs),
|
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.device_info['sample_rate']), 5)
|
||||||
5)
|
|
||||||
self.dw.filterInit()
|
self.dw.filterInit()
|
||||||
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
||||||
|
|
||||||
elif decoder_class == 'ssmvep':
|
elif decoder_class == 'ssmvep':
|
||||||
self.thread_data_server.interval_init(decoder_class)
|
self.zmqServer.interval_init(decoder_class)
|
||||||
self.n_chan = 8
|
self.n_chan = 8
|
||||||
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
|
||||||
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
||||||
self.single_train = 10 # 单类别数量
|
self.single_train = 10 # 单类别数量
|
||||||
self.num_target = 2 # 分类目标数目
|
self.num_target = 2 # 分类目标数目
|
||||||
self.list_freqs = np.array([8, 9]) # 刺激频率
|
self.list_freqs = np.array([8, 9]) # 刺激频率
|
||||||
self.list_phase = np.array([0, 0]) # 相位
|
self.list_phase = np.array([0, 0]) # 相位
|
||||||
self.tdca = TDCA(padding_len=5, n_components=1)
|
self.tdca = TDCA(padding_len=5, n_components=1)
|
||||||
self.Yf = generate_cca_references(self.list_freqs, srate=self.fs, T=self.sample_length,
|
self.Yf = generate_cca_references(self.list_freqs, srate=self.device_info['sample_rate'], T=self.sample_length,
|
||||||
phases=self.list_phase, n_harmonics=5)
|
phases=self.list_phase, n_harmonics=5)
|
||||||
self.parameter_init(5,45)
|
self.parameter_init(5,45)
|
||||||
|
|
||||||
elif decoder_class == 'mi' or decoder_class == 'ma':
|
elif decoder_class == 'mi' or decoder_class == 'ma':
|
||||||
self.thread_data_server.interval_init(decoder_class)
|
self.zmqServer.interval_init(decoder_class)
|
||||||
self.n_chan = 21
|
self.n_chan = 21
|
||||||
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
|
||||||
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度4s,# 精确到小数点后6位
|
||||||
self.single_train = 40 # 单类别数量
|
self.single_train = 40 # 单类别数量
|
||||||
self.num_target = 2 # 分类目标数目
|
self.num_target = 2 # 分类目标数目
|
||||||
|
|
||||||
@@ -124,7 +125,7 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
# self.win_len = 10
|
# self.win_len = 10
|
||||||
# self.win_step = 1
|
# self.win_step = 1
|
||||||
# self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
|
# self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
|
||||||
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.fs, self.win_len)
|
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.device_info['sample_rate'], self.win_len)
|
||||||
# self.interval_epoch = [0, 1]
|
# self.interval_epoch = [0, 1]
|
||||||
# self.parameter_init(2, 40)
|
# self.parameter_init(2, 40)
|
||||||
# # self.eegQueue moved to Calculate class
|
# # self.eegQueue moved to Calculate class
|
||||||
@@ -136,8 +137,8 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
# self.total_samples = 0 # 总采样点数
|
# self.total_samples = 0 # 总采样点数
|
||||||
# self.window_ms = 600 # 检测窗口大小 (ms)
|
# self.window_ms = 600 # 检测窗口大小 (ms)
|
||||||
# self.step_ms = 100 # 滑动步长 (ms)
|
# self.step_ms = 100 # 滑动步长 (ms)
|
||||||
# self.window_samples = int(self.window_ms * self.fs / 1000) # 150个样本点
|
# self.window_samples = int(self.window_ms * self.device_info['sample_rate'] / 1000) # 150个样本点
|
||||||
# self.step_samples = int(self.step_ms * self.fs / 1000) # 25个样本点
|
# self.step_samples = int(self.step_ms * self.device_info['sample_rate'] / 1000) # 25个样本点
|
||||||
# self.buffer_size = self.window_samples + self.step_samples * 5
|
# self.buffer_size = self.window_samples + self.step_samples * 5
|
||||||
# self.fp1_buffer = deque(maxlen=self.buffer_size)
|
# self.fp1_buffer = deque(maxlen=self.buffer_size)
|
||||||
# self.fp2_buffer = deque(maxlen=self.buffer_size)
|
# self.fp2_buffer = deque(maxlen=self.buffer_size)
|
||||||
@@ -151,11 +152,11 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
# self.double_blink_events = [] # 连续眨眼事件记录
|
# self.double_blink_events = [] # 连续眨眼事件记录
|
||||||
# self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
|
# self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
|
||||||
# self.blink_events = []
|
# self.blink_events = []
|
||||||
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.fs / 2), self.h_freq / (self.fs / 2)], btype='band')
|
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
|
||||||
|
|
||||||
def parameter_init(self,bandPass_low,bandPass_high):
|
def parameter_init(self,bandPass_low,bandPass_high):
|
||||||
self.interval_epoch = [int(i * self.fs) for i in self.interval_epoch] # epoch截取信息
|
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息
|
||||||
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.fs)] # 训练样本epoch
|
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
|
||||||
self.trainData = [] #训练数据
|
self.trainData = [] #训练数据
|
||||||
self.trainLabel = [] #训练标签
|
self.trainLabel = [] #训练标签
|
||||||
self.plotData = [] #报告分析数据
|
self.plotData = [] #报告分析数据
|
||||||
@@ -163,13 +164,15 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
self.currentLabel = -1 #刺激界面当前显示的训练标签
|
self.currentLabel = -1 #刺激界面当前显示的训练标签
|
||||||
self.train_started = False #是否开始训练模型
|
self.train_started = False #是否开始训练模型
|
||||||
self.load_model = False # 调用模型是否完成的标志
|
self.load_model = False # 调用模型是否完成的标志
|
||||||
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子
|
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.device_info['sample_rate']/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子
|
||||||
self.b_design = signal.firwin(65, [bandPass_low / (self.fs/2), bandPass_high / (self.fs/2)], pass_zero=False) # 设计8-30Hz带通滤波器
|
self.b_design = signal.firwin(65, [bandPass_low / (self.device_info['sample_rate']/2), bandPass_high / (self.device_info['sample_rate']/2)], pass_zero=False) # 设计8-30Hz带通滤波器
|
||||||
|
filePath = os.path.join(get_root_path(), MODEL_FOLDER) + os.sep
|
||||||
|
for old_pth in glob.glob(os.path.join(filePath, '*.pth')):
|
||||||
|
os.remove(old_pth)
|
||||||
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||||
filePath = './online_Models/'
|
|
||||||
self.modelPath = ''.join([filePath, fileName, '.pth'])
|
self.modelPath = ''.join([filePath, fileName, '.pth'])
|
||||||
self.mp_data_queue = mp.Queue() #多进程传参队列
|
self.mp_data_queue = mp.Queue()
|
||||||
self.mp_result_queue = mp.Queue() #多进程结果队列
|
self.mp_result_queue = mp.Queue()
|
||||||
|
|
||||||
def preprocess(self, signal_data):
|
def preprocess(self, signal_data):
|
||||||
# # 计算每行的平均值
|
# # 计算每行的平均值
|
||||||
@@ -183,8 +186,13 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
while self.Runing:
|
while self.Runing:
|
||||||
|
# 当滤波数据大于5秒时,启动滤波线程
|
||||||
|
if not self.sliding_filter.is_alive() and self.zmqServer.filterBuffer.GetDataLenCount() > self.device_info['sample_rate'] * 5:
|
||||||
|
algo_log("启动滤波线程", level="DEBUG")
|
||||||
|
self.sliding_filter.start()
|
||||||
|
|
||||||
if self.zmqServer.decoder_switch or self.zmqServer.changeTarget:
|
if self.zmqServer.decoder_switch or self.zmqServer.changeTarget:
|
||||||
print(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}")
|
algo_log(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}", level="DEBUG")
|
||||||
self.zmqServer.decoder_switch = False
|
self.zmqServer.decoder_switch = False
|
||||||
self.zmqServer.changeTarget = False
|
self.zmqServer.changeTarget = False
|
||||||
self.reset_state() # 切换前先统一清理旧状态
|
self.reset_state() # 切换前先统一清理旧状态
|
||||||
@@ -192,57 +200,9 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
|
|
||||||
# 同步信息
|
# 同步信息
|
||||||
if self.zmqServer.state_mode == 'sync':
|
if self.zmqServer.state_mode == 'sync':
|
||||||
self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
# self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
||||||
self.zmqServer.state_mode = 'rest'
|
self.zmqServer.state_mode = 'rest'
|
||||||
# 状态异常,报告上位机
|
|
||||||
if self.status_code != self.thread_data_server.status_code:
|
|
||||||
self.status_code = self.thread_data_server.status_code
|
|
||||||
self.zmqClient.send_to_all('status_code', int(self.status_code))
|
|
||||||
print('status code')
|
|
||||||
|
|
||||||
# 返回电量
|
|
||||||
if self.energy != self.thread_data_server.energy:
|
|
||||||
self.energy = self.thread_data_server.energy
|
|
||||||
self.zmqClient.send_to_all('energy', int(self.energy))
|
|
||||||
print('energy')
|
|
||||||
|
|
||||||
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
|
|
||||||
self.thread_data_server.Impedance(True)
|
|
||||||
print('Impedance')
|
|
||||||
self.zmqServer.open_Impedance = -1
|
|
||||||
elif self.zmqServer.open_Impedance == False:
|
|
||||||
self.thread_data_server.Impedance(False)
|
|
||||||
self.zmqServer.open_Impedance = -1
|
|
||||||
|
|
||||||
if self.zmqServer.get_Impedance: # 返回阻抗值
|
|
||||||
# print(self.zmqServer.get_Impedance)
|
|
||||||
# print(self.thread_data_server.GetDataLenCount())
|
|
||||||
if self.thread_data_server.GetDataLenCount() > 250:
|
|
||||||
Impe_data = self.thread_data_server.getData(250)
|
|
||||||
# 计算阻抗
|
|
||||||
imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class)
|
|
||||||
self.zmqClient.send_to_all('impedance', imps.tolist())
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
if self.zmqServer.getReport: #返回训练报告内容
|
|
||||||
self.zmqServer.getReport = False
|
|
||||||
allData = np.array(self.plotData)
|
|
||||||
allLabel = np.array(self.plotLabel) + 1
|
|
||||||
nTrials = min(len(allLabel),len(allData))
|
|
||||||
if nTrials < 30:
|
|
||||||
self.zmqClient.send_to_all('miReport',0)
|
|
||||||
else:
|
|
||||||
allData = allData[:nTrials]
|
|
||||||
allLabel = allLabel[:nTrials]
|
|
||||||
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
|
|
||||||
'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
|
|
||||||
compare_names = ['C3', 'CZ', 'C4']
|
|
||||||
miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2,
|
|
||||||
fs=self.fs)
|
|
||||||
self.zmqClient.send_to_all('miReport',miReport)
|
|
||||||
|
|
||||||
|
|
||||||
# --- 取数优先:先执行 decoder(消费环形缓冲),再处理 plot/report 等重负载 ---
|
|
||||||
try:
|
try:
|
||||||
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
|
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
|
||||||
self.decoder_SSVEP()
|
self.decoder_SSVEP()
|
||||||
@@ -250,92 +210,83 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
self.decoder_SSMVEP()
|
self.decoder_SSMVEP()
|
||||||
elif self.decoder_class == 'mi':
|
elif self.decoder_class == 'mi':
|
||||||
self.decoder_MI()
|
self.decoder_MI()
|
||||||
elif self.decoder_class == 'concentration':
|
|
||||||
self.decoder_concentration()
|
|
||||||
elif self.decoder_class == 'blink':
|
|
||||||
self.decoder_blink()
|
|
||||||
else:
|
else:
|
||||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||||
if self.thread_data_server.GetDataLenCount() < 25:
|
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
continue;
|
continue;
|
||||||
self.thread_data_server.getData(25)
|
self.zmqServer.paradigmBuffer.getData(25)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Decoder Loop Error: {e}")
|
algo_log(f"Decoder Loop Error: {e}")
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
time.sleep(0.1) # Prevent CPU spin if error is persistent
|
time.sleep(0.1) # Prevent CPU spin if error is persistent
|
||||||
|
|
||||||
def decoder_SSVEP(self):
|
def decoder_SSVEP(self):
|
||||||
if self.zmqServer.StartDecode:
|
if self.zmqServer.StartDecode:
|
||||||
self.zmqServer.StartDecode = False
|
self.zmqServer.StartDecode = False
|
||||||
self.decodingSteps = 1
|
self.decodingSteps = 1
|
||||||
self.thread_data_server.ResetAll()
|
self.zmqServer.paradigmBuffer.resetAllPara()
|
||||||
print('启动预测')
|
algo_log('启动SSVEP预测', level="DEBUG")
|
||||||
if self.thread_data_server.GetDataLenCount() < 50:
|
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50:
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
return
|
return
|
||||||
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
|
if self.zmqServer.open_Impedance: # 阻抗检测状态不解码
|
||||||
return
|
return
|
||||||
data = self.thread_data_server.getDataViaSSVEP(50)
|
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
|
||||||
|
# algo_log(f"SSVEP取出的:{data.shape}, data = {data[:20]}", level="DEBUG")
|
||||||
data = data[:self.n_chan, :]
|
data = data[:self.n_chan, :]
|
||||||
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
|
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
|
||||||
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
||||||
self.dw.warmFilter(data) # 预热
|
self.dw.warmFilter(data) # 预热
|
||||||
self.decodingSteps = 2
|
self.decodingSteps = 2
|
||||||
print('预热数据完成。开始预测')
|
algo_log('SSVEP预热数据完成。开始预测', level="DEBUG")
|
||||||
return
|
return
|
||||||
if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中
|
if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中
|
||||||
choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount)
|
choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount)
|
||||||
self.calculateCount += 1
|
self.calculateCount += 1
|
||||||
if choosenNum != -1 and self.is_valid_signal(data):
|
if choosenNum != -1 and self.is_valid_signal(data):
|
||||||
self.decodingSteps = 3
|
self.decodingSteps = 3
|
||||||
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
|
algo_log('SSVEP预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount), level="DEBUG")
|
||||||
self.calculateCount = 0
|
self.calculateCount = 0
|
||||||
if self.decodingSteps == 3: # 发送解码后的信息
|
if self.decodingSteps == 3: # 发送解码后的信息
|
||||||
self.zmqClient.send_to_all('result', int(choosenNum))
|
self.zmqServer.broadcast_message('result', int(choosenNum))
|
||||||
self.decodingSteps = 0
|
self.decodingSteps = 0
|
||||||
print('发送给界面完成。')
|
algo_log('SSVEP发送给界面完成。', level="DEBUG")
|
||||||
|
|
||||||
def decoder_SSMVEP(self):
|
def decoder_SSMVEP(self):
|
||||||
'''模型训练'''
|
'''模型训练'''
|
||||||
if self.load_model == False and all(
|
if self.load_model == False and all(
|
||||||
self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成
|
self.trainLabel.count(i) >= self.single_train for i in [1, 2]): # 模型尚未训练完成
|
||||||
self.trainData = np.array(self.trainData)
|
self.trainData = np.array(self.trainData)
|
||||||
self.trainLabel = np.array(self.trainLabel)
|
self.trainLabel = np.array(self.trainLabel)
|
||||||
print(np.shape(self.trainData), (self.trainLabel))
|
algo_log(f"开始SSMVEP模型训练,数据形状:{np.shape(self.trainData)},标签形状:{self.trainLabel.shape}", level="DEBUG")
|
||||||
# 保存多个数组到文件
|
if save_train_data == 1:
|
||||||
# np.savez('20250520_yy.npz', array1=self.trainData, array2=self.trainLabel)
|
now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
# self.decoder = self.fbtdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
|
save_path = f"{now_str}.npz"
|
||||||
|
np.savez(save_path, array1=self.trainData, array2=self.trainLabel)
|
||||||
self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
|
self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
print('模型训练完成', formatted_time)
|
algo_log(f"SSMVEP模型训练完成,时间:{formatted_time}", level="DEBUG")
|
||||||
self.load_model = True
|
self.load_model = True
|
||||||
self.zmqClient.send_to_all('paradigm', 1)
|
self.zmqServer.broadcast_message('paradigm', 1)
|
||||||
|
|
||||||
'''训练阶段采集数据'''
|
'''训练阶段采集数据'''
|
||||||
if self.zmqServer.state_mode == 'train': # 训练状态
|
if self.zmqServer.state_mode == 'train': # 训练状态
|
||||||
if self.zmqServer.StartTrain:
|
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
|
||||||
|
self.train_epoch[1] + self.zmqServer.event_inner_idx:
|
||||||
self.currentLabel = self.zmqServer.currentLabel
|
self.currentLabel = self.zmqServer.currentLabel
|
||||||
self.zmqServer.StartTrain = False
|
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
|
||||||
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
algo_log(f"取出的:{trainTrial.shape},event:{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||||
self.train_epoch[1] \
|
|
||||||
+ self.thread_data_server.event_inner_idx:
|
|
||||||
time.sleep(0.0001)
|
|
||||||
return
|
|
||||||
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
|
|
||||||
trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据
|
|
||||||
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx])
|
|
||||||
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
||||||
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.train_epoch[
|
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
|
||||||
0]:self.thread_data_server.event_inner_idx + self.train_epoch[1]]
|
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
|
||||||
print('trial: ', self.thread_data_server.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
|
||||||
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
||||||
self.trainLabel, list) \
|
self.trainLabel, list) \
|
||||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||||
self.trainData.append(trainTrial)
|
self.trainData.append(trainTrial)
|
||||||
self.trainLabel.append(self.currentLabel)
|
self.trainLabel.append(self.currentLabel)
|
||||||
|
else:
|
||||||
|
time.sleep(0.0001)
|
||||||
|
return
|
||||||
|
|
||||||
elif self.zmqServer.state_mode == 'predict': # 测试状态
|
elif self.zmqServer.state_mode == 'predict': # 测试状态
|
||||||
if self.load_model == False: # 模型尚未训练完成
|
if self.load_model == False: # 模型尚未训练完成
|
||||||
@@ -346,45 +297,47 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
self.zmqServer.StartDecode = False
|
self.zmqServer.StartDecode = False
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
print('启动预测 ', formatted_time)
|
algo_log(f"SSMVEP模型启动预测 {formatted_time}", level="DEBUG")
|
||||||
|
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||||
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
|
||||||
self.interval_epoch[1] \
|
self.interval_epoch[1] \
|
||||||
+ self.thread_data_server.event_inner_idx:
|
+ self.zmqServer.event_inner_idx:
|
||||||
|
# algo_log(f"SSMVEP模型启动预测 {self.zmqServer.epoch_finished}", level="DEBUG")
|
||||||
time.sleep(0.0001)
|
time.sleep(0.0001)
|
||||||
return
|
return
|
||||||
data = self.thread_data_server.get_SSMVEPData() # 读取全部数据
|
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据
|
||||||
print('取出的: ', data.shape, 'event: ', data[-2, self.thread_data_server.event_inner_idx])
|
algo_log(f"取出的:{data.shape}, event: {data[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||||
data = self.preprocess(data[:self.n_chan, :]) # 预处理
|
data = self.preprocess(data[:self.n_chan, :]) # 预处理
|
||||||
data = data[:,
|
data = data[:,
|
||||||
self.thread_data_server.event_inner_idx + self.interval_epoch[
|
self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||||
pad_eeg_test = np.zeros(
|
pad_eeg_test = np.zeros(
|
||||||
(data.shape[0], int((self.sample_length + 0.1) * self.fs)))
|
(data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
|
||||||
pad_eeg_test[:, :int(self.sample_length * self.fs)] = data
|
pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
|
||||||
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
|
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
|
||||||
if isinstance(choosenNum, np.ndarray):
|
if isinstance(choosenNum, np.ndarray):
|
||||||
choosenNum = choosenNum[0]
|
choosenNum = choosenNum[0]
|
||||||
print('结果:', choosenNum, 'rho: ', sorted(features_2[0]),
|
algo_log(f"结果:{choosenNum}, rho: {sorted(features_2[0])[-1] - sorted(features_2[0])[-2]}", level="DEBUG")
|
||||||
sorted(features_2[0])[-1] - sorted(features_2[0])[-2])
|
self.zmqServer.broadcast_message('result', int(choosenNum))
|
||||||
self.zmqClient.send_to_all('result', int(choosenNum))
|
algo_log("SSMVEP发送给界面完成。", level="DEBUG")
|
||||||
print('发送给界面完成。')
|
|
||||||
else: # 休息状态
|
else: # 休息状态
|
||||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||||
if self.thread_data_server.GetDataLenCount() < 25:
|
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
return
|
return
|
||||||
self.thread_data_server.getData(25)
|
self.zmqServer.paradigmBuffer.getData(25)
|
||||||
|
|
||||||
def decoder_MI(self):
|
def decoder_MI(self):
|
||||||
'''模型训练'''
|
'''模型训练'''
|
||||||
if self.train_started == False and all(
|
if self.train_started == False and all(
|
||||||
self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练
|
self.trainLabel.count(i) >= self.single_train for i in [1, 2]): # 模型尚未训练
|
||||||
self.zmqClient.send_to_all('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
|
self.zmqServer.broadcast_message('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
|
||||||
self.train_started = True
|
self.train_started = True
|
||||||
self.trainData = np.array(self.trainData)
|
self.trainData = np.array(self.trainData)
|
||||||
self.trainLabel = np.array(self.trainLabel) + 1
|
self.trainLabel = np.array(self.trainLabel)
|
||||||
# print('训练集:',np.shape(self.trainData), (self.trainLabel))
|
algo_log(f"MI开始训练,训练集:{np.shape(self.trainData)},标签shape:{np.shape(self.trainLabel)}", level="DEBUG")
|
||||||
|
if save_train_data == 1:
|
||||||
|
now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
save_path = f"{now_str}.npz"
|
||||||
|
np.savez(save_path, array1=self.trainData, array2=self.trainLabel)
|
||||||
p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型
|
p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型
|
||||||
p.start()
|
p.start()
|
||||||
self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath,
|
self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath,
|
||||||
@@ -395,7 +348,7 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
try:
|
try:
|
||||||
result = self.mp_result_queue.get_nowait()
|
result = self.mp_result_queue.get_nowait()
|
||||||
if result['status'] == 'success':
|
if result['status'] == 'success':
|
||||||
print("模型训练完成,加载新模型")
|
algo_log("MI模型训练完成,加载新模型", level="DEBUG")
|
||||||
# 调用模型
|
# 调用模型
|
||||||
self.model = torch.load(self.modelPath, weights_only=False)
|
self.model = torch.load(self.modelPath, weights_only=False)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
@@ -406,63 +359,61 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_ = self.model(warmup_data)
|
_ = self.model(warmup_data)
|
||||||
self.load_model = True
|
self.load_model = True
|
||||||
self.zmqClient.send_to_all('paradigm', 1) # 模型调用完毕,通知上位机
|
self.zmqServer.broadcast_message('paradigm', 1) # 模型调用完毕,通知上位机
|
||||||
else:
|
else:
|
||||||
print("训练失败:", result['msg'])
|
algo_log("MI训练失败: " + result['msg'], level="DEBUG")
|
||||||
except Empty:
|
except Empty:
|
||||||
pass # 还没完成
|
pass # 还没完成
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print('模型调用失败: ', e)
|
algo_log("MI模型训练失败: " + str(e), level="DEBUG")
|
||||||
|
|
||||||
'''训练阶段采集数据'''
|
'''训练阶段采集数据'''
|
||||||
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
|
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
|
||||||
if self.zmqServer.StartTrain:
|
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
|
||||||
self.currentLabel = self.zmqServer.currentLabel
|
self.zmqServer.train_epoch[1] + self.zmqServer.event_inner_idx:
|
||||||
self.zmqServer.StartTrain = False
|
self.currentLabel = self.zmqServer.currentLabel # 同步当前标签
|
||||||
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
algo_log(f"训练队列数据:{self.zmqServer.paradigmBuffer.GetDataLenCount()}", level="DEBUG")
|
||||||
self.interval_epoch[1] \
|
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
|
||||||
+ self.thread_data_server.event_inner_idx:
|
algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||||
time.sleep(0.0001)
|
|
||||||
return
|
|
||||||
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
|
|
||||||
originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据
|
|
||||||
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx])
|
|
||||||
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
||||||
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||||
print('trial: ', self.thread_data_server.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
|
# algo_log(f"trial: {self.zmqServer.event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG")
|
||||||
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
|
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
|
||||||
list) \
|
list) \
|
||||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||||
self.trainData.append(trainTrial)
|
self.trainData.append(trainTrial)
|
||||||
self.trainLabel.append(self.currentLabel)
|
self.trainLabel.append(self.currentLabel)
|
||||||
print('训练集:', np.shape(self.trainData))
|
algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG")
|
||||||
self.plotData.append(originalTrial[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
|
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
||||||
self.plotLabel.append(self.currentLabel)
|
self.plotLabel.append(self.currentLabel)
|
||||||
|
else:
|
||||||
|
time.sleep(0.0001)
|
||||||
|
return
|
||||||
|
|
||||||
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
|
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
|
||||||
if self.zmqServer.StartDecode:
|
if self.zmqServer.StartDecode:
|
||||||
self.zmqServer.StartDecode = False
|
self.zmqServer.StartDecode = False
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
print('启动预测 ', formatted_time)
|
algo_log(f"MI启动预测 {formatted_time}", level="DEBUG")
|
||||||
|
|
||||||
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||||
self.interval_epoch[1] \
|
self.interval_epoch[1] \
|
||||||
+ self.thread_data_server.event_inner_idx:
|
+ self.zmqServer.event_inner_idx:
|
||||||
time.sleep(0.0001)
|
time.sleep(0.0001)
|
||||||
return
|
return
|
||||||
originalData = self.thread_data_server.get_MIData() # 读取全部数据
|
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
|
||||||
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.thread_data_server.event_inner_idx])
|
algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||||
start = time.time()
|
start = time.time()
|
||||||
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
|
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
|
||||||
data = data[:,
|
data = data[:,
|
||||||
self.thread_data_server.event_inner_idx + self.interval_epoch[
|
self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||||
self.plotData.append(
|
self.plotData.append(
|
||||||
originalData[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
originalData[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
|
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
||||||
|
|
||||||
test_data = data[np.newaxis, np.newaxis, :, :]
|
test_data = data[np.newaxis, np.newaxis, :, :]
|
||||||
test_data = torch.from_numpy(test_data)
|
test_data = torch.from_numpy(test_data)
|
||||||
@@ -471,134 +422,40 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
Cls = self.model(test_data)
|
Cls = self.model(test_data)
|
||||||
y_pred = torch.max(Cls, 1)[1]
|
y_pred = torch.max(Cls, 1)[1]
|
||||||
self.plotLabel.append(int(y_pred.item()))
|
self.plotLabel.append(int(y_pred.item()))
|
||||||
print('运动意图识别: ', y_pred)
|
algo_log(f"MI运动意图识别: {y_pred}")
|
||||||
self.zmqClient.send_to_all('result', int(y_pred.item()))
|
self.zmqServer.broadcast_message('paradigm', int(y_pred.item()))
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f'发送给界面完成,耗时{end - start:.3f}s。')
|
algo_log(f'MI发送给界面完成,耗时{end - start:.3f}s。')
|
||||||
else: # 休息状态
|
else: # 休息状态
|
||||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||||
if self.thread_data_server.GetDataLenCount() < 25:
|
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
return
|
return
|
||||||
self.thread_data_server.getData(25)
|
self.zmqServer.paradigmBuffer.getData(25)
|
||||||
|
|
||||||
def decoder_concentration(self):
|
# def decoder_concentration(self):
|
||||||
if self.zmqServer.state_mode == 'predict':
|
# if self.zmqServer.state_mode == 'predict':
|
||||||
if self.zmqServer.StartDecode:
|
# if self.zmqServer.StartDecode:
|
||||||
self.zmqServer.StartDecode = False
|
# self.zmqServer.StartDecode = False
|
||||||
self.thread_data_server.ResetAll()
|
# self.thread_data_server.ResetAll()
|
||||||
now = datetime.now()
|
# now = datetime.now()
|
||||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
# formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
print('启动专注力预测 ', formatted_time)
|
# print('启动专注力预测 ', formatted_time)
|
||||||
if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.fs): # 每win_step得出一次结果
|
# if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.device_info['sample_rate']): # 每win_step得出一次结果
|
||||||
time.sleep(0.005)
|
# time.sleep(0.005)
|
||||||
return
|
# return
|
||||||
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
|
# if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
|
||||||
return
|
# return
|
||||||
data = self.thread_data_server.get_concentrateData(int(self.win_step * self.fs)) # 修改每次读取的数据
|
# data = self.thread_data_server.get_concentrateData(int(self.win_step * self.device_info['sample_rate'])) # 修改每次读取的数据
|
||||||
result = self.calculate.queueOpt(data)
|
# result = self.calculate.queueOpt(data)
|
||||||
if result is not None:
|
# if result is not None:
|
||||||
self.zmqClient.send_to_all('result', int(result))
|
# self.zmqClient.send_to_all('result', int(result))
|
||||||
else: # 休息状态
|
# else: # 休息状态
|
||||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
# if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
if self.thread_data_server.GetDataLenCount() < 25:
|
# if self.thread_data_server.GetDataLenCount() < 25:
|
||||||
time.sleep(0.005)
|
# time.sleep(0.005)
|
||||||
return
|
# return
|
||||||
self.thread_data_server.getData(25)
|
# self.thread_data_server.getData(25)
|
||||||
|
|
||||||
#### Blink detection #####
|
|
||||||
def check_double_blink(self, current_time):
|
|
||||||
"""
|
|
||||||
检查是否检测到连续两次眨眼
|
|
||||||
@param current_time: 当前眨眼时间戳
|
|
||||||
@return: True表示检测到连续两次眨眼
|
|
||||||
"""
|
|
||||||
if len(self.blink_timestamps) < 2:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 检查是否在去抖期内
|
|
||||||
if self.last_double_blink_time > 0:
|
|
||||||
time_since_last_double_blink = current_time - self.last_double_blink_time
|
|
||||||
if time_since_last_double_blink < self.double_blink_jitter:
|
|
||||||
return False # 在去抖期内,忽略连续眨眼检测
|
|
||||||
last_time = self.blink_timestamps[-1] # 当前眨眼
|
|
||||||
prev_time = self.blink_timestamps[-2] # 上次眨眼
|
|
||||||
|
|
||||||
interval = last_time - prev_time
|
|
||||||
if interval <= self.double_blink_interval:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def process_blink_detection(self):
|
|
||||||
"""
|
|
||||||
在缓冲区数据上执行,单次眨眼检测
|
|
||||||
"""
|
|
||||||
if len(self.fp1_buffer) < self.window_samples:
|
|
||||||
return
|
|
||||||
|
|
||||||
fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:])
|
|
||||||
fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:])
|
|
||||||
# 计算FP1和FP2的平均
|
|
||||||
fp12_mean = (fp1_data + fp2_data) / 2.0
|
|
||||||
# 带通滤波
|
|
||||||
try:
|
|
||||||
fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Filter error: {e}")
|
|
||||||
return
|
|
||||||
F = np.diff(fp12_filtered)
|
|
||||||
if len(F) < 3:
|
|
||||||
return
|
|
||||||
b, d, e = blink_detection(F, self.fs, self.Dmin, self.Dmax, self.EMin, self.EMax)
|
|
||||||
|
|
||||||
if b == 1:
|
|
||||||
samples_since_last = self.total_samples - self.last_blink_time
|
|
||||||
time_since_last_ms = (samples_since_last / self.fs) * 1000
|
|
||||||
if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms
|
|
||||||
self.blink_count += 1
|
|
||||||
self.last_blink_time = self.total_samples
|
|
||||||
current_time = time.time()
|
|
||||||
self.blink_timestamps.append(current_time)
|
|
||||||
blink_event = {
|
|
||||||
'count': self.blink_count,
|
|
||||||
'time': current_time,
|
|
||||||
'sample_index': self.total_samples,
|
|
||||||
'duration_ms': d,
|
|
||||||
'energy': e
|
|
||||||
}
|
|
||||||
self.blink_events.append(blink_event)
|
|
||||||
self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机
|
|
||||||
if self.check_double_blink(current_time):
|
|
||||||
self.double_blink_count += 1
|
|
||||||
interval = self.blink_timestamps[-1] - self.blink_timestamps[-2]
|
|
||||||
double_blink_event = {
|
|
||||||
'double_blink_count': self.double_blink_count,
|
|
||||||
'blink1_time': self.blink_timestamps[-2],
|
|
||||||
'blink2_time': self.blink_timestamps[-1],
|
|
||||||
'interval': interval
|
|
||||||
}
|
|
||||||
self.double_blink_events.append(double_blink_event)
|
|
||||||
self.last_double_blink_time = current_time
|
|
||||||
self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件
|
|
||||||
|
|
||||||
def decoder_blink(self):
|
|
||||||
if self.thread_data_server.GetDataLenCount() < 50:
|
|
||||||
time.sleep(0.005)
|
|
||||||
return
|
|
||||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
|
||||||
data = self.thread_data_server.get_blinkData(50)
|
|
||||||
fp1_data = data[0, :] # ch1 (相当于FP1)
|
|
||||||
fp2_data = data[1, :] # ch2 (相当于FP2)
|
|
||||||
for i in range(len(fp1_data)):
|
|
||||||
self.fp1_buffer.append(fp1_data[i])
|
|
||||||
self.fp2_buffer.append(fp2_data[i])
|
|
||||||
self.total_samples += 1
|
|
||||||
self.sample_counter += 1
|
|
||||||
|
|
||||||
if self.sample_counter >= self.step_samples:
|
|
||||||
self.process_blink_detection()
|
|
||||||
self.sample_counter = 0
|
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
'''
|
'''
|
||||||
@@ -606,12 +463,13 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
@return:
|
@return:
|
||||||
'''
|
'''
|
||||||
self.zmqServer.stop()
|
self.zmqServer.stop()
|
||||||
|
self.sliding_filter.stop()
|
||||||
self.Runing=False
|
self.Runing=False
|
||||||
|
|
||||||
def reset_state(self):
|
def reset_state(self):
|
||||||
"""清空解码器状态和缓存数据"""
|
"""清空解码器状态和缓存数据"""
|
||||||
# 重置设备层缓存
|
# 重置设备层缓存
|
||||||
self.thread_data_server.reset_state()
|
self.zmqServer.reset_state()
|
||||||
|
|
||||||
# 重置解码状态
|
# 重置解码状态
|
||||||
self.decodingSteps = 0
|
self.decodingSteps = 0
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ cudnn.benchmark = True
|
|||||||
cudnn.deterministic = True
|
cudnn.deterministic = True
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
# writer = SummaryWriter('./TensorBoardX/')
|
# writer = SummaryWriter('./TensorBoardX/')
|
||||||
|
from logs.log import algo_log
|
||||||
|
|
||||||
# Convolution module
|
# Convolution module
|
||||||
# use conv to capture local features, instead of postion embedding.
|
# use conv to capture local features, instead of postion embedding.
|
||||||
@@ -82,7 +82,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||||
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
fill_value = torch.finfo(torch.float32).min
|
fill_value = torch.finfo(torch.float64).min
|
||||||
energy.mask_fill(~mask, fill_value)
|
energy.mask_fill(~mask, fill_value)
|
||||||
|
|
||||||
scaling = self.emb_size ** (1 / 2)
|
scaling = self.emb_size ** (1 / 2)
|
||||||
@@ -318,11 +318,11 @@ class ExP():
|
|||||||
train_pred = torch.max(outputs, 1)[1]
|
train_pred = torch.max(outputs, 1)[1]
|
||||||
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
|
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
|
||||||
|
|
||||||
print('Epoch:', e,
|
algo_log('Epoch:', e,
|
||||||
' Train loss: %.6f' % loss.detach().cpu().numpy(),
|
' Train loss: %.6f' % loss.detach().cpu().numpy(),
|
||||||
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
|
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
|
||||||
' Train accuracy %.6f' % train_acc,
|
' Train accuracy %.6f' % train_acc,
|
||||||
' Test accuracy is %.6f' % acc)
|
' Test accuracy is %.6f' % acc, level="debug")
|
||||||
|
|
||||||
self.log_write.write(str(e) + " " + str(acc) + "\n")
|
self.log_write.write(str(e) + " " + str(acc) + "\n")
|
||||||
num = num + 1
|
num = num + 1
|
||||||
@@ -335,8 +335,8 @@ class ExP():
|
|||||||
|
|
||||||
torch.save(self.model, model_path)
|
torch.save(self.model, model_path)
|
||||||
averAcc = averAcc / num
|
averAcc = averAcc / num
|
||||||
print('The average accuracy is:', averAcc)
|
algo_log('The average accuracy is:', averAcc, level="debug")
|
||||||
print('The best accuracy is:', bestAcc)
|
algo_log('The best accuracy is:', bestAcc, level="debug")
|
||||||
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
|
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
|
||||||
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
|
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
|
||||||
|
|
||||||
@@ -346,10 +346,10 @@ class ExP():
|
|||||||
|
|
||||||
def onlineTrain(data_queue,result_queue):
|
def onlineTrain(data_queue,result_queue):
|
||||||
import torch
|
import torch
|
||||||
print(f"[DEBUG] torch.__version__ = {torch.__version__}")
|
algo_log(f"[DEBUG] torch.__version__ = {torch.__version__}", level="debug")
|
||||||
print(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}")
|
algo_log(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}", level="debug")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
print(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}")
|
algo_log(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}", level="debug")
|
||||||
try:
|
try:
|
||||||
starttime = datetime.datetime.now()
|
starttime = datetime.datetime.now()
|
||||||
|
|
||||||
@@ -366,12 +366,12 @@ def onlineTrain(data_queue,result_queue):
|
|||||||
data = data_queue.get(timeout=30)
|
data = data_queue.get(timeout=30)
|
||||||
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
|
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
|
||||||
exp = ExP(n_chan)
|
exp = ExP(n_chan)
|
||||||
print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path)
|
algo_log('训练参数: ',np.shape(all_data),np.shape(all_label),model_path, level="debug")
|
||||||
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
|
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
|
||||||
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
|
||||||
|
|
||||||
endtime = datetime.datetime.now()
|
endtime = datetime.datetime.now()
|
||||||
print('train duration: ',str(endtime - starttime))
|
algo_log('train duration: ',str(endtime - starttime), level="debug")
|
||||||
|
|
||||||
# 将模型或参数传回
|
# 将模型或参数传回
|
||||||
result_queue.put({
|
result_queue.put({
|
||||||
@@ -387,7 +387,7 @@ def offlineTrain(all_data,all_label,modelPath):
|
|||||||
|
|
||||||
# seed_n = np.random.randint(2025)
|
# seed_n = np.random.randint(2025)
|
||||||
seed_n = 1877
|
seed_n = 1877
|
||||||
print('seed is ' + str(seed_n))
|
algo_log('seed is ' + str(seed_n), level="debug")
|
||||||
random.seed(seed_n)
|
random.seed(seed_n)
|
||||||
np.random.seed(seed_n)
|
np.random.seed(seed_n)
|
||||||
torch.manual_seed(seed_n)
|
torch.manual_seed(seed_n)
|
||||||
@@ -397,13 +397,12 @@ def offlineTrain(all_data,all_label,modelPath):
|
|||||||
exp = ExP()
|
exp = ExP()
|
||||||
|
|
||||||
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
|
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
|
||||||
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
|
||||||
|
|
||||||
endtime = datetime.datetime.now()
|
endtime = datetime.datetime.now()
|
||||||
print('train duration: ',str(endtime - starttime))
|
algo_log('train duration: ',str(endtime - starttime), level="debug")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(time.asctime(time.localtime(time.time())))
|
algo_log(f"[DEBUG] time.asctime(time.localtime(time.time())) = {time.asctime(time.localtime(time.time()))}", level="debug")
|
||||||
print(time.asctime(time.localtime(time.time())))
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from einops import rearrange
|
|||||||
from einops.layers.torch import Rearrange, Reduce
|
from einops.layers.torch import Rearrange, Reduce
|
||||||
from torch.backends import cudnn
|
from torch.backends import cudnn
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
from logs.log import algo_log
|
||||||
# writer = SummaryWriter('./TensorBoardX/')
|
# writer = SummaryWriter('./TensorBoardX/')
|
||||||
|
|
||||||
|
|
||||||
@@ -71,7 +72,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||||
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
fill_value = torch.finfo(torch.float32).min
|
fill_value = torch.finfo(torch.float64).min
|
||||||
energy.mask_fill(~mask, fill_value)
|
energy.mask_fill(~mask, fill_value)
|
||||||
|
|
||||||
scaling = self.emb_size ** (1 / 2)
|
scaling = self.emb_size ** (1 / 2)
|
||||||
@@ -190,7 +191,7 @@ class ExP():
|
|||||||
# 自动选择设备:有 GPU 用 GPU,否则用 CPU
|
# 自动选择设备:有 GPU 用 GPU,否则用 CPU
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
# self.device = torch.device("cpu")
|
# self.device = torch.device("cpu")
|
||||||
print(f"Using device: {self.device}")
|
algo_log(f"Using device: {self.device}", level="debug")
|
||||||
|
|
||||||
# 定义张量类型(不再强制使用 cuda)
|
# 定义张量类型(不再强制使用 cuda)
|
||||||
self.Tensor = torch.FloatTensor
|
self.Tensor = torch.FloatTensor
|
||||||
|
|||||||
15
README.md
15
README.md
@@ -13,5 +13,16 @@ Debug_64ch_Decoder_Optimize is an updated version that fixes several issues and
|
|||||||
6. decoder class切换问题
|
6. decoder class切换问题
|
||||||
7. decoder_class切换时,数据重置、各类参数重置
|
7. decoder_class切换时,数据重置、各类参数重置
|
||||||
|
|
||||||
# update
|
|
||||||
2026年6月5日13:55:34
|
# 常用命令
|
||||||
|
source activate 3in1Py310
|
||||||
|
python runDecoder.py
|
||||||
|
python datamock.py
|
||||||
|
python ZeroMQClient_mock.py
|
||||||
|
python filter_test.py
|
||||||
|
python upperHost_stimmock/MI_headless.py
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 遗留问题
|
||||||
|
1. mvep是否要把list freq 开放到config
|
||||||
@@ -12,16 +12,17 @@ from scipy.io import loadmat
|
|||||||
from scipy.linalg import qr
|
from scipy.linalg import qr
|
||||||
from scipy.signal import filtfilt, lfilter
|
from scipy.signal import filtfilt, lfilter
|
||||||
# from numpy.linalg import _umath_linalg
|
# from numpy.linalg import _umath_linalg
|
||||||
|
from logs.log import algo_log
|
||||||
|
|
||||||
|
|
||||||
class FbccaDw:
|
class FbccaDw:
|
||||||
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
|
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
|
||||||
print('******************************************')
|
algo_log('******************************************', level="debug")
|
||||||
print('parameter list')
|
algo_log('parameter list', level="debug")
|
||||||
print('target:', num_target)
|
algo_log('target:', num_target, level="debug")
|
||||||
print('number of filter bank:', num_filter)
|
algo_log('number of filter bank:', num_filter, level="debug")
|
||||||
print('parameter:', parameter)
|
algo_log('parameter:', parameter, level="debug")
|
||||||
print('width:', width)
|
algo_log('width:', width, level="debug")
|
||||||
self.phase = 0
|
self.phase = 0
|
||||||
self.bandWidth = width
|
self.bandWidth = width
|
||||||
self.winNum = winNum
|
self.winNum = winNum
|
||||||
@@ -237,7 +238,7 @@ class FbccaDw:
|
|||||||
dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0])
|
dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0])
|
||||||
return np.asmatrix(dataFiltered)
|
return np.asmatrix(dataFiltered)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(Exception)
|
algo_log(f"Exception: {Exception}", level="debug")
|
||||||
|
|
||||||
'''
|
'''
|
||||||
getDataQ
|
getDataQ
|
||||||
|
|||||||
73
Tools/beta_calculate.py
Normal file
73
Tools/beta_calculate.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import numpy as np
|
||||||
|
from scipy.signal import welch
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Beta_Calculate():
|
||||||
|
def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=5, config=None):
|
||||||
|
self.Threshold_value_low = Threshold_value_low
|
||||||
|
self.Threshold_value_high = Threshold_value_high
|
||||||
|
self.fs = fs
|
||||||
|
self.beta_result = []
|
||||||
|
self.eegQueue = deque(maxlen=win_len)
|
||||||
|
|
||||||
|
def calculate_all(self, data, fs, nperseg=1000):
|
||||||
|
mean_x = np.mean(data, axis=-1, keepdims=True)
|
||||||
|
data = data - mean_x
|
||||||
|
freqs, psd = self.compute_psd_multichannel(data, fs, nperseg)
|
||||||
|
beta_psd = np.sum(self.band_psd(freqs, psd, (13, 30)))
|
||||||
|
alpha_psd = np.sum(self.band_psd(freqs, psd, (8, 13)))
|
||||||
|
theta_psd = np.sum(self.band_psd(freqs, psd, (4, 8)))
|
||||||
|
|
||||||
|
print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}")
|
||||||
|
|
||||||
|
return beta_psd, alpha_psd, theta_psd
|
||||||
|
|
||||||
|
def compute_psd_multichannel(self, data, fs=250, nperseg=1000):
|
||||||
|
n_samples = data.shape[-1]
|
||||||
|
if n_samples < nperseg:
|
||||||
|
nperseg = n_samples
|
||||||
|
|
||||||
|
noverlap = 500
|
||||||
|
if noverlap >= nperseg:
|
||||||
|
noverlap = int(nperseg / 2)
|
||||||
|
|
||||||
|
if nperseg == 0:
|
||||||
|
return np.array([]), np.zeros((data.shape[0], 0))
|
||||||
|
|
||||||
|
freqs, psd = welch(data, fs=fs, nperseg=nperseg, noverlap=noverlap, axis=-1)
|
||||||
|
return freqs, psd
|
||||||
|
|
||||||
|
def band_psd(self, freqs, psd, band):
|
||||||
|
idx = np.logical_and(freqs >= band[0], freqs <= band[1])
|
||||||
|
return np.sum(psd[:, idx], axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_queue(self):
|
||||||
|
self.eegQueue.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def queueOpt(self, data):
|
||||||
|
if data is None or data.size == 0:
|
||||||
|
return None
|
||||||
|
if len(self.eegQueue) < self.eegQueue.maxlen:
|
||||||
|
self.eegQueue.append(data)
|
||||||
|
else:
|
||||||
|
self.eegQueue.append(data)
|
||||||
|
|
||||||
|
if len(self.eegQueue) == self.eegQueue.maxlen:
|
||||||
|
eegData = np.hstack([self.eegQueue[i] for i in range(len(self.eegQueue))])
|
||||||
|
if eegData.size == 0:
|
||||||
|
return None
|
||||||
|
eegData -= np.mean(eegData, axis=-1, keepdims=True)
|
||||||
|
|
||||||
|
beta_psd, alpha_psd, theta_psd = self.calculate_all(eegData, fs=self.fs, nperseg=1000)
|
||||||
|
|
||||||
|
return (beta_psd)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
167
ZeroMQClient_mock.py
Normal file
167
ZeroMQClient_mock.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
import zmq
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
|
||||||
|
def receive_messages(socket, stop_event):
|
||||||
|
"""
|
||||||
|
后台线程函数,用于持续接收服务器消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
socket (zmq.Socket): ZeroMQ套接字
|
||||||
|
stop_event (threading.Event): 停止事件,用于通知线程退出
|
||||||
|
"""
|
||||||
|
print("开始持续接收服务器数据...")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
while not stop_event.is_set():
|
||||||
|
try:
|
||||||
|
# 设置接收超时为1秒,避免阻塞
|
||||||
|
socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||||
|
# 接收服务器的消息
|
||||||
|
frames = socket.recv_multipart()
|
||||||
|
|
||||||
|
# DEALER 套接字接收消息格式:[身份标识, 空帧, 消息内容]
|
||||||
|
# 使用frames[-1]获取最后一帧,无论中间有多少空帧
|
||||||
|
if len(frames) >= 2:
|
||||||
|
message = frames[-1].decode('utf-8')
|
||||||
|
|
||||||
|
# 尝试解析为JSON格式
|
||||||
|
try:
|
||||||
|
json_message = json.loads(message)
|
||||||
|
# 检查消息长度
|
||||||
|
json_str = str(json_message)
|
||||||
|
if len(json_str) > 100:
|
||||||
|
print(f"收到服务器数据 (JSON): {json_str[:100]}...")
|
||||||
|
else:
|
||||||
|
print(f"收到服务器数据 (JSON): {json_message}")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# 检查消息长度
|
||||||
|
if len(message) > 100:
|
||||||
|
print(f"收到服务器数据 (原始): {message[:100]}...")
|
||||||
|
else:
|
||||||
|
print(f"收到服务器数据 (原始): {message}")
|
||||||
|
else:
|
||||||
|
print(f"收到服务器数据 (格式异常): {frames}")
|
||||||
|
|
||||||
|
except zmq.Again:
|
||||||
|
# 接收超时,继续循环
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
print(f"接收消息时发生错误: {e}")
|
||||||
|
# 短暂暂停后继续接收
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
print("接收线程已停止。")
|
||||||
|
|
||||||
|
def zero_mq_client(server_address="tcp://127.0.0.1:8099"):
|
||||||
|
"""
|
||||||
|
ZeroMQ客户端函数,用于与服务器通信
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_address (str): 服务器地址,格式为"tcp://IP:端口"
|
||||||
|
"""
|
||||||
|
# 创建 ZeroMQ 上下文
|
||||||
|
context = zmq.Context()
|
||||||
|
|
||||||
|
# 创建 DEALER 套接字
|
||||||
|
socket = context.socket(zmq.DEALER)
|
||||||
|
|
||||||
|
# 生成唯一的身份标识
|
||||||
|
identity = str('wdd').encode('utf-8')
|
||||||
|
socket.setsockopt(zmq.IDENTITY, identity)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 连接到服务器
|
||||||
|
print(f"连接到服务器 {server_address}...")
|
||||||
|
socket.connect(server_address)
|
||||||
|
|
||||||
|
# 定义消息集
|
||||||
|
message_set = [
|
||||||
|
{"method": "sync", "params": 1},
|
||||||
|
{"method": "decoderClass", "params": "mi"},
|
||||||
|
{"method": "decoderClass", "params": "ssvep"},
|
||||||
|
{"method": "decoderClass", "params": "ssmvep"},
|
||||||
|
{"method": "decoderClass", "params": "blink"},
|
||||||
|
{"method": "decoderClass", "params": "concentration"},
|
||||||
|
{"method": "train", "params": 0},
|
||||||
|
{"method": "train", "params": 1},
|
||||||
|
{"method": "rest", "params": 0},
|
||||||
|
{"method": "predict", "params": 1},
|
||||||
|
{"method": "getReport", "params": 0},
|
||||||
|
{"method": "targetFreqs", "params": [11, 12, 13]}
|
||||||
|
]
|
||||||
|
|
||||||
|
# 打印消息集
|
||||||
|
print("消息集:")
|
||||||
|
for i, msg in enumerate(message_set):
|
||||||
|
print(f"[{i}] {msg}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
# 创建停止事件
|
||||||
|
stop_event = threading.Event()
|
||||||
|
|
||||||
|
# 启动接收线程
|
||||||
|
receive_thread = threading.Thread(target=receive_messages, args=(socket, stop_event))
|
||||||
|
receive_thread.daemon = True # 设置为守护线程,主线程退出时自动退出
|
||||||
|
receive_thread.start()
|
||||||
|
|
||||||
|
# 主线程处理控制台输入
|
||||||
|
print("输入消息序号发送对应消息,输入'q'退出程序:")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# 获取用户输入
|
||||||
|
user_input = input("请输入消息序号: ")
|
||||||
|
|
||||||
|
# 检查是否退出
|
||||||
|
if user_input.lower() == 'q':
|
||||||
|
print("正在退出程序...")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 尝试转换为整数
|
||||||
|
msg_index = int(user_input)
|
||||||
|
|
||||||
|
# 检查序号是否有效
|
||||||
|
if 0 <= msg_index < len(message_set):
|
||||||
|
# 获取对应的消息
|
||||||
|
selected_message = message_set[msg_index]
|
||||||
|
|
||||||
|
# 将消息转换为 JSON 字符串
|
||||||
|
json_message = json.dumps(selected_message)
|
||||||
|
|
||||||
|
# 打印发送信息
|
||||||
|
print(f"\n发送消息 (大小: {len(json_message)} 字节)...")
|
||||||
|
print(f"消息方法: {selected_message['method']}")
|
||||||
|
print(f"参数值: {selected_message['params']}")
|
||||||
|
|
||||||
|
# DEALER 套接字发送消息,包含身份标识和空帧
|
||||||
|
socket.send_multipart([identity, json_message.encode('utf-8')])
|
||||||
|
print("消息发送完成!")
|
||||||
|
print("-" * 50)
|
||||||
|
else:
|
||||||
|
print(f"无效的消息序号,请输入 0-{len(message_set)-1} 之间的数字。")
|
||||||
|
print("消息集:")
|
||||||
|
for i, msg in enumerate(message_set):
|
||||||
|
print(f"[{i}] {msg}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
print("请输入有效的数字或'q'退出。")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理输入时发生错误: {e}")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n程序被手动终止。")
|
||||||
|
finally:
|
||||||
|
# 停止接收线程
|
||||||
|
stop_event.set()
|
||||||
|
# 等待接收线程停止
|
||||||
|
time.sleep(1)
|
||||||
|
# 关闭套接字和上下文
|
||||||
|
socket.close()
|
||||||
|
context.term()
|
||||||
|
print("客户端已关闭。")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
zero_mq_client()
|
||||||
@@ -5,12 +5,13 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
import threading
|
import threading
|
||||||
|
from logs.log import algo_log
|
||||||
|
|
||||||
class ParadigmRingBuffer:
|
class ParadigmRingBuffer:
|
||||||
def __init__(self, n_chan, n_points):
|
def __init__(self, n_chan, n_points):
|
||||||
self.n_chan = n_chan
|
self.n_chan = n_chan
|
||||||
self.n_points = n_points
|
self.n_points = n_points
|
||||||
self.buffer = np.zeros((n_chan, n_points))
|
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
|
||||||
self.currentPtr = 0
|
self.currentPtr = 0
|
||||||
self.readPtr = 0
|
self.readPtr = 0
|
||||||
self.nUpdate = 0
|
self.nUpdate = 0
|
||||||
@@ -19,7 +20,8 @@ class ParadigmRingBuffer:
|
|||||||
## append buffer and update current pointer
|
## append buffer and update current pointer
|
||||||
def appendBuffer(self, data):
|
def appendBuffer(self, data):
|
||||||
if self.nUpdate == self.n_points:
|
if self.nUpdate == self.n_points:
|
||||||
raise Exception("Buffer is full")
|
# raise Exception("Buffer is full")
|
||||||
|
algo_log("ParadigmRingBuffer is full", record_once=True)
|
||||||
|
|
||||||
n = data.shape[1]
|
n = data.shape[1]
|
||||||
|
|
||||||
@@ -65,13 +67,56 @@ class ParadigmRingBuffer:
|
|||||||
'''
|
'''
|
||||||
return self.nUpdate
|
return self.nUpdate
|
||||||
|
|
||||||
|
# ========== 各范式数据访问接口 ==========
|
||||||
|
def get_MIData(self):
|
||||||
|
"""获取MI导联数据 (21通道 + 事件)"""
|
||||||
|
data = self.getData(self.GetDataLenCount())
|
||||||
|
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
if data.shape[1] > 0:
|
||||||
|
return data[row_to_select, :]
|
||||||
|
return np.zeros((len(rows_to_extract), 0))
|
||||||
|
|
||||||
|
def get_SSMVEPData(self):
|
||||||
|
"""获取SSMVEP导联数据 (8通道 + 事件)"""
|
||||||
|
data = self.getData(self.GetDataLenCount())
|
||||||
|
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
if data.shape[1] > 0:
|
||||||
|
return data[row_to_select, :]
|
||||||
|
return np.zeros((len(rows_to_extract), 0))
|
||||||
|
|
||||||
|
def getDataViaSSVEP(self, count):
|
||||||
|
"""获取SSVEP数据 (8通道 + 事件)"""
|
||||||
|
data = self.getData(count)
|
||||||
|
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
if data.shape[1] > 0:
|
||||||
|
return data[row_to_select, :]
|
||||||
|
return np.zeros((len(rows_to_extract), 0))
|
||||||
|
|
||||||
|
def get_concentrateData(self, count):
|
||||||
|
"""获取专注力数据 (2通道)"""
|
||||||
|
data = self.getData(count)
|
||||||
|
rows_to_extract = [0, 1]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
if data.shape[1] > 0:
|
||||||
|
return data[row_to_select, :]
|
||||||
|
return np.zeros((len(rows_to_extract), 0))
|
||||||
|
|
||||||
|
def get_blinkData(self, count):
|
||||||
|
"""获取眨眼数据 (2通道)"""
|
||||||
|
data = self.getData(count)
|
||||||
|
rows_to_extract = [0, 1]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
if data.shape[1] > 0:
|
||||||
|
return data[row_to_select, :]
|
||||||
|
return np.zeros((len(rows_to_extract), 0))
|
||||||
|
|
||||||
# reset buffer
|
# reset buffer
|
||||||
def resetAllPara(self):
|
def resetAllPara(self):
|
||||||
self.nUpdate = 0
|
self.nUpdate = 0
|
||||||
self.currentPtr = 0
|
self.currentPtr = 0
|
||||||
self.readPtr = 0 # add by lizhenhua 清空读指针
|
self.readPtr = 0
|
||||||
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区
|
self.buffer.fill(0.0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,132 +3,122 @@
|
|||||||
数据滤波模块
|
数据滤波模块
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import time
|
||||||
import threading
|
import threading
|
||||||
|
from scipy import signal
|
||||||
from logs.log import algo_log
|
from logs.log import algo_log
|
||||||
|
|
||||||
class FilterRingBuffer:
|
class FilterRingBuffer:
|
||||||
def __init__(self, n_chan, n_points):
|
def __init__(self, n_chan, n_points):
|
||||||
"""
|
|
||||||
初始化纯数据环形缓存
|
|
||||||
:param n_chan: 通道数
|
|
||||||
:param n_points: 总缓存点数(与paradigmRingBuffer参数完全一致)
|
|
||||||
"""
|
|
||||||
self.n_chan = n_chan
|
self.n_chan = n_chan
|
||||||
self.n_points = n_points
|
self.n_points = n_points
|
||||||
|
|
||||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
|
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
|
||||||
self.current_ptr = 0 # 写入指针
|
self.current_ptr = 0
|
||||||
self.total_samples = 0 # 已写入总点数
|
self.total_samples = 0
|
||||||
|
self.lock = threading.Lock() # 仅保护元数据
|
||||||
# 线程安全锁(多线程环境必须)
|
self.has_new_data = False
|
||||||
self.lock = threading.Lock()
|
|
||||||
|
|
||||||
def appendBuffer(self, data):
|
def appendBuffer(self, data):
|
||||||
"""
|
|
||||||
追加数据到缓存(与paradigmRingBuffer接口一致)
|
|
||||||
:param data: 输入数据,shape=(n_chan, n_samples)
|
|
||||||
"""
|
|
||||||
with self.lock:
|
|
||||||
n = data.shape[1]
|
n = data.shape[1]
|
||||||
if n == 0:
|
if n == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 环形写入逻辑
|
# 仅加锁读取/更新元数据
|
||||||
write_end = self.current_ptr + n
|
with self.lock:
|
||||||
|
old_ptr = self.current_ptr
|
||||||
|
new_ptr = (old_ptr + n) % self.n_points
|
||||||
|
new_total = min(self.total_samples + n, self.n_points)
|
||||||
|
self.has_new_data = True
|
||||||
|
|
||||||
|
# 数组写入(耗时操作,移出锁外)
|
||||||
|
write_end = old_ptr + n
|
||||||
if write_end <= self.n_points:
|
if write_end <= self.n_points:
|
||||||
self.buffer[:, self.current_ptr:write_end] = data
|
self.buffer[:, old_ptr:write_end] = data
|
||||||
else:
|
else:
|
||||||
split = self.n_points - self.current_ptr
|
split = self.n_points - old_ptr
|
||||||
self.buffer[:, self.current_ptr:] = data[:, :split]
|
self.buffer[:, old_ptr:] = data[:, :split]
|
||||||
self.buffer[:, :write_end - self.n_points] = data[:, split:]
|
self.buffer[:, :write_end - self.n_points] = data[:, split:]
|
||||||
|
|
||||||
# 更新指针和计数
|
# 再次加锁更新最终元数据
|
||||||
self.current_ptr = write_end % self.n_points
|
with self.lock:
|
||||||
self.total_samples = min(self.total_samples + n, self.n_points)
|
self.current_ptr = new_ptr
|
||||||
|
self.total_samples = new_total
|
||||||
|
|
||||||
|
# ========== 新增:获取&清空新数据标记的方法 ==========
|
||||||
|
def check_and_clear_new_data(self):
|
||||||
|
"""检查是否有新数据,并一次性清空标记(消费后重置)"""
|
||||||
|
with self.lock:
|
||||||
|
flag = self.has_new_data
|
||||||
|
if flag:
|
||||||
|
self.has_new_data = False
|
||||||
|
return flag
|
||||||
|
|
||||||
def getData(self, count):
|
def getData(self, count):
|
||||||
"""
|
# 加锁获取最新元数据
|
||||||
从读指针位置读取count个点(与paradigmRingBuffer接口一致)
|
|
||||||
:param count: 读取点数
|
|
||||||
:return: np.ndarray, shape=(n_chan, count)
|
|
||||||
"""
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
count = min(count, self.total_samples)
|
count = min(count, self.total_samples)
|
||||||
if count == 0:
|
if count == 0:
|
||||||
return np.zeros((self.n_chan, 0))
|
return np.zeros((self.n_chan, 0))
|
||||||
|
|
||||||
# 环形读取逻辑(与paradigmRingBuffer完全相同)
|
|
||||||
end = self.current_ptr
|
end = self.current_ptr
|
||||||
start = end - count
|
start = end - count
|
||||||
|
|
||||||
|
# 数据读取、切片、拼接(无锁)
|
||||||
if start >= 0:
|
if start >= 0:
|
||||||
return self.buffer[:, start:end].copy()
|
res = self.buffer[:, start:end].copy()
|
||||||
else:
|
else:
|
||||||
part1 = self.buffer[:, start:]
|
part1 = self.buffer[:, start:]
|
||||||
part2 = self.buffer[:, :end]
|
part2 = self.buffer[:, :end]
|
||||||
return np.concatenate((part1, part2), axis=1)
|
res = np.concatenate((part1, part2), axis=1).copy()
|
||||||
|
return res
|
||||||
|
|
||||||
def get_latest_n_points(self, n):
|
def get_latest_n_points(self, n):
|
||||||
"""
|
|
||||||
扩展方法:获取最新的n个点(不移动读指针,用于滑动窗口)
|
|
||||||
:param n: 点数
|
|
||||||
:return: np.ndarray, shape=(n_chan, n)
|
|
||||||
"""
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
if self.total_samples < n:
|
if self.total_samples < n:
|
||||||
return None
|
return None
|
||||||
return self.getData(n)
|
return self.getData(n)
|
||||||
|
|
||||||
def GetDataLenCount(self):
|
def GetDataLenCount(self):
|
||||||
"""获取当前缓存总点数(兼容原有接口)"""
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
return self.total_samples
|
return self.total_samples
|
||||||
|
|
||||||
def resetAllPara(self):
|
def resetAllPara(self):
|
||||||
"""重置所有缓存和指针(兼容原有接口)"""
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.buffer.fill(0.0)
|
self.buffer.fill(0.0)
|
||||||
self.current_ptr = 0
|
self.current_ptr = 0
|
||||||
self.total_samples = 0
|
self.total_samples = 0
|
||||||
|
self.has_new_data = False # 重置时清空新数据标记
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
|
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
|
||||||
# 可替换任意缓存实现,只要实现appendBuffer、get_latest_n_points接口
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class SlidingFilter:
|
class SlidingFilter(threading.Thread):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
ring_buffer: FilterRingBuffer,
|
||||||
n_chan=66,
|
n_chan=66,
|
||||||
srate=250,
|
srate=250,
|
||||||
buffer_sec=5,
|
|
||||||
window_sec=3,
|
window_sec=3,
|
||||||
step_sec=0.2,
|
step_sec=0.2
|
||||||
packet_size=5
|
|
||||||
):
|
):
|
||||||
"""
|
super().__init__(daemon=True)
|
||||||
初始化滑动滤波器
|
|
||||||
:param n_chan: 通道数
|
|
||||||
:param srate: 采样率
|
|
||||||
:param buffer_sec: 总缓存时长(秒)
|
|
||||||
:param window_sec: 滤波窗口时长(秒)
|
|
||||||
:param step_sec: 滑动步长/输出时长(秒)
|
|
||||||
:param packet_size: 每包数据点数(20ms一包=5点)
|
|
||||||
"""
|
|
||||||
# 核心参数
|
# 核心参数
|
||||||
self.n_chan = n_chan
|
self.n_chan = n_chan
|
||||||
self.srate = srate
|
self.srate = srate
|
||||||
self.buffer_size = int(srate * buffer_sec)
|
self.step_sec = step_sec # 200ms滑动步长
|
||||||
self.window_size = int(srate * window_sec)
|
self.window_sec = window_sec # 3秒窗口
|
||||||
self.step_size = int(srate * step_sec)
|
self.step_sec = step_sec # 200ms滑动步长
|
||||||
self.packet_size = packet_size
|
self.window_size = int(srate * window_sec) # 3秒点数:250*3=750
|
||||||
|
self.step_size = int(srate * step_sec) # 200ms点数:250*0.2=50
|
||||||
|
|
||||||
# 初始化纯数据缓存(解耦核心)
|
# 关联ZMQServer的环形缓存(解耦:仅依赖接口)
|
||||||
self.buffer = FilterRingBuffer(n_chan, self.buffer_size)
|
self.ring_buffer = ring_buffer
|
||||||
|
# 线程控制
|
||||||
|
self.running = threading.Event()
|
||||||
|
self.running.set()
|
||||||
|
# 滤波结果回调(外部可注册,获取滤波后的数据)
|
||||||
|
self.filter_result_callback = None
|
||||||
|
|
||||||
# 滤波触发计数器
|
# 预计算滤波器系数(仅执行一次)
|
||||||
self.packet_count = 0
|
|
||||||
self.ready_to_filter = False
|
|
||||||
|
|
||||||
# 预计算滤波器系数
|
|
||||||
self._init_filters()
|
self._init_filters()
|
||||||
|
|
||||||
def _init_filters(self):
|
def _init_filters(self):
|
||||||
@@ -138,71 +128,76 @@ class SlidingFilter:
|
|||||||
# 8~30Hz带通FIR(65阶,线性相位)
|
# 8~30Hz带通FIR(65阶,线性相位)
|
||||||
self.b_bp = signal.firwin(
|
self.b_bp = signal.firwin(
|
||||||
numtaps=65,
|
numtaps=65,
|
||||||
cutoff=[8/(self.srate/2), 30/(self.srate/2)],
|
cutoff=[0.5/(self.srate/2), 45/(self.srate/2)],
|
||||||
pass_zero=False,
|
pass_zero=False,
|
||||||
window='hamming'
|
window='hamming'
|
||||||
)
|
)
|
||||||
self.a_bp = np.array([1.0])
|
self.a_bp = np.array([1.0])
|
||||||
|
|
||||||
def append_and_check_trigger(self, raw_data):
|
def _filter_window_data(self, window_data):
|
||||||
"""
|
"""对3秒窗口数据执行滤波,返回无边界效应的200ms数据"""
|
||||||
追加单包原始数据并检查是否触发滤波
|
|
||||||
:param raw_data: 上位机原始数据,shape=(packet_size, n_chan)
|
|
||||||
:return: bool: 是否触发本次滤波
|
|
||||||
"""
|
|
||||||
# 转置为标准格式:(通道数, 点数)
|
|
||||||
data = raw_data.T.astype(np.float64)
|
|
||||||
|
|
||||||
# 写入缓存(纯缓存操作)
|
|
||||||
self.buffer.appendBuffer(data)
|
|
||||||
|
|
||||||
# 更新包计数器
|
|
||||||
self.packet_count += 1
|
|
||||||
|
|
||||||
# 检查滤波触发条件:数据≥窗口长度 且 累计满一个步长的包数
|
|
||||||
packets_per_step = int(self.step_size / self.packet_size) # 10包=200ms
|
|
||||||
if (self.buffer.GetDataLenCount() >= self.window_size
|
|
||||||
and self.packet_count >= packets_per_step):
|
|
||||||
self.packet_count = 0
|
|
||||||
self.ready_to_filter = True
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def filter_and_get_output(self):
|
|
||||||
"""
|
|
||||||
执行滤波并返回无边界效应的输出数据
|
|
||||||
:return: np.ndarray: 滤波后数据,shape=(n_chan, step_size)
|
|
||||||
"""
|
|
||||||
if not self.ready_to_filter:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 获取最新的完整滤波窗口数据
|
|
||||||
window_data = self.buffer.get_latest_n_points(self.window_size)
|
|
||||||
if window_data is None:
|
|
||||||
self.ready_to_filter = False
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 零相位滤波(无延迟,无边界效应)
|
# 零相位滤波(无延迟,无边界效应)
|
||||||
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
|
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
|
||||||
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
|
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
|
||||||
filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1)
|
filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1)
|
||||||
|
|
||||||
# 提取倒数第二个步长的数据(完全避开两端边界效应)
|
# 提取倒数第二个200ms的数据(完全避开两端边界效应)
|
||||||
|
# 窗口长度750,步长50 → start=750-100=650,end=750-50=700
|
||||||
start_idx = self.window_size - 2 * self.step_size
|
start_idx = self.window_size - 2 * self.step_size
|
||||||
end_idx = self.window_size - self.step_size
|
end_idx = self.window_size - self.step_size
|
||||||
output_data = filtered[:, start_idx:end_idx].copy()
|
output_data = filtered[:, start_idx:end_idx].copy()
|
||||||
|
|
||||||
# 重置触发标志
|
|
||||||
self.ready_to_filter = False
|
|
||||||
|
|
||||||
return output_data
|
return output_data
|
||||||
|
|
||||||
def reset(self):
|
def run(self):
|
||||||
"""重置滤波器和缓存"""
|
"""线程主逻辑:精确200ms触发一次滤波"""
|
||||||
self.buffer.resetAllPara()
|
interval = self.step_sec # 200ms = 0.2秒
|
||||||
self.packet_count = 0
|
next_run_time = time.perf_counter()
|
||||||
self.ready_to_filter = False
|
while self.running.is_set():
|
||||||
|
# 1. 精确定时等待
|
||||||
|
current_time = time.perf_counter()
|
||||||
|
if current_time < next_run_time:
|
||||||
|
time.sleep(next_run_time - current_time)
|
||||||
|
next_run_time += interval
|
||||||
|
else:
|
||||||
|
algo_log("滤波耗时超过200ms,定时偏移", level='debug')
|
||||||
|
next_run_time = time.perf_counter() + interval
|
||||||
|
|
||||||
def get_buffer_length(self):
|
# ========== 新增核心判断:无新数据则直接跳过 ==========
|
||||||
"""获取当前缓存数据长度"""
|
if not self.ring_buffer.check_and_clear_new_data():
|
||||||
return self.buffer.GetDataLenCount()
|
# 无新数据,不执行滤波、不发送数据
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 2. 有新数据,才执行原有滤波逻辑
|
||||||
|
try:
|
||||||
|
window_data = self.ring_buffer.get_latest_n_points(self.window_size)
|
||||||
|
if window_data is None:
|
||||||
|
algo_log(f"缓存数据不足,当前缓存{self.ring_buffer.GetDataLenCount()}点,需{self.window_size}点", level='debug')
|
||||||
|
continue
|
||||||
|
|
||||||
|
filtered_data = self._filter_window_data(window_data)
|
||||||
|
# algo_log(f"滤波后{filtered_data.shape}数据", level='debug')
|
||||||
|
|
||||||
|
if self.filter_result_callback is not None:
|
||||||
|
self.filter_result_callback(filtered_data[:64, :])
|
||||||
|
except Exception as e:
|
||||||
|
algo_log(f"滤波执行异常: {e}", level='error')
|
||||||
|
|
||||||
|
def set_result_callback(self, callback):
|
||||||
|
"""注册滤波结果回调函数"""
|
||||||
|
self.filter_result_callback = callback
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""停止滤波线程(安全版)"""
|
||||||
|
# 1. 先设置停止标志(Event.clear()是线程安全的)
|
||||||
|
self.running.clear()
|
||||||
|
|
||||||
|
# 2. 核心修复:只有线程已启动且正在运行时才调用join
|
||||||
|
if self.is_alive():
|
||||||
|
# 等待线程正常退出,最多1秒
|
||||||
|
self.join(timeout=1)
|
||||||
|
# 超时未退出时打印警告,便于排查问题
|
||||||
|
if self.is_alive():
|
||||||
|
algo_log("警告:滤波线程在1秒内未正常退出,可能存在阻塞操作", level="WARNING")
|
||||||
|
|
||||||
|
# 3. 无论线程是否启动,都打印停止日志
|
||||||
|
algo_log("滤波线程已停止")
|
||||||
|
|||||||
451
Zmq/zmqServer.py
451
Zmq/zmqServer.py
@@ -1,241 +1,424 @@
|
|||||||
|
# -*-coding:utf-8 -*-
|
||||||
|
import ast
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import zmq
|
|
||||||
import threading
|
import threading
|
||||||
|
import zmq
|
||||||
import json
|
import json
|
||||||
import queue
|
import queue
|
||||||
# from Device.SunnyLinker import SunnyLinker64
|
from typing import Dict
|
||||||
from dataBuffer import ParadigmRingBuffer
|
import datetime
|
||||||
from filterProcess import FilterRingBuffer
|
import time
|
||||||
|
|
||||||
|
from Zmq.dataBuffer import ParadigmRingBuffer
|
||||||
|
from Zmq.filterProcess import FilterRingBuffer
|
||||||
|
from PubLibrary.InifileHelper import IniRead
|
||||||
from logs.log import algo_log
|
from logs.log import algo_log
|
||||||
|
|
||||||
|
zmqServer_host = str(IniRead('system', 'zmqServer_host', '127.0.0.1'))
|
||||||
|
|
||||||
class zmqServer(threading.Thread):
|
class zmqServer(threading.Thread):
|
||||||
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
|
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
self.host = host
|
self.device_info = device_info
|
||||||
self.cmd_port = cmd_port # 命令交互端口
|
|
||||||
self.data_port = data_port # 数据接收端口
|
self.host = zmqServer_host
|
||||||
|
|
||||||
|
self.cmd_port = cmd_port # 命令交互端口:收JSON命令 + 返JSON结果
|
||||||
|
self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
# 原有业务状态变量
|
# 原有业务状态变量
|
||||||
# self.get_Impedance = False # 是否返回阻抗值
|
self.open_Impedance = False #当前系统处于阻抗检测状态
|
||||||
# self.open_Impedance = None # 是否开启阻抗检测功能
|
self.StartDecode = False
|
||||||
self.StartDecode = False # false 停止解码,true=开始解码
|
self.StartTrain = False
|
||||||
self.StartTrain = False # False未进入训练状态,True处于训练状态
|
self.state_mode = None
|
||||||
self.state_mode = None # 'train'为训练状态,’rest'为休息状态,'test'为测试状态
|
self.currentLabel = -1
|
||||||
self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签
|
self.IsExitApp = False
|
||||||
self.IsExitApp = False # 当socket收到2的时候,就置为True,代表要退出系统了。
|
|
||||||
# self.getReport = False # 获取训练报告内容
|
|
||||||
self.daemon = True
|
self.daemon = True
|
||||||
|
|
||||||
# 范式数据缓存
|
# 双环形缓冲区
|
||||||
self.paradigmBuffer = ParadigmRingBuffer(66, 2500)
|
self.paradigmBuffer = ParadigmRingBuffer(
|
||||||
self.filterBuffer = FilterRingBuffer(66, 2500)
|
self.device_info['channel_nums'],
|
||||||
|
self.device_info['sample_rate'] * 10
|
||||||
|
)
|
||||||
|
self.filterBuffer = FilterRingBuffer(
|
||||||
|
self.device_info['channel_nums'],
|
||||||
|
self.device_info['sample_rate'] * 10
|
||||||
|
)
|
||||||
|
self.paradigmBufferLock = threading.Lock()
|
||||||
|
self.filterBufferLock = threading.Lock()
|
||||||
|
|
||||||
|
# ZMQ上下文与套接字
|
||||||
# 命令与数据通信
|
|
||||||
self.context = zmq.Context()
|
self.context = zmq.Context()
|
||||||
# 指令通道 (8099) - ROUTER:短JSON命令,低频率
|
|
||||||
|
# 8099命令端口:ROUTER
|
||||||
self.cmd_socket = self.context.socket(zmq.ROUTER)
|
self.cmd_socket = self.context.socket(zmq.ROUTER)
|
||||||
self.cmd_socket.setsockopt(zmq.RCVHWM, 100) # 指令不需要大缓存,100条足够
|
self.cmd_socket.setsockopt(zmq.SocketOption.RCVHWM, 100)
|
||||||
self.cmd_socket.setsockopt(zmq.SNDHWM, 100)
|
self.cmd_socket.setsockopt(zmq.SocketOption.SNDHWM, 100)
|
||||||
self.cmd_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法,降低指令延迟
|
|
||||||
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
|
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
|
||||||
|
|
||||||
# 数据通道 (8100) - ROUTER:高频脑电二进制流
|
# 8100数据端口:ROUTER
|
||||||
self.data_socket = self.context.socket(zmq.ROUTER)
|
self.data_socket = self.context.socket(zmq.ROUTER)
|
||||||
self.data_socket.setsockopt(zmq.RCVHWM, 500) # 500包=10秒缓存,足够应对短时卡顿
|
self.data_socket.setsockopt(zmq.SocketOption.RCVHWM, 500)
|
||||||
self.data_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法,减少数据传输延迟
|
self.data_socket.setsockopt(zmq.SocketOption.SNDHWM, 100) # 添加发送高水位线
|
||||||
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
|
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
|
||||||
|
|
||||||
# Poller 轮训器(保持不变)
|
# Poller轮询器
|
||||||
self.poller = zmq.Poller()
|
self.poller = zmq.Poller()
|
||||||
self.poller.register(self.cmd_socket, zmq.POLLIN)
|
self.poller.register(self.cmd_socket, zmq.POLLIN)
|
||||||
self.poller.register(self.data_socket, zmq.POLLIN)
|
self.poller.register(self.data_socket, zmq.POLLIN)
|
||||||
|
|
||||||
# 业务变量
|
# 业务变量
|
||||||
self.targetFreqs = []
|
self.targetFreqs = []
|
||||||
self.changeTarget = False # 更换目标频率
|
self.changeTarget = False
|
||||||
# self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类,已在Decoder实例化
|
|
||||||
self.labels = [0x01, 0x02, 0x03]
|
self.labels = [0x01, 0x02, 0x03]
|
||||||
self.decoder_switch = False #更换解码器
|
self.decoder_switch = False
|
||||||
self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi'
|
self.decoder_class = None
|
||||||
|
|
||||||
# 客户端管理 - 区分命令/数据客户端
|
# 客户端管理(单客户端场景)
|
||||||
self.cmd_clients = set() # 命令端口客户端ID
|
self.cmd_clients = set()
|
||||||
self.data_clients = set() # 数据端口客户端ID
|
self.data_clients = set()
|
||||||
self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播)
|
self.current_data_client = None # 唯一数据客户端身份,用于发送滤波结果
|
||||||
|
|
||||||
|
# 发送队列(双端口分离)
|
||||||
|
self.cmd_send_queue = queue.Queue() # 8099端口命令结果队列
|
||||||
|
self.data_send_queue = queue.Queue() # 8100端口滤波数据队列
|
||||||
|
|
||||||
|
# 范式buffer与事件检测参数
|
||||||
|
self.predict_event = 99
|
||||||
|
self.events = [1, 2, self.predict_event]
|
||||||
|
self.latency = 50
|
||||||
|
self.train_latency = 50
|
||||||
|
self.count_events = {}
|
||||||
|
self.epoch_finished = False
|
||||||
|
self.pack_contain_event = False
|
||||||
|
self.event_inner_idx = -1
|
||||||
|
self.interval_inited = False
|
||||||
|
self.last_epoch_finish_time = None
|
||||||
|
|
||||||
|
def reset_state(self):
|
||||||
|
"""清空采集器状态和缓存数据"""
|
||||||
|
with self.paradigmBufferLock:
|
||||||
|
self.paradigmBuffer.resetAllPara()
|
||||||
|
self.count_events = {}
|
||||||
|
self.epoch_finished = False
|
||||||
|
self.pack_contain_event = False
|
||||||
|
self.event_inner_idx = -1
|
||||||
|
self.interval_inited = False
|
||||||
|
|
||||||
|
def interval_init(self, decoder_class):
|
||||||
|
if decoder_class == 'ssmvep':
|
||||||
|
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
|
||||||
|
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # [50, 550]
|
||||||
|
self.train_epoch = [
|
||||||
|
int(self.interval_epoch[0]),
|
||||||
|
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])
|
||||||
|
] # [50, 575]
|
||||||
|
self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #115包, 575个点
|
||||||
|
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #120包 600个点
|
||||||
|
|
||||||
|
elif decoder_class == 'mi':
|
||||||
|
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
|
||||||
|
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] #[125, 1125]
|
||||||
|
self.train_epoch = self.interval_epoch.copy()
|
||||||
|
self.latency = self.interval_epoch[1] // 5 #225
|
||||||
|
self.train_latency = self.latency #225
|
||||||
|
|
||||||
|
algo_log(f"时间窗初始化完成: {interval_epoch}", level="INFO")
|
||||||
|
self.count_events: Dict[str, int] = {}
|
||||||
|
self.event_inner_idx = -1
|
||||||
|
self.epoch_finished = False
|
||||||
|
self.pack_contain_event = False
|
||||||
|
self.predict_event = 99
|
||||||
|
self.events = [1, 2, self.predict_event]
|
||||||
|
self.interval_inited = True
|
||||||
|
|
||||||
|
# -------------------------- 8099端口:命令结果广播 --------------------------
|
||||||
def broadcast_message(self, method, params):
|
def broadcast_message(self, method, params):
|
||||||
"""Put message into queue to be sent to all command clients"""
|
"""
|
||||||
self.send_queue.put((method, params))
|
向所有8099端口客户端广播JSON格式的命令结果
|
||||||
|
用于:解码结果、训练状态、错误提示、进度通知等
|
||||||
|
"""
|
||||||
|
self.cmd_send_queue.put((method, params))
|
||||||
|
|
||||||
def _handle_cmd_message(self, frames):
|
def _process_cmd_send_queue(self):
|
||||||
"""处理命令端口消息(原有命令交互逻辑)"""
|
"""处理8099端口发送队列,在主线程执行(保证ZMQ线程安全)"""
|
||||||
if len(frames) < 3:
|
while not self.cmd_send_queue.empty():
|
||||||
|
method, params = self.cmd_send_queue.get()
|
||||||
|
if not self.cmd_clients:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg = {'method': method, 'params': params}
|
||||||
|
msg_bytes = json.dumps(msg).encode('utf-8')
|
||||||
|
|
||||||
|
algo_log(f"发送命令结果: {msg}", level="DEBUG")
|
||||||
|
|
||||||
|
# 广播到所有命令客户端
|
||||||
|
for client_id in list(self.cmd_clients):
|
||||||
|
try:
|
||||||
|
self.cmd_socket.send_multipart([client_id, b"", msg_bytes])
|
||||||
|
except Exception as e:
|
||||||
|
algo_log(f"向命令客户端{client_id}发送失败: {e}", level="ERROR")
|
||||||
|
self.cmd_clients.discard(client_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
algo_log(f"命令结果打包失败: {e}", level="ERROR")
|
||||||
|
|
||||||
|
# -------------------------- 8100端口:滤波结果发送 --------------------------
|
||||||
|
def send_filtered_data(self, filtered_data):
|
||||||
|
"""
|
||||||
|
向8100端口客户端发送二进制格式的滤波结果
|
||||||
|
用于:上位机实时绘图的脑电波形数据
|
||||||
|
:param filtered_data: 滤波后数据,shape=(通道数, 50),float64格式
|
||||||
|
"""
|
||||||
|
if self.current_data_client is None:
|
||||||
|
algo_log("数据客户端未连接,跳过滤波数据发送", level="WARNING")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 转置为上位机需要的[50, 通道数]格式
|
||||||
|
filtered_data = filtered_data.T.astype(np.float64)
|
||||||
|
send_buf = filtered_data.tobytes()
|
||||||
|
algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG", record_once=True)
|
||||||
|
self.data_send_queue.put(send_buf)
|
||||||
|
|
||||||
|
def _process_data_send_queue(self):
|
||||||
|
"""处理8100端口发送队列,在主线程执行(保证ZMQ线程安全)"""
|
||||||
|
while not self.data_send_queue.empty():
|
||||||
|
send_buf = self.data_send_queue.get()
|
||||||
|
if self.current_data_client is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 标准ROUTER发送格式:[客户端ID, 空分隔帧, 数据帧]
|
||||||
|
self.data_socket.send_multipart([
|
||||||
|
self.current_data_client,
|
||||||
|
b"",
|
||||||
|
send_buf
|
||||||
|
])
|
||||||
|
algo_log(f"发送滤波数据成功,长度: {len(send_buf)}字节", level="DEBUG", record_once=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
algo_log(f"发送滤波数据失败: {e}", level="ERROR")
|
||||||
|
# 客户端断开,重置身份
|
||||||
|
self.current_data_client = None
|
||||||
|
self.data_clients.clear()
|
||||||
|
|
||||||
|
# -------------------------- 命令端口消息处理 --------------------------
|
||||||
|
def _handle_cmd_message(self, frames):
|
||||||
|
"""处理8099端口JSON命令消息"""
|
||||||
|
if len(frames) < 3:
|
||||||
|
algo_log(f"无效命令帧:长度不足3帧,实际{len(frames)}", level="ERROR")
|
||||||
|
return
|
||||||
|
|
||||||
ident, _, message_bytes = frames[:3]
|
ident, _, message_bytes = frames[:3]
|
||||||
|
|
||||||
# 注册新的命令客户端
|
# 注册新的命令客户端
|
||||||
if ident not in self.cmd_clients:
|
if ident not in self.cmd_clients:
|
||||||
self.cmd_clients.add(ident)
|
self.cmd_clients.add(ident)
|
||||||
print(f"New CMD Client Connected: {ident} (port: {self.cmd_port})")
|
algo_log(f"新命令客户端连接成功: {ident}", level="INFO")
|
||||||
|
|
||||||
# 解析消息
|
# 解析JSON命令
|
||||||
try:
|
try:
|
||||||
message = json.loads(message_bytes.decode('utf-8'))
|
message = json.loads(message_bytes.decode('utf-8'))
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
print(f"Invalid JSON from CMD client {ident}")
|
algo_log(f"无效JSON命令: {message_bytes.hex()}", level="ERROR")
|
||||||
continue
|
self.broadcast_message("error", {"code": 400, "message": "无效JSON格式"})
|
||||||
print(f"Received CMD request: {message}")
|
return
|
||||||
|
|
||||||
|
algo_log(f"收到命令: {message}", level="INFO")
|
||||||
method = message.get("method")
|
method = message.get("method")
|
||||||
params = message.get("params")
|
params = message.get("params")
|
||||||
|
|
||||||
# 原有命令处理逻辑
|
# 命令处理逻辑
|
||||||
if method == "sync":
|
if method == "sync":
|
||||||
self.state_mode = 'sync'
|
self.state_mode = 'sync'
|
||||||
if method == "targetFreqs":
|
elif method == "targetFreqs":
|
||||||
if not isinstance(params, list):
|
if not isinstance(params, list):
|
||||||
print('targetFreqs must be a list')
|
algo_log(f"targetFreqs must be a list")
|
||||||
continue
|
return
|
||||||
if params != self.targetFreqs:
|
if params != self.targetFreqs:
|
||||||
self.targetFreqs = params
|
self.targetFreqs = params
|
||||||
self.changeTarget = True
|
self.changeTarget = True
|
||||||
if method == "decoderClass":
|
elif method == "decoderClass":
|
||||||
if not isinstance(params, str):
|
if not isinstance(params, str):
|
||||||
print('decoderClass must be a str')
|
algo_log(f"decoderClass必须是字符串")
|
||||||
continue
|
return
|
||||||
if params != self.decoder_class:
|
if params != self.decoder_class:
|
||||||
self.decoder_class = params
|
self.decoder_class = params
|
||||||
self.decoder_switch = True
|
self.decoder_switch = True
|
||||||
if method == "getReport":
|
elif method == "train":
|
||||||
self.getReport = True
|
|
||||||
if method == "train":#训练状态
|
|
||||||
self.state_mode = 'train'
|
self.state_mode = 'train'
|
||||||
self.StartTrain = True
|
resp = {
|
||||||
self.currentLabel = params # 当前刺激端的训练标签
|
"method": "train_response",
|
||||||
self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
|
"params": {
|
||||||
elif method == "predict":#预测状态
|
"code": 200,
|
||||||
|
"message": "ok"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
resp_bytes = json.dumps(resp, ensure_ascii=False).encode("utf-8")
|
||||||
|
self.cmd_socket.send_multipart([ident, b"", resp_bytes])
|
||||||
|
algo_log(f"train 命令已即时回复客户端 {ident}", level="DEBUG")
|
||||||
|
except Exception as e:
|
||||||
|
algo_log(f"train 命令回复失败: {e}", level="ERROR")
|
||||||
|
return
|
||||||
|
elif method == "predict":
|
||||||
self.state_mode = 'predict'
|
self.state_mode = 'predict'
|
||||||
if params == 1: #开始解码
|
if params == 1: #开始解码
|
||||||
self.StartDecode = True
|
self.StartDecode = True
|
||||||
self.sunnyLinker.push_trigger(0x63)
|
|
||||||
elif params == 2: #停止解码
|
elif params == 2: #停止解码
|
||||||
self.IsExitApp = True
|
self.IsExitApp = True
|
||||||
self.running = False
|
self.running = False
|
||||||
elif method == "rest": #休息状态
|
elif method == "rest":
|
||||||
self.state_mode = 'rest'
|
self.state_mode = 'rest'
|
||||||
# elif method == "impedance":
|
elif method == "impedance":
|
||||||
# if params == 1:
|
if params == 1:
|
||||||
# self.open_Impedance = True # 开启阻抗
|
self.open_Impedance = True
|
||||||
# self.get_Impedance = True # 返回阻抗
|
elif params == 2:
|
||||||
# elif params == 2:
|
self.open_Impedance = False
|
||||||
# self.open_Impedance = False # 关闭阻抗
|
else:
|
||||||
# self.get_Impedance = False # 停止返回阻抗
|
self.broadcast_message("error", {"code": 404, "message": f"未知命令: {method}"})
|
||||||
|
|
||||||
|
# -------------------------- 数据端口消息处理 --------------------------
|
||||||
def _handle_data_message(self, frames):
|
def _handle_data_message(self, frames):
|
||||||
"""
|
"""处理8100端口二进制脑电数据消息"""
|
||||||
处理8100端口原始脑电二进制数据
|
algo_log(f"收到数据帧,总帧数:{len(frames)}", level="DEBUG", record_once=True)
|
||||||
固定格式:上位机发送 (5,66) float32 二维数组字节流(已转换为微伏物理量)→ 转置为 (66,5) 写入双缓冲区
|
# 然后再进行解析
|
||||||
"""
|
if len(frames) == 4:
|
||||||
# 1. 校验ZMQ消息帧完整性
|
# 你的上位机格式
|
||||||
if len(frames) < 3:
|
ident, sender_ident, empty_sep, data_bytes = frames[:4]
|
||||||
print(f"[ERROR] 无效数据帧:长度不足3帧,实际长度={len(frames)}")
|
elif len(frames) == 3:
|
||||||
|
# 标准格式
|
||||||
|
ident, empty_sep, data_bytes = frames[:3]
|
||||||
|
elif len(frames) == 2:
|
||||||
|
ident, data_bytes = frames[:2]
|
||||||
|
else:
|
||||||
return
|
return
|
||||||
|
# 注册新的数据客户端(单客户端场景,自动覆盖旧身份)
|
||||||
ident, _, data_bytes = frames[:3]
|
|
||||||
|
|
||||||
# 2. 客户端管理(单客户端场景,自动更新最新身份)
|
|
||||||
if ident not in self.data_clients:
|
if ident not in self.data_clients:
|
||||||
|
self.data_clients.clear() # 单客户端,只保留最新连接
|
||||||
self.data_clients.add(ident)
|
self.data_clients.add(ident)
|
||||||
self.current_data_client = ident # 保存唯一客户端身份,用于后续回复滤波结果
|
self.current_data_client = ident
|
||||||
print(f"[INFO] 新数据客户端连接成功:{ident}")
|
algo_log(f"新数据客户端连接成功: {ident}", level="INFO")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 3. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节,与int32字节数相同)
|
# 精确长度校验
|
||||||
EXPECTED_BYTES = 5 * 66 * 4 # 每个float32占4字节
|
EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * np.dtype(np.float64).itemsize
|
||||||
if len(data_bytes) != EXPECTED_BYTES:
|
if len(data_bytes) != EXPECTED_BYTES:
|
||||||
print(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节")
|
algo_log(f"数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节", level="ERROR")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 4. 零拷贝二进制解析 + 维度转换
|
# 零拷贝解析 + 维度转换
|
||||||
# 步骤:字节流 → (330,) float32数组 → (5,66) 原始格式 → 转置为 (66,5) 缓冲区标准格式
|
data_np = np.frombuffer(data_bytes, dtype=np.float64)
|
||||||
data_np = np.frombuffer(data_bytes, dtype=np.float32)
|
data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
|
||||||
# 重塑为上位机原始维度
|
|
||||||
data_np = data_np.reshape(5, 66)
|
|
||||||
# 转置为(通道数, 采样点数)标准格式,转换为float64保证滤波运算精度
|
|
||||||
data_np = data_np.T.astype(np.float64)
|
data_np = data_np.T.astype(np.float64)
|
||||||
|
|
||||||
# 5. 同时写入双环形缓冲区(方法名与现有类保持一致:appendBuffer)
|
# 写入滤波缓冲区
|
||||||
# 注意:上位机已发送微伏物理量,无需再乘以增益系数
|
with self.filterBufferLock:
|
||||||
self.paradigmBuffer.appendBuffer(data_np)
|
|
||||||
self.filterBuffer.appendBuffer(data_np)
|
self.filterBuffer.appendBuffer(data_np)
|
||||||
|
|
||||||
# 生产环境必须注释!每秒50次打印会导致CPU占用飙升30%以上
|
# 写入范式缓冲区
|
||||||
algo_log(f"数据写入成功:shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG", record_once=True)
|
with self.paradigmBufferLock:
|
||||||
|
if self.interval_inited:
|
||||||
|
self.epoch_finished = self.detect_event(data_np)
|
||||||
|
if self.pack_contain_event:
|
||||||
|
self.paradigmBuffer.resetAllPara()
|
||||||
|
self.paradigmBuffer.appendBuffer(data_np)
|
||||||
|
|
||||||
|
if self.epoch_finished:
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
time_diff_str = ""
|
||||||
|
# 计算与上一次Epoch完成的时间差
|
||||||
|
if self.last_epoch_finish_time is not None:
|
||||||
|
# 时间差 单位:秒,保留3位小数
|
||||||
|
delta_seconds = (now - self.last_epoch_finish_time).total_seconds()
|
||||||
|
time_diff_str = f" | 与上一次间隔: {delta_seconds:.3f} s"
|
||||||
|
|
||||||
|
# 拼接日志,增加时间差信息
|
||||||
|
log_msg = f"Epoch采集完成: {now.strftime('%H:%M:%S.%f')[:-3]}{time_diff_str}"
|
||||||
|
algo_log(log_msg, level="DEBUG")
|
||||||
|
|
||||||
|
# 更新上一次Epoch完成时间为当前时间
|
||||||
|
self.last_epoch_finish_time = now
|
||||||
|
else:
|
||||||
|
self.paradigmBuffer.appendBuffer(data_np)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
algo_log(f"数据处理失败:{str(e)}", level="ERROR")
|
algo_log(f"数据处理失败: {str(e)}", level="ERROR")
|
||||||
# 调试阶段临时打开,生产环境务必注释
|
if IniRead('system', 'algo_log_level', 'INFO') == 'DEBUG':
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
def _process_send_queue(self):
|
# -------------------------- 事件检测 --------------------------
|
||||||
"""处理发送队列,向所有命令客户端广播消息"""
|
def detect_event(self, samples):
|
||||||
while not self.send_queue.empty():
|
self.pack_contain_event = False
|
||||||
method, params = self.send_queue.get()
|
# 第65通道为事件通道
|
||||||
if self.cmd_clients:
|
event = int(samples[-2][0])
|
||||||
try:
|
# for idx, event in enumerate(events):
|
||||||
msg = {'method': method, 'params': params}
|
if event in self.events:
|
||||||
msg_bytes = json.dumps(msg).encode('utf-8')
|
new_key = "".join(
|
||||||
|
[
|
||||||
# 打印日志(隐藏大尺寸数据)
|
str(event),
|
||||||
if method in ['single_trial_plot', 'miReport']:
|
datetime.datetime.now().strftime("%Y-%m-%d \
|
||||||
print(f"{{'method': '{method}', 'params': <Base64 Image Data>}}")
|
-%H-%M-%S"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.currentLabel = event
|
||||||
|
if event == self.predict_event:
|
||||||
|
self.count_events[new_key] = self.latency + 1
|
||||||
else:
|
else:
|
||||||
print(f"Sending CMD message: {msg}")
|
self.count_events[new_key] = self.train_latency + 1
|
||||||
|
self.event_inner_idx = self.device_info['frame_points'] - 1
|
||||||
|
# algo_log(f"事件检测到: {event},索引: {idx}", level="DEBUG")
|
||||||
|
self.pack_contain_event = True
|
||||||
|
|
||||||
# 广播到所有命令客户端
|
# 倒计时并清理过期事件
|
||||||
for client_id in list(self.cmd_clients):
|
drop_items = []
|
||||||
try:
|
for key, value in self.count_events.items():
|
||||||
self.cmd_socket.send_multipart([client_id, b'', msg_bytes])
|
value -= 1
|
||||||
except Exception as e:
|
if value == 0:
|
||||||
print(f"Error sending to CMD client {client_id}: {e}")
|
drop_items.append(key)
|
||||||
self.cmd_clients.discard(client_id) # 移除失效客户端
|
self.count_events[key] = value
|
||||||
except Exception as e:
|
|
||||||
print(f"Error preparing broadcast: {e}")
|
|
||||||
|
|
||||||
|
for key in drop_items:
|
||||||
|
del self.count_events[key]
|
||||||
|
|
||||||
|
if drop_items:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
# -------------------------- 主循环 --------------------------
|
||||||
def run(self):
|
def run(self):
|
||||||
self.running = True
|
self.running = True
|
||||||
print(f"ZMQ Server started - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}")
|
algo_log(f"ZMQ服务器启动成功 - host: {self.host}, 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while self.running:
|
while self.running:
|
||||||
# 1. 处理发送队列(命令端口广播)
|
# 1. 处理两个端口的发送队列(必须在主线程执行)
|
||||||
self._process_send_queue()
|
self._process_cmd_send_queue()
|
||||||
|
self._process_data_send_queue()
|
||||||
|
|
||||||
# 2. 轮训监听两个Socket的输入事件(10ms超时,避免阻塞)
|
# 2. 轮询监听两个端口的输入事件
|
||||||
socks = dict(self.poller.poll(10))
|
socks = dict(self.poller.poll(50))
|
||||||
|
|
||||||
# 处理命令端口消息
|
# 处理8099命令端口消息
|
||||||
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:
|
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:
|
||||||
frames = self.cmd_socket.recv_multipart()
|
frames = self.cmd_socket.recv_multipart()
|
||||||
self._handle_cmd_message(frames)
|
self._handle_cmd_message(frames)
|
||||||
|
|
||||||
# 处理数据端口消息
|
# 处理8100数据端口消息
|
||||||
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
|
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
|
||||||
frames = self.data_socket.recv_multipart()
|
frames = self.data_socket.recv_multipart()
|
||||||
self._handle_data_message(frames)
|
self._handle_data_message(frames)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Server error occurred: {e}")
|
algo_log(f"服务器主循环异常: {e}", level="ERROR")
|
||||||
finally:
|
finally:
|
||||||
self.running = False
|
self.running = False
|
||||||
# 关闭所有Socket和上下文
|
# 优雅关闭所有资源
|
||||||
self.cmd_socket.close()
|
self.cmd_socket.close()
|
||||||
self.data_socket.close()
|
self.data_socket.close()
|
||||||
self.context.term()
|
self.context.term()
|
||||||
print("Server sockets and context closed.")
|
algo_log("ZMQ服务器已关闭", level="INFO")
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""显式关闭服务器"""
|
"""显式关闭服务器"""
|
||||||
@@ -243,10 +426,10 @@ class zmqServer(threading.Thread):
|
|||||||
self.cmd_socket.close()
|
self.cmd_socket.close()
|
||||||
self.data_socket.close()
|
self.data_socket.close()
|
||||||
self.context.term()
|
self.context.term()
|
||||||
print(f"Server closed explicitly - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}")
|
algo_log(f"服务器已显式关闭 - 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# 初始化并启动服务器(默认cmd=8099, data=8100)
|
# 初始化并启动服务器
|
||||||
server = zmqServer()
|
server = zmqServer()
|
||||||
server.start()
|
server.start()
|
||||||
|
|
||||||
@@ -255,5 +438,5 @@ if __name__ == '__main__':
|
|||||||
while server.running:
|
while server.running:
|
||||||
threading.Event().wait(1)
|
threading.Event().wait(1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("Received KeyboardInterrupt, stopping server...")
|
algo_log("收到键盘中断信号,正在停止服务器...", level="INFO")
|
||||||
server.stop()
|
server.stop()
|
||||||
@@ -1,445 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import zmq
|
|
||||||
import threading
|
|
||||||
import json
|
|
||||||
import queue
|
|
||||||
import time
|
|
||||||
from Device.SunnyLinker import SunnyLinker64, RingBuffer
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
|
|
||||||
class zmqServer(threading.Thread):
|
|
||||||
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100):
|
|
||||||
threading.Thread.__init__(self)
|
|
||||||
self.host = host
|
|
||||||
self.cmd_port = cmd_port
|
|
||||||
self.data_port = data_port
|
|
||||||
self.running = False
|
|
||||||
self.get_Impedance = False
|
|
||||||
self.open_Impedance = None
|
|
||||||
self.StartDecode = False
|
|
||||||
self.StartTrain = False
|
|
||||||
self.state_mode = None
|
|
||||||
self.currentLabel = -1
|
|
||||||
self.IsExitApp = False
|
|
||||||
self.getReport = False
|
|
||||||
self.daemon = True
|
|
||||||
|
|
||||||
# ZMQ Context
|
|
||||||
self.context = zmq.Context()
|
|
||||||
|
|
||||||
# 指令通道 (8099) - ROUTER
|
|
||||||
self.cmd_socket = self.context.socket(zmq.ROUTER)
|
|
||||||
self.cmd_socket.setsockopt(zmq.RCVHWM, 1000)
|
|
||||||
self.cmd_socket.setsockopt(zmq.SNDHWM, 1000)
|
|
||||||
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
|
|
||||||
|
|
||||||
# 数据通道 (8100)) - ROUTER
|
|
||||||
self.data_socket = self.context.socket(zmq.ROUTER)
|
|
||||||
self.data_socket.setsockopt(zmq.RCVHWM, 1000)
|
|
||||||
self.data_socket.setsockopt(zmq.RCVTIMEO, 50)
|
|
||||||
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
|
|
||||||
|
|
||||||
self.targetFreqs = []
|
|
||||||
self.changeTarget = False
|
|
||||||
self.sunnyLinker = SunnyLinker64(None, None, None, None, None)
|
|
||||||
self.labels = [0x01, 0x02, 0x03]
|
|
||||||
|
|
||||||
self.decoder_switch = False
|
|
||||||
self.decoder_class = None
|
|
||||||
self.cmd_clients = set()
|
|
||||||
self.data_clients = set()
|
|
||||||
self.send_queue = queue.Queue()
|
|
||||||
|
|
||||||
# ========== 数据缓冲区 (RingBuffer) ==========
|
|
||||||
# 与 SunnyLinker 保持一致,使用 RingBuffer
|
|
||||||
# 66 = 64 EEG通道 + 1 事件通道(第65) + 1 标签序号通道(第66)
|
|
||||||
# 缓存约 10 秒数据 (250Hz * 10s = 2500 点)
|
|
||||||
self.n_chan = 66
|
|
||||||
self.t_buffer = 10.0 # 缓冲区时长(秒)
|
|
||||||
self.__ringBuffer = RingBuffer(self.n_chan, int(self.t_buffer * 250))
|
|
||||||
|
|
||||||
# 事件检测相关
|
|
||||||
self._event_lock = threading.Lock()
|
|
||||||
self._epoch_finished = False
|
|
||||||
self._event_inner_idx = -1
|
|
||||||
self.pack_contain_event = False
|
|
||||||
self.predict_event = 99
|
|
||||||
self.events = [1, 2, self.predict_event]
|
|
||||||
self.count_events = {}
|
|
||||||
self.latency = 50
|
|
||||||
self.train_latency = 50
|
|
||||||
|
|
||||||
# 当前事件标签序号 (从第66通道获取)
|
|
||||||
self.current_label_index = 0
|
|
||||||
|
|
||||||
# 初始化标志
|
|
||||||
self._interval_inited = False
|
|
||||||
self._currentLabel = -1
|
|
||||||
|
|
||||||
# 注册的客户端(兼容旧接口)
|
|
||||||
self.clients = set()
|
|
||||||
|
|
||||||
# ========== 事件属性:线程安全访问 ==========
|
|
||||||
@property
|
|
||||||
def epoch_finished(self):
|
|
||||||
with self._event_lock:
|
|
||||||
return self._epoch_finished
|
|
||||||
|
|
||||||
@epoch_finished.setter
|
|
||||||
def epoch_finished(self, value):
|
|
||||||
with self._event_lock:
|
|
||||||
self._epoch_finished = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def event_inner_idx(self):
|
|
||||||
with self._event_lock:
|
|
||||||
return self._event_inner_idx
|
|
||||||
|
|
||||||
@event_inner_idx.setter
|
|
||||||
def event_inner_idx(self, value):
|
|
||||||
with self._event_lock:
|
|
||||||
self._event_inner_idx = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def interval_inited(self):
|
|
||||||
return self._interval_inited
|
|
||||||
|
|
||||||
@interval_inited.setter
|
|
||||||
def interval_inited(self, value):
|
|
||||||
self._interval_inited = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def currentLabel(self):
|
|
||||||
return self._currentLabel
|
|
||||||
|
|
||||||
@currentLabel.setter
|
|
||||||
def currentLabel(self, value):
|
|
||||||
self._currentLabel = value
|
|
||||||
|
|
||||||
def broadcast_message(self, method, params):
|
|
||||||
"""Put message into queue to be sent to all connected clients"""
|
|
||||||
self.send_queue.put((method, params))
|
|
||||||
|
|
||||||
# ========== 数据缓冲区操作接口 ==========
|
|
||||||
def GetDataLenCount(self):
|
|
||||||
"""返回缓冲区当前数据点数"""
|
|
||||||
return self.__ringBuffer.nUpdate
|
|
||||||
|
|
||||||
def getData(self, count):
|
|
||||||
"""获取最新count个数据点,不消费(只读)"""
|
|
||||||
with self.__ringBuffer.RingBufferLock:
|
|
||||||
count = min(count, self.__ringBuffer.nUpdate)
|
|
||||||
if count == 0:
|
|
||||||
return np.zeros((self.n_chan, 0))
|
|
||||||
|
|
||||||
# 计算读取范围(从尾部取最新数据)
|
|
||||||
read_end = (self.__ringBuffer.currentPtr - 1) % self.__ringBuffer.n_points
|
|
||||||
read_start = (read_end - count + 1) % self.__ringBuffer.n_points
|
|
||||||
|
|
||||||
if self.__ringBuffer.currentPtr == 0:
|
|
||||||
read_start = self.__ringBuffer.n_points - count
|
|
||||||
read_end = self.__ringBuffer.n_points - 1
|
|
||||||
|
|
||||||
if read_start <= read_end:
|
|
||||||
data = self.__ringBuffer.buffer[:, read_start:read_end + 1]
|
|
||||||
else:
|
|
||||||
part1 = self.__ringBuffer.buffer[:, read_start:]
|
|
||||||
part2 = self.__ringBuffer.buffer[:, :read_end + 1]
|
|
||||||
data = np.concatenate((part1, part2), axis=1)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def consumeData(self, count):
|
|
||||||
"""消费(丢弃)指定数量的数据点,从头部移除"""
|
|
||||||
with self.__ringBuffer.RingBufferLock:
|
|
||||||
count = min(count, self.__ringBuffer.nUpdate)
|
|
||||||
self.__ringBuffer.readPtr = (self.__ringBuffer.readPtr + count) % self.__ringBuffer.n_points
|
|
||||||
self.__ringBuffer.nUpdate -= count
|
|
||||||
|
|
||||||
def ResetAll(self):
|
|
||||||
"""重置缓冲区"""
|
|
||||||
with self.__ringBuffer.RingBufferLock:
|
|
||||||
self.__ringBuffer.resetAllPara()
|
|
||||||
with self._event_lock:
|
|
||||||
self._epoch_finished = False
|
|
||||||
self._event_inner_idx = -1
|
|
||||||
self.pack_contain_event = False
|
|
||||||
self.count_events.clear()
|
|
||||||
self.current_label_index = 0
|
|
||||||
|
|
||||||
def reset_data_buffer(self):
|
|
||||||
self.ResetAll()
|
|
||||||
|
|
||||||
def reset_state(self):
|
|
||||||
self.ResetAll()
|
|
||||||
|
|
||||||
def interval_init(self, decoder_class):
|
|
||||||
"""初始化事件检测参数"""
|
|
||||||
import ast
|
|
||||||
from PubLibrary.InifileHelper import IniRead
|
|
||||||
|
|
||||||
if decoder_class == 'ssmvep':
|
|
||||||
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
|
||||||
self.interval_epoch = [int(i * 250) for i in interval_epoch]
|
|
||||||
self.train_epoch = [int(self.interval_epoch[0]),
|
|
||||||
int(self.interval_epoch[1] + 0.1 * 250)]
|
|
||||||
self.latency = (self.interval_epoch[1] + 0.1 * 250) // 5
|
|
||||||
self.train_latency = (self.train_epoch[1] + 0.1 * 250) // 5
|
|
||||||
|
|
||||||
elif decoder_class == 'mi':
|
|
||||||
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
|
||||||
self.interval_epoch = [int(i * 250) for i in interval_epoch]
|
|
||||||
self.train_epoch = self.interval_epoch.copy()
|
|
||||||
self.latency = self.interval_epoch[1] // 5
|
|
||||||
self.train_latency = self.latency
|
|
||||||
|
|
||||||
self.count_events = {}
|
|
||||||
self._event_inner_idx = -1
|
|
||||||
self._epoch_finished = False
|
|
||||||
self.pack_contain_event = False
|
|
||||||
self.predict_event = 99
|
|
||||||
self.events = [1, 2, self.predict_event]
|
|
||||||
self._interval_inited = True
|
|
||||||
|
|
||||||
# ========== 事件检测 ==========
|
|
||||||
def detect_event(self, data_matrix):
|
|
||||||
"""
|
|
||||||
检测事件通道中的触发信号
|
|
||||||
|
|
||||||
@param data_matrix: shape (66, N) - N个采样点的数据
|
|
||||||
第65行(索引64) = 事件通道
|
|
||||||
第66行(索引65) = 标签通道
|
|
||||||
@return: 是否检测到事件
|
|
||||||
"""
|
|
||||||
if data_matrix.shape[1] == 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.pack_contain_event = False
|
|
||||||
event_channel = data_matrix[64, :] # 第65通道 = 标签值(event值)
|
|
||||||
label_channel = data_matrix[65, :] # 第66通道 = 标签序号(label index)
|
|
||||||
|
|
||||||
events = event_channel.tolist()
|
|
||||||
|
|
||||||
with self._event_lock:
|
|
||||||
self._event_inner_idx = -1
|
|
||||||
self.current_event_label = 0
|
|
||||||
|
|
||||||
for idx, event in enumerate(events):
|
|
||||||
if int(event) in self.events:
|
|
||||||
self._event_inner_idx = idx
|
|
||||||
self.current_label_index = int(label_channel[idx])
|
|
||||||
self.pack_contain_event = True
|
|
||||||
|
|
||||||
new_key = f"{event}_{time.time()}"
|
|
||||||
latency = self.latency if event == self.predict_event else self.train_latency
|
|
||||||
self.count_events[new_key] = latency + 1
|
|
||||||
|
|
||||||
# 延迟计数递减
|
|
||||||
drop_items = []
|
|
||||||
for key, value in self.count_events.items():
|
|
||||||
value = value - 1
|
|
||||||
if value == 0:
|
|
||||||
drop_items.append(key)
|
|
||||||
self.count_events[key] = value
|
|
||||||
for key in drop_items:
|
|
||||||
del self.count_events[key]
|
|
||||||
|
|
||||||
if drop_items:
|
|
||||||
self._epoch_finished = True
|
|
||||||
# 检测到事件时,清除RingBuffer中之前的数据,只保留当前包
|
|
||||||
if self.pack_contain_event:
|
|
||||||
self.__ringBuffer.resetAllPara()
|
|
||||||
return True
|
|
||||||
|
|
||||||
self._epoch_finished = False
|
|
||||||
return False
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
self.running = True
|
|
||||||
print(f"Server running - CMD: {self.cmd_port}, DATA: {self.data_port}")
|
|
||||||
|
|
||||||
cmd_poller = zmq.Poller()
|
|
||||||
cmd_poller.register(self.cmd_socket, zmq.POLLIN)
|
|
||||||
|
|
||||||
data_poller = zmq.Poller()
|
|
||||||
data_poller.register(self.data_socket, zmq.POLLIN)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while self.running:
|
|
||||||
# --- 处理发送队列 (指令通道) ---
|
|
||||||
while not self.send_queue.empty():
|
|
||||||
method, params = self.send_queue.get()
|
|
||||||
if self.cmd_clients:
|
|
||||||
try:
|
|
||||||
msg = {'method': method, 'params': params}
|
|
||||||
msg_bytes = json.dumps(msg).encode('utf-8')
|
|
||||||
for client_id in list(self.cmd_clients):
|
|
||||||
try:
|
|
||||||
self.cmd_socket.send_multipart([client_id, b'', msg_bytes])
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# --- 处理指令通道 ---
|
|
||||||
socks = dict(cmd_poller.poll(10))
|
|
||||||
if self.cmd_socket in socks:
|
|
||||||
self._handle_cmd_socket()
|
|
||||||
|
|
||||||
# --- 处理数据通道 ---
|
|
||||||
socks = dict(data_poller.poll(10))
|
|
||||||
if self.data_socket in socks:
|
|
||||||
self._handle_data_socket()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Server error: {e}")
|
|
||||||
finally:
|
|
||||||
self.running = False
|
|
||||||
self.cmd_socket.close()
|
|
||||||
self.data_socket.close()
|
|
||||||
self.context.term()
|
|
||||||
|
|
||||||
def _handle_cmd_socket(self):
|
|
||||||
"""处理指令通道消息"""
|
|
||||||
try:
|
|
||||||
frames = self.cmd_socket.recv_multipart()
|
|
||||||
if len(frames) < 3:
|
|
||||||
return
|
|
||||||
ident, _, message_bytes = frames[:3]
|
|
||||||
self.cmd_clients.add(ident)
|
|
||||||
self.clients.add(ident)
|
|
||||||
|
|
||||||
message = json.loads(message_bytes.decode('utf-8'))
|
|
||||||
method = message.get("method")
|
|
||||||
params = message.get("params")
|
|
||||||
|
|
||||||
print(f"[CMD] {method}: {params}")
|
|
||||||
|
|
||||||
if method == "sync":
|
|
||||||
self.state_mode = 'sync'
|
|
||||||
elif method == "targetFreqs":
|
|
||||||
if isinstance(params, list) and params != self.targetFreqs:
|
|
||||||
self.targetFreqs = params
|
|
||||||
self.changeTarget = True
|
|
||||||
elif method == "decoderClass":
|
|
||||||
if isinstance(params, str) and params != self.decoder_class:
|
|
||||||
self.decoder_class = params
|
|
||||||
self.decoder_switch = True
|
|
||||||
elif method == "getReport":
|
|
||||||
self.getReport = True
|
|
||||||
elif method == "train":
|
|
||||||
self.state_mode = 'train'
|
|
||||||
self.StartTrain = True
|
|
||||||
self.currentLabel = params
|
|
||||||
elif method == "predict":
|
|
||||||
self.state_mode = 'predict'
|
|
||||||
if params == 1:
|
|
||||||
self.StartDecode = True
|
|
||||||
elif params == 2:
|
|
||||||
self.IsExitApp = True
|
|
||||||
self.running = False
|
|
||||||
elif method == "rest":
|
|
||||||
self.state_mode = 'rest'
|
|
||||||
elif method == "impedance":
|
|
||||||
if params == 1:
|
|
||||||
self.open_Impedance = True
|
|
||||||
self.get_Impedance = True
|
|
||||||
elif params == 2:
|
|
||||||
self.open_Impedance = False
|
|
||||||
self.get_Impedance = False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"CMD socket error: {e}")
|
|
||||||
|
|
||||||
def _handle_data_socket(self):
|
|
||||||
"""处理数据通道消息 (EEG数据)
|
|
||||||
|
|
||||||
上位机数据格式:
|
|
||||||
- 数据帧: [identity, '', meta_json, data_buffer]
|
|
||||||
data_buffer = [N, 66] float32 -> 转置为 [66, N]
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
frames = self.data_socket.recv_multipart()
|
|
||||||
if len(frames) < 4:
|
|
||||||
return
|
|
||||||
ident, _, message_bytes = frames[:3]
|
|
||||||
self.data_clients.add(ident)
|
|
||||||
|
|
||||||
meta = json.loads(message_bytes.decode('utf-8'))
|
|
||||||
|
|
||||||
# data: [N, 66] -> 转置 -> [66, N]
|
|
||||||
raw_data = np.frombuffer(frames[3], dtype=np.float32)
|
|
||||||
n_samples, n_channels = meta.get('shape', [5, 66])
|
|
||||||
data_matrix = raw_data.reshape(n_samples, n_channels).T.astype(np.float32)
|
|
||||||
|
|
||||||
# 写入 RingBuffer
|
|
||||||
with self.__ringBuffer.RingBufferLock:
|
|
||||||
self.__ringBuffer.appendBuffer(data_matrix)
|
|
||||||
|
|
||||||
# 事件检测
|
|
||||||
self.detect_event(data_matrix)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"DATA socket error: {e}")
|
|
||||||
|
|
||||||
# ========== 各范式数据访问接口 ==========
|
|
||||||
def get_MIData(self):
|
|
||||||
"""获取MI导联数据 (21通道 + 事件)"""
|
|
||||||
data = self.getData(self.GetDataLenCount())
|
|
||||||
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
|
|
||||||
row_to_select = np.array(rows_to_extract)
|
|
||||||
if data.shape[1] > 0:
|
|
||||||
return data[row_to_select, :]
|
|
||||||
return np.zeros((len(rows_to_extract), 0))
|
|
||||||
|
|
||||||
def get_SSMVEPData(self):
|
|
||||||
"""获取SSMVEP导联数据 (8通道 + 事件)"""
|
|
||||||
data = self.getData(self.GetDataLenCount())
|
|
||||||
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
|
|
||||||
row_to_select = np.array(rows_to_extract)
|
|
||||||
if data.shape[1] > 0:
|
|
||||||
return data[row_to_select, :]
|
|
||||||
return np.zeros((len(rows_to_extract), 0))
|
|
||||||
|
|
||||||
def getDataViaSSVEP(self, count):
|
|
||||||
"""获取SSVEP数据 (8通道 + 事件)"""
|
|
||||||
data = self.getData(count)
|
|
||||||
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
|
|
||||||
row_to_select = np.array(rows_to_extract)
|
|
||||||
if data.shape[1] > 0:
|
|
||||||
return data[row_to_select, :]
|
|
||||||
return np.zeros((len(rows_to_extract), 0))
|
|
||||||
|
|
||||||
def get_concentrateData(self, count):
|
|
||||||
"""获取专注力数据 (2通道)"""
|
|
||||||
data = self.getData(count)
|
|
||||||
rows_to_extract = [0, 1]
|
|
||||||
row_to_select = np.array(rows_to_extract)
|
|
||||||
if data.shape[1] > 0:
|
|
||||||
return data[row_to_select, :]
|
|
||||||
return np.zeros((len(rows_to_extract), 0))
|
|
||||||
|
|
||||||
def get_blinkData(self, count):
|
|
||||||
"""获取眨眼数据 (2通道)"""
|
|
||||||
data = self.getData(count)
|
|
||||||
rows_to_extract = [0, 1]
|
|
||||||
row_to_select = np.array(rows_to_extract)
|
|
||||||
if data.shape[1] > 0:
|
|
||||||
return data[row_to_select, :]
|
|
||||||
return np.zeros((len(rows_to_extract), 0))
|
|
||||||
|
|
||||||
def getImpedance(self, data, decoder_class):
|
|
||||||
"""计算阻抗(ZMQ模式下不可用)"""
|
|
||||||
return np.zeros(8)
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
self.running = False
|
|
||||||
self.cmd_socket.close()
|
|
||||||
self.data_socket.close()
|
|
||||||
self.context.term()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
server = zmqServer()
|
|
||||||
server.start()
|
|
||||||
@@ -8,6 +8,7 @@ import os
|
|||||||
# import logging
|
# import logging
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
|
import math
|
||||||
|
|
||||||
# logger = logging.getLogger(__name__)
|
# logger = logging.getLogger(__name__)
|
||||||
#
|
#
|
||||||
@@ -22,7 +23,7 @@ import io
|
|||||||
|
|
||||||
|
|
||||||
class Calculate():
|
class Calculate():
|
||||||
def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10):
|
def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10, config=None):
|
||||||
self.Threshold_value_low = Threshold_value_low
|
self.Threshold_value_low = Threshold_value_low
|
||||||
self.Threshold_value_high = Threshold_value_high
|
self.Threshold_value_high = Threshold_value_high
|
||||||
self.fs = fs
|
self.fs = fs
|
||||||
@@ -31,47 +32,73 @@ class Calculate():
|
|||||||
self.EVI_result = []
|
self.EVI_result = []
|
||||||
self.eegQueue = deque(maxlen=win_len)
|
self.eegQueue = deque(maxlen=win_len)
|
||||||
|
|
||||||
# # 存储历史数据用于绘图
|
|
||||||
# self.beta_history = []
|
|
||||||
# self.alpha_history = []
|
|
||||||
# self.theta_history = []
|
|
||||||
# self.focus_history = []
|
|
||||||
# self.timestamp_history = []
|
|
||||||
#
|
|
||||||
# # 记录开始时间
|
|
||||||
# self.start_time = None
|
|
||||||
# self.recording = False
|
|
||||||
#
|
|
||||||
# # 图表保存路径
|
|
||||||
# self.chart_dir = "reports"
|
|
||||||
# if not os.path.exists(self.chart_dir):
|
|
||||||
# os.makedirs(self.chart_dir)
|
|
||||||
# print(f"[调试] 创建目录: {self.chart_dir}")
|
|
||||||
|
|
||||||
# 初始化滤波器
|
# 初始化滤波器
|
||||||
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30)
|
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30)
|
||||||
self.b_design = signal.firwin(65, [2 / (self.fs/2), 40 / (self.fs/2)], pass_zero=False)
|
self.b_design = signal.firwin(65, [2 / (self.fs/2), 40 / (self.fs/2)], pass_zero=False)
|
||||||
|
|
||||||
|
self.last_focus = None
|
||||||
|
# 异步滤波系数配置(核心手感控制纽)
|
||||||
|
self.alpha_up = 1 # 上升系数:较小,保证分数平滑爬升,过滤偶发的瞬时高能量
|
||||||
|
# alpha_down / shrink_factor 从 config.ini 读取,方便上位机调参
|
||||||
|
if config:
|
||||||
|
self.alpha_down = float(config.get('alpha_down', 0.8))
|
||||||
|
self.shrink_factor = float(config.get('shrink_factor', 0.5))
|
||||||
|
else:
|
||||||
|
self.alpha_down = 0.8
|
||||||
|
self.shrink_factor = 0.5
|
||||||
print("[调试] Calculate 类初始化完成")
|
print("[调试] Calculate 类初始化完成")
|
||||||
|
|
||||||
def calculate_focus(self, beta, alpha, theta):
|
def calculate_focus(self, beta, alpha, theta):
|
||||||
"""
|
"""
|
||||||
专注度计算 - 固定映射版本
|
专注度计算 - 三区间门限异步滤波版本
|
||||||
"""
|
"""
|
||||||
|
# 0. 频带特征预处理
|
||||||
|
theta_mod = theta ** 0.7
|
||||||
|
|
||||||
# 原始比值
|
# 原始比值
|
||||||
raw = beta / (alpha + theta + 1e-10)
|
raw = beta / (alpha + theta_mod + 1e-10)
|
||||||
|
|
||||||
# Sigmoid 映射:让 raw 在 0.3-1.5 区间敏感
|
exponent = 2.0
|
||||||
# 参数可调:
|
|
||||||
# k = 12 (斜率,越大越陡)
|
|
||||||
# x0 = 0.6 (中心点,raw=0.6时focus≈50)
|
|
||||||
k = 12.0
|
|
||||||
x0 = 0.6
|
|
||||||
focus = 100.0 / (1.0 + np.exp(-k * (raw - x0)))
|
|
||||||
|
|
||||||
# 可选:添加滑动平均平滑
|
# 1. 防止脑电比值出现负数异常值
|
||||||
|
raw_input = max(raw, 0.0)
|
||||||
|
|
||||||
|
# 2. 2次幂纵轴压缩映射 (shrink_factor 从 config.ini 读取)
|
||||||
|
focus_raw = 100 * self.shrink_factor * (raw_input ** exponent)
|
||||||
|
|
||||||
|
# 3. 计算当前帧的瞬时分数 (基准量级 0-120)
|
||||||
|
instant_focus = 120 * (1.0 - np.exp(-focus_raw / 100.0))
|
||||||
|
|
||||||
|
# 4. 核心修改:三区间门限时域滤波
|
||||||
|
if self.last_focus is None:
|
||||||
|
# 冷启动:首帧直接赋值
|
||||||
|
focus = instant_focus
|
||||||
|
else:
|
||||||
|
# 判断当前瞬时分数是否处于【极端区】(80以上 或 60以下)
|
||||||
|
if instant_focus > 85.0 or instant_focus < 60.0:
|
||||||
|
# 执行异步低通时域滤波
|
||||||
|
if instant_focus >= self.last_focus:
|
||||||
|
# 趋势上升:慢爬升
|
||||||
|
focus = self.alpha_up * instant_focus + (1 - self.alpha_up) * self.last_focus
|
||||||
|
else:
|
||||||
|
# 趋势下降:快跌落
|
||||||
|
focus = self.alpha_down * instant_focus + (1 - self.alpha_down) * self.last_focus
|
||||||
|
else:
|
||||||
|
# 【高灵敏自由区】(60 <= instant_focus <= 80)
|
||||||
|
# 不执行异步滤波,分数直接跟随瞬时值,保证中间状态绝对跟手
|
||||||
|
focus = instant_focus
|
||||||
|
|
||||||
|
# 5. 更新历史状态缓存
|
||||||
|
self.last_focus = focus
|
||||||
|
|
||||||
|
# 打印在线调试日志,方便观察区间切换
|
||||||
|
zone_tag = "极端区(滤波)" if (instant_focus > 80 or instant_focus < 60) else "自由区(直通)"
|
||||||
|
print(f"原始特征比值 raw: {raw:.4f} | 瞬时分数: {instant_focus:.1f} | 滤波后分数: {focus:.1f}")
|
||||||
|
|
||||||
|
# 最终返回整型
|
||||||
return int(focus)
|
return int(focus)
|
||||||
|
|
||||||
|
|
||||||
def calculate_all(self, data, fs, nperseg=1000):
|
def calculate_all(self, data, fs, nperseg=1000):
|
||||||
mean_x = np.mean(data, axis=-1, keepdims=True)
|
mean_x = np.mean(data, axis=-1, keepdims=True)
|
||||||
data = data - mean_x
|
data = data - mean_x
|
||||||
@@ -319,14 +346,16 @@ class Calculate():
|
|||||||
if eegData.size == 0:
|
if eegData.size == 0:
|
||||||
return None
|
return None
|
||||||
eegData -= np.mean(eegData, axis=-1, keepdims=True)
|
eegData -= np.mean(eegData, axis=-1, keepdims=True)
|
||||||
eegData = signal.lfilter(self.b_notch, self.a_notch, eegData)
|
# eegData = signal.lfilter(self.b_notch, self.a_notch, eegData) # 陷波
|
||||||
eegData = signal.lfilter(self.b_design, 1, eegData)
|
# eegData = signal.lfilter(self.b_design, 1, eegData) # 滤波
|
||||||
focus_score, CLI_score, beta, alpha, theta = self.calculate_all(eegData, fs=self.fs, nperseg=1000)
|
focus_score, CLI_score, beta_psd, alpha_psd, theta_psd = self.calculate_all(eegData, fs=self.fs, nperseg=1000)
|
||||||
|
|
||||||
# self.add_data_point(focus_score, beta, alpha, theta)
|
# self.add_data_point(focus_score, beta_psd, alpha_psd, theta_psd) # 已注释(方法已移除)
|
||||||
|
|
||||||
|
# return (focus_score)
|
||||||
|
return (focus_score, beta_psd)
|
||||||
|
# return None
|
||||||
|
|
||||||
return focus_score
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class Calculate2():
|
class Calculate2():
|
||||||
|
|||||||
157
config.ini
157
config.ini
@@ -13,159 +13,20 @@ Num_blocks = 1
|
|||||||
Num_trials = 10
|
Num_trials = 10
|
||||||
Audio_device = 0
|
Audio_device = 0
|
||||||
Rest_time = 2
|
Rest_time = 2
|
||||||
Device_type = 1
|
|
||||||
Device_Host = 127.0.0.1
|
|
||||||
Device_Port = 5086
|
|
||||||
Upper_Host = 127.0.0.1
|
Upper_Host = 127.0.0.1
|
||||||
Upper_Port = 8088
|
Upper_Port = 8088
|
||||||
|
Decoder_Host = 127.0.0.1
|
||||||
|
Decoder_Port = 8099
|
||||||
Serial_port = COM44
|
Serial_port = COM44
|
||||||
algo_log_level = DEBUG
|
algo_log_level = DEBUG
|
||||||
console_output = 1
|
console_output = 1
|
||||||
|
save_train_data = 0
|
||||||
|
zmqServer_host = 127.0.0.1
|
||||||
|
|
||||||
; 64 导设备配置
|
; 64 导设备配置
|
||||||
[device_type_1]
|
[device_type_1]
|
||||||
device_sample_rate = 250
|
sample_rate = 250
|
||||||
device_channel_nums = 66
|
frame_points = 5
|
||||||
device_channel_names = ['FP1', 'FP2', 'FC1', 'FC2', 'CP1', 'CP2', 'F3', 'F4', 'P3', 'P4', 'O1', 'O2', 'FT9', 'FT10', 'F7', 'F8', 'TP9', 'TP10', 'AF4', 'PO8', 'PZ', 'FCZ']
|
channel_nums = 66
|
||||||
device_channel_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18]
|
channel_names = ['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1', 'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6', 'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1', 'T7', 'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ', 'PZ', 'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1', 'label', 'label_tag']
|
||||||
|
channel_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]
|
||||||
|
|
||||||
|
|
||||||
[Layout]
|
|
||||||
main_splitter_left = 993
|
|
||||||
main_splitter_right = 922
|
|
||||||
right_splitter_left = 233
|
|
||||||
right_splitter_right = 771
|
|
||||||
left_splitter_left = 503
|
|
||||||
left_splitter_right = 501q
|
|
||||||
|
|
||||||
[channel]
|
|
||||||
channel_x_fp1 = 419
|
|
||||||
channel_y_fp1 = 124
|
|
||||||
channel_x_fc1 = 439
|
|
||||||
channel_y_fc1 = 296
|
|
||||||
channel_x_fp2 = 576
|
|
||||||
channel_y_fp2 = 124
|
|
||||||
channel_x_fc2 = 556
|
|
||||||
channel_y_fc2 = 299
|
|
||||||
channel_x_f3 = 397
|
|
||||||
channel_y_f3 = 231
|
|
||||||
channel_x_cp1 = 439
|
|
||||||
channel_y_cp1 = 426
|
|
||||||
channel_x_f4 = 601
|
|
||||||
channel_y_f4 = 232
|
|
||||||
channel_x_cp2 = 559
|
|
||||||
channel_y_cp2 = 425
|
|
||||||
channel_x_fc3 = 379
|
|
||||||
channel_y_fc3 = 295
|
|
||||||
channel_x_af4 = 571
|
|
||||||
channel_y_af4 = 171
|
|
||||||
channel_x_po8 = 645
|
|
||||||
channel_y_po8 = 564
|
|
||||||
channel_x_fpz = 499
|
|
||||||
channel_y_fpz = 112
|
|
||||||
channel_x_fcz = 499
|
|
||||||
channel_y_fcz = 300
|
|
||||||
channel_x_poz = 500
|
|
||||||
channel_y_poz = 554
|
|
||||||
channel_x_po5 = 387
|
|
||||||
channel_y_po5 = 551
|
|
||||||
channel_x_po6 = 611
|
|
||||||
channel_y_po6 = 551
|
|
||||||
channel_x_c3 = 373
|
|
||||||
channel_y_c3 = 363
|
|
||||||
channel_x_fc5 = 319
|
|
||||||
channel_y_fc5 = 292
|
|
||||||
channel_x_c4 = 620
|
|
||||||
channel_y_c4 = 363
|
|
||||||
channel_x_fc6 = 676
|
|
||||||
channel_y_fc6 = 288
|
|
||||||
channel_x_p3 = 398
|
|
||||||
channel_y_p3 = 491
|
|
||||||
channel_x_cp5 = 322
|
|
||||||
channel_y_cp5 = 430
|
|
||||||
channel_x_p4 = 600
|
|
||||||
channel_y_p4 = 489
|
|
||||||
channel_x_cp6 = 678
|
|
||||||
channel_y_cp6 = 430
|
|
||||||
channel_x_c5 = 313
|
|
||||||
channel_y_c5 = 361
|
|
||||||
channel_x_f6 = 650
|
|
||||||
channel_y_f6 = 223
|
|
||||||
channel_x_f5 = 349
|
|
||||||
channel_y_f5 = 224
|
|
||||||
channel_x_po4 = 573
|
|
||||||
channel_y_po4 = 551
|
|
||||||
channel_x_po3 = 429
|
|
||||||
channel_y_po3 = 550
|
|
||||||
channel_x_cp4 = 619
|
|
||||||
channel_y_cp4 = 424
|
|
||||||
channel_x_cp3 = 381
|
|
||||||
channel_y_cp3 = 426
|
|
||||||
channel_x_fc4 = 619
|
|
||||||
channel_y_fc4 = 295
|
|
||||||
channel_x_o1 = 423
|
|
||||||
channel_y_o1 = 598
|
|
||||||
channel_x_ft9 = 252
|
|
||||||
channel_y_ft9 = 168
|
|
||||||
channel_x_o2 = 576
|
|
||||||
channel_y_o2 = 597
|
|
||||||
channel_x_ft10 = 798
|
|
||||||
channel_y_ft10 = 277
|
|
||||||
channel_x_f7 = 295
|
|
||||||
channel_y_f7 = 214
|
|
||||||
channel_x_tp9 = 202
|
|
||||||
channel_y_tp9 = 445
|
|
||||||
channel_x_f8 = 701
|
|
||||||
channel_y_f8 = 215
|
|
||||||
channel_x_t7 = 252
|
|
||||||
channel_y_t7 = 362
|
|
||||||
channel_x_tp7 = 261
|
|
||||||
channel_y_tp7 = 436
|
|
||||||
channel_x_ft8 = 734
|
|
||||||
channel_y_ft8 = 283
|
|
||||||
channel_x_ft7 = 264
|
|
||||||
channel_y_ft7 = 286
|
|
||||||
channel_x_af8 = 645
|
|
||||||
channel_y_af8 = 159
|
|
||||||
channel_x_af7 = 351
|
|
||||||
channel_y_af7 = 160
|
|
||||||
channel_x_p6 = 652
|
|
||||||
channel_y_p6 = 499
|
|
||||||
channel_x_p5 = 348
|
|
||||||
channel_y_p5 = 499
|
|
||||||
channel_x_c6 = 683
|
|
||||||
channel_y_c6 = 362
|
|
||||||
channel_x_f1 = 447
|
|
||||||
channel_y_f1 = 236
|
|
||||||
channel_x_t8 = 745
|
|
||||||
channel_y_t8 = 361
|
|
||||||
channel_x_f2 = 549
|
|
||||||
channel_y_f2 = 235
|
|
||||||
channel_x_p7 = 300
|
|
||||||
channel_y_p7 = 505
|
|
||||||
channel_x_c1 = 435
|
|
||||||
channel_y_c1 = 363
|
|
||||||
channel_x_p8 = 698
|
|
||||||
channel_y_p8 = 508
|
|
||||||
channel_x_c2 = 559
|
|
||||||
channel_y_c2 = 359
|
|
||||||
channel_x_fz = 499
|
|
||||||
channel_y_fz = 238
|
|
||||||
channel_x_po7 = 354
|
|
||||||
channel_y_po7 = 562
|
|
||||||
channel_x_tp8 = 735
|
|
||||||
channel_y_tp8 = 438
|
|
||||||
channel_x_oz = 498
|
|
||||||
channel_y_oz = 609
|
|
||||||
channel_x_af3 = 428
|
|
||||||
channel_y_af3 = 170
|
|
||||||
channel_x_pz = 501
|
|
||||||
channel_y_pz = 486
|
|
||||||
channel_x_p2 = 551
|
|
||||||
channel_y_p2 = 483
|
|
||||||
channel_x_cz = 499
|
|
||||||
channel_y_cz = 361
|
|
||||||
channel_x_p1 = 448
|
|
||||||
channel_y_p1 = 488
|
|
||||||
|
|
||||||
|
|||||||
188
datamock.py
Normal file
188
datamock.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
import zmq
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# ========== 参数配置 ==========
|
||||||
|
FS = 250 # 采样率 Hz
|
||||||
|
N_SAMPLES_PER_PKT = 5 # 每包采样点数
|
||||||
|
N_CHAN = 66 # 通道数: 64 EEG + 1 标签值 + 1 标签序号
|
||||||
|
EEG_FREQ = 10 # EEG 正弦波频率 Hz
|
||||||
|
EEG_AMP = 100.0 # EEG 幅值 100μV
|
||||||
|
LABEL_INTERVAL = 5 # 标签间隔秒数
|
||||||
|
SERVER_ADDR = 'tcp://127.0.0.1:8100'
|
||||||
|
LABEL_CMD_ADDR = 'tcp://127.0.0.1:8101' # 接收来自上位机范式的标签命令
|
||||||
|
|
||||||
|
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms
|
||||||
|
PKT_INTERVAL = N_SAMPLES_PER_PKT / FS
|
||||||
|
|
||||||
|
|
||||||
|
def build_packet(global_sample_idx):
|
||||||
|
"""
|
||||||
|
生成一包 [5, 66] 的 float64 数据
|
||||||
|
:param global_sample_idx: 当前包第一个采样点在全局序列中的索引 (从 0 开始)
|
||||||
|
:return: np.ndarray shape [5, 66]
|
||||||
|
"""
|
||||||
|
# 当前包内 5 个采样点对应的时间(秒)
|
||||||
|
t = (global_sample_idx + np.arange(N_SAMPLES_PER_PKT)) / FS
|
||||||
|
|
||||||
|
# Ch0-63: EEG 10Hz 正弦波,幅值 100μV
|
||||||
|
# t shape [5,],sin 乘以标量后仍是 [5,],需要 reshape 为 [5,1] 再广播到 64 通道
|
||||||
|
eeg = (EEG_AMP * np.sin(2 * np.pi * EEG_FREQ * t)).reshape(N_SAMPLES_PER_PKT, 1) # [5, 1]
|
||||||
|
eeg = np.tile(eeg, (1, 64)) # [5, 64]
|
||||||
|
|
||||||
|
# Ch64: 标签值通道,初始化为 0
|
||||||
|
event = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
|
||||||
|
|
||||||
|
# Ch65: 标签序号通道,初始化为 0
|
||||||
|
label_idx = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
|
||||||
|
|
||||||
|
# 拼成 [5, 66]
|
||||||
|
packet = np.concatenate([eeg, event, label_idx], axis=1).astype(np.float64)
|
||||||
|
return packet
|
||||||
|
|
||||||
|
|
||||||
|
def should_send_label(global_sample_idx):
|
||||||
|
"""
|
||||||
|
判断当前包是否包含标签触发点(每 5s 的最后一个采样点)
|
||||||
|
采样点索引从 0 开始,每 5s = 1250 个采样点
|
||||||
|
最后一个采样点索引: 1249, 2499, 3749, ...
|
||||||
|
由于每包 5 个采样点,标签点落在包内的最后一个采样点位置
|
||||||
|
即当前包起始索引 global_sample_idx 必须使得:
|
||||||
|
global_sample_idx <= 标签点索引 < global_sample_idx + N_SAMPLES_PER_PKT
|
||||||
|
也就是 global_sample_idx <= 1249 < global_sample_idx + 5
|
||||||
|
即 global_sample_idx = 1245, 2495, 3745, ...
|
||||||
|
即 global_sample_idx = n * LABEL_INTERVAL * FS - N_SAMPLES_PER_PKT
|
||||||
|
"""
|
||||||
|
samples_per_interval = LABEL_INTERVAL * FS
|
||||||
|
# 检查当前包是否包含 interval 的最后一个采样点
|
||||||
|
# 标签点索引 = n * 1250 - 1,当 global_sample_idx = n*1250-5 时,标签在包内索引 4
|
||||||
|
return (global_sample_idx + N_SAMPLES_PER_PKT - 1) % samples_per_interval == samples_per_interval - 1
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
ctx = zmq.Context()
|
||||||
|
sock = ctx.socket(zmq.DEALER)
|
||||||
|
sock.connect(SERVER_ADDR)
|
||||||
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] ZMQ Dealer 连接到 {SERVER_ADDR}")
|
||||||
|
|
||||||
|
# ========== 上位机标签命令监听 ==========
|
||||||
|
# 使用线程安全的队列接收来自 ssmvep_main.py 的标签命令
|
||||||
|
# 标签值: 1 (train 0), 2 (train 1), 99 (predict)
|
||||||
|
pending_label = [None] # [label_value or None]
|
||||||
|
label_lock = threading.Lock()
|
||||||
|
|
||||||
|
label_cmd_sock = ctx.socket(zmq.PULL)
|
||||||
|
label_cmd_sock.bind(LABEL_CMD_ADDR)
|
||||||
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] 标签命令监听绑定到 {LABEL_CMD_ADDR}")
|
||||||
|
|
||||||
|
stop_recv = threading.Event()
|
||||||
|
|
||||||
|
def label_cmd_thread():
|
||||||
|
"""监听来自上位机范式的标签命令,写入 pending_label"""
|
||||||
|
while not stop_recv.is_set():
|
||||||
|
try:
|
||||||
|
msg = label_cmd_sock.recv_string(zmq.NOBLOCK)
|
||||||
|
label_val = int(msg)
|
||||||
|
with label_lock:
|
||||||
|
pending_label[0] = label_val
|
||||||
|
ts = datetime.now().strftime('%H:%M:%S')
|
||||||
|
label_name = {1: 'train_0', 2: 'train_1', 99: 'predict'}.get(label_val, str(label_val))
|
||||||
|
print(f"[{ts}] 收到标签命令: {label_name} -> label={label_val}")
|
||||||
|
except zmq.Again:
|
||||||
|
time.sleep(0.005)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[label_cmd_thread] 错误: {e}")
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
label_thread = threading.Thread(target=label_cmd_thread, daemon=True)
|
||||||
|
label_thread.start()
|
||||||
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] 标签命令监听线程已启动")
|
||||||
|
|
||||||
|
# 后台消费线程:持续 recv 从 ROUTER 返回的数据,避免 server 发送队列积压
|
||||||
|
recv_count = [0]
|
||||||
|
|
||||||
|
def consumer_thread():
|
||||||
|
"""消费线程:阻塞 recv,丢弃收到的数据,仅用于清空 ROUTER 发送队列"""
|
||||||
|
while not stop_recv.is_set():
|
||||||
|
try:
|
||||||
|
frames = sock.recv_multipart(zmq.NOBLOCK)
|
||||||
|
recv_count[0] += 1
|
||||||
|
# 收到的格式: [identity, '', filtered_data_bytes]
|
||||||
|
if recv_count[0] % 500 == 0:
|
||||||
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] 消费线程已丢弃 {recv_count[0]} 帧滤波数据")
|
||||||
|
except zmq.Again:
|
||||||
|
time.sleep(0.01)
|
||||||
|
except zmq.error.Again: # 兼容旧版
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
consumer = threading.Thread(target=consumer_thread, daemon=True)
|
||||||
|
consumer.start()
|
||||||
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] 消费线程已启动(daemon)")
|
||||||
|
|
||||||
|
global_sample_idx = 0 # 全局采样点计数器
|
||||||
|
label_type = 1 # 当前标签类型: 1 或 2
|
||||||
|
label1_count = 0 # label=1 的序号计数器
|
||||||
|
label2_count = 0 # label=2 的序号计数器
|
||||||
|
packet_count = 0 # 已发送包数
|
||||||
|
|
||||||
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] 开始发送模拟数据 ...")
|
||||||
|
print(f" 采样率: {FS}Hz | 每包 {N_SAMPLES_PER_PKT} 采样点 | 发送间隔 {PKT_INTERVAL*1000:.0f}ms")
|
||||||
|
print(f" EEG: {EEG_FREQ}Hz 正弦波 | 幅值 {EEG_AMP}μV")
|
||||||
|
print(f" 标签: 来自上位机范式命令 (train_0=1, train_1=2, predict=99)")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
t_start = time.perf_counter()
|
||||||
|
|
||||||
|
# 构建当前包
|
||||||
|
packet = build_packet(global_sample_idx)
|
||||||
|
|
||||||
|
# 检查是否有来自上位机范式的挂起标签命令
|
||||||
|
with label_lock:
|
||||||
|
ext_label = pending_label[0]
|
||||||
|
if ext_label is not None:
|
||||||
|
pending_label[0] = None
|
||||||
|
|
||||||
|
if ext_label is not None:
|
||||||
|
# 将标签写入当前包所有5个采样点的第65通道 (index 64)
|
||||||
|
# 覆盖全部采样点确保 event_inner_idx 无论落在哪个位置都能被正确检测
|
||||||
|
packet[:, 64] = float(ext_label)
|
||||||
|
ts = datetime.now().strftime('%H:%M:%S')
|
||||||
|
print(f"[{ts}] 打标签: label={ext_label} -> ch64[all 5 samples] (global_sample_idx={global_sample_idx})")
|
||||||
|
|
||||||
|
# 发送: multipart 2帧 ['', data]
|
||||||
|
# 使用标准格式,ROUTER 会自动附加 ZMQ 分配的客户端身份
|
||||||
|
sock.send_multipart([
|
||||||
|
b'',
|
||||||
|
packet.tobytes()
|
||||||
|
])
|
||||||
|
|
||||||
|
# 每 50 包打印一次进度
|
||||||
|
if packet_count % 50 == 0:
|
||||||
|
ts = datetime.now().strftime('%H:%M:%S')
|
||||||
|
print(f"[{ts}] 已发送 {packet_count} 包 (global_sample_idx={global_sample_idx})")
|
||||||
|
|
||||||
|
global_sample_idx += N_SAMPLES_PER_PKT
|
||||||
|
packet_count += 1
|
||||||
|
|
||||||
|
# 精确控制发送节奏: 等待到 PKT_INTERVAL 秒
|
||||||
|
elapsed = time.perf_counter() - t_start
|
||||||
|
sleep_time = PKT_INTERVAL - elapsed
|
||||||
|
if sleep_time > 0:
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] 停止发送,共发送 {packet_count} 包")
|
||||||
|
finally:
|
||||||
|
stop_recv.set()
|
||||||
|
consumer.join(timeout=2)
|
||||||
|
label_cmd_sock.close()
|
||||||
|
sock.close()
|
||||||
|
ctx.term()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
421
filter_test.py
Normal file
421
filter_test.py
Normal file
@@ -0,0 +1,421 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
脑电滤波服务 8100端口测试工具【统计逻辑专项优化版】
|
||||||
|
优化点:
|
||||||
|
1. 5秒预热(250个发包),预热结束后才启动丢包/数据统计
|
||||||
|
2. 业务比例:0.02s发1包,200ms收1包 → 每 10 个发包对应 1 个回包
|
||||||
|
3. 通道校验:发送(5,66) 仅对比前64通道,接收(50,64)全通道比对
|
||||||
|
4. 区分:全局总包数 / 有效统计区间包数、理论收包数、实际收包数、丢包数、丢包率
|
||||||
|
5. 新增64通道整体数据均值/极值比对,校验数据有效性
|
||||||
|
通信规范:send_multipart([client_id, b"", data_buf]) 三帧报文,服务端 recv_multipart 长度=3
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from collections import deque
|
||||||
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.animation import FuncAnimation
|
||||||
|
|
||||||
|
# ===================== 全局前置:修复Matplotlib中文字体 & 负号显示 =====================
|
||||||
|
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "WenQuanYi Micro Hei"]
|
||||||
|
plt.rcParams["axes.unicode_minus"] = False
|
||||||
|
|
||||||
|
# ===================== 【1. 全局业务固定参数(核心统计规则)】 =====================
|
||||||
|
# ZMQ 服务端配置
|
||||||
|
ZMQ_SERVER_IP = "127.0.0.1"
|
||||||
|
ZMQ_SERVER_PORT = 8100
|
||||||
|
ZMQ_SOCKET_TIMEOUT = 3000 # 套接字超时(ms)
|
||||||
|
POLL_TIMEOUT = 10 # Poll轮询超时(ms)
|
||||||
|
|
||||||
|
# 时序 & 统计核心规则(严格对齐现场业务)
|
||||||
|
SEND_INTERVAL = 0.02 # 上位机发包间隔:20ms/包
|
||||||
|
RECV_INTERVAL = 0.2 # 服务端回包间隔:200ms/包
|
||||||
|
PREHEAT_SECONDS = 5.0 # 滤波缓存预热时长:5秒
|
||||||
|
# 计算:预热需要的发包总数 = 预热时长 / 单包发送间隔
|
||||||
|
PREHEAT_SEND_PACKS = int(PREHEAT_SECONDS / SEND_INTERVAL) # 5 / 0.02 = 250 包
|
||||||
|
# 收发包比例:每多少个发包对应1个回包
|
||||||
|
PACK_RATIO = int(RECV_INTERVAL / SEND_INTERVAL) # 0.2 / 0.02 = 10
|
||||||
|
|
||||||
|
# 数据报文形状
|
||||||
|
PKG_SEND_SHAPE = (5, 66) # 发送包 (点数, 总通道)
|
||||||
|
PKG_RECV_SHAPE = (50, 64) # 回包 (点数, 有效脑电通道)
|
||||||
|
SAMPLE_RATE = 250
|
||||||
|
|
||||||
|
# 通道定义(对比仅使用前64路脑电通道)
|
||||||
|
CH_EEG_VALID = 64 # 共同对比通道数:0~63
|
||||||
|
CH_EVENT = 64
|
||||||
|
CH_RESERVED = 65
|
||||||
|
|
||||||
|
# ZMQ 三帧报文固定字段
|
||||||
|
CLIENT_ID = b"test_client_001"
|
||||||
|
EMPTY_FRAME = b""
|
||||||
|
|
||||||
|
# 仿真信号配置
|
||||||
|
TARGET_CHANNEL = 0
|
||||||
|
SIGNAL_FREQ_LIST = [3, 13]
|
||||||
|
SIGNAL_AMP = 1.8
|
||||||
|
NOISE_GAUSSIAN_AMP = 0.4
|
||||||
|
NOISE_POWER50_AMP = 0.3
|
||||||
|
EVENT_LABEL_VAL = 1
|
||||||
|
RESERVED_VAL = 0.0
|
||||||
|
|
||||||
|
# 可视化配置
|
||||||
|
MAX_PLOT_POINTS = 800
|
||||||
|
PLOT_REFRESH_INTERVAL = 80
|
||||||
|
FFT_N_POINTS = 256
|
||||||
|
PLOT_X_LIMIT_FREQ = (0, 60)
|
||||||
|
|
||||||
|
# 运行控制
|
||||||
|
MAX_RUN_SECONDS = None
|
||||||
|
ENABLE_RECONNECT = True
|
||||||
|
PRINT_STAT_INTERVAL = 5.0
|
||||||
|
|
||||||
|
# ===================== 【2. 全局变量 + 统计结构体(重构统计逻辑)】 =====================
|
||||||
|
g_running = threading.Event()
|
||||||
|
g_running.set()
|
||||||
|
data_lock = threading.Lock()
|
||||||
|
|
||||||
|
# 绘图缓冲区
|
||||||
|
raw_data_buf = deque(maxlen=MAX_PLOT_POINTS)
|
||||||
|
filt_data_buf = deque(maxlen=MAX_PLOT_POINTS)
|
||||||
|
|
||||||
|
# ===================== 全新统计变量(区分预热/正式统计) =====================
|
||||||
|
stat = {
|
||||||
|
# 全局总包数(包含预热包)
|
||||||
|
"total_send": 0,
|
||||||
|
"total_recv": 0,
|
||||||
|
|
||||||
|
# 有效统计区间(预热250包之后)
|
||||||
|
"valid_send": 0, # 有效发包数
|
||||||
|
"valid_recv": 0, # 有效收包数
|
||||||
|
"theo_recv": 0, # 理论应收到包数 = valid_send // PACK_RATIO
|
||||||
|
|
||||||
|
# 运行时间
|
||||||
|
"start_time": time.perf_counter(),
|
||||||
|
"last_print_time": time.perf_counter(),
|
||||||
|
|
||||||
|
# 数据校验缓存:保存最新一包原始64通道数据,用于和回包比对
|
||||||
|
"latest_raw_64ch": None
|
||||||
|
}
|
||||||
|
|
||||||
|
# ===================== 【3. 日志配置】 =====================
|
||||||
|
def init_logger():
|
||||||
|
log_format = "%(asctime)s | %(levelname)-8s | %(message)s"
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format=log_format,
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S"
|
||||||
|
)
|
||||||
|
return logging.getLogger("FilterTest")
|
||||||
|
|
||||||
|
logger = init_logger()
|
||||||
|
|
||||||
|
# ===================== 【4. 仿真脑电数据生成 (5,66)】 =====================
|
||||||
|
def generate_eeg_packet(pkt_idx: int) -> np.ndarray:
|
||||||
|
"""生成单包 (5,66) 仿真数据"""
|
||||||
|
n_point, n_chan = PKG_SEND_SHAPE
|
||||||
|
base_t = pkt_idx * n_point / SAMPLE_RATE
|
||||||
|
t_arr = base_t + np.arange(n_point) / SAMPLE_RATE
|
||||||
|
|
||||||
|
data = np.zeros((n_point, n_chan), dtype=np.float64)
|
||||||
|
|
||||||
|
# 64路脑电信号
|
||||||
|
for ch in range(CH_EEG_VALID):
|
||||||
|
sig = 0.0
|
||||||
|
for freq in SIGNAL_FREQ_LIST:
|
||||||
|
sig += SIGNAL_AMP * np.sin(2 * np.pi * freq * t_arr)
|
||||||
|
sig += NOISE_POWER50_AMP * np.sin(2 * np.pi * 50 * t_arr)
|
||||||
|
sig += NOISE_GAUSSIAN_AMP * np.random.randn(n_point)
|
||||||
|
data[:, ch] = sig
|
||||||
|
|
||||||
|
# 事件通道、保留通道
|
||||||
|
data[:, CH_EVENT] = EVENT_LABEL_VAL
|
||||||
|
data[:, CH_RESERVED] = RESERVED_VAL
|
||||||
|
return data
|
||||||
|
|
||||||
|
# ===================== 【5. ZMQ 核心IO线程(单连接+Poller,保留原有通信逻辑)】 =====================
|
||||||
|
def zmq_io_thread():
|
||||||
|
context = zmq.Context()
|
||||||
|
pkt_index = 0
|
||||||
|
send_interval = SEND_INTERVAL
|
||||||
|
|
||||||
|
logger.info(f"滤波预热配置:{PREHEAT_SECONDS}秒 / {PREHEAT_SEND_PACKS} 个发包后开始统计")
|
||||||
|
logger.info(f"收发比例:每 {PACK_RATIO} 个发包 → 1 个滤波回包")
|
||||||
|
|
||||||
|
while g_running.is_set():
|
||||||
|
try:
|
||||||
|
sock = context.socket(zmq.DEALER)
|
||||||
|
sock.setsockopt(zmq.RCVTIMEO, ZMQ_SOCKET_TIMEOUT)
|
||||||
|
sock.setsockopt(zmq.SNDTIMEO, ZMQ_SOCKET_TIMEOUT)
|
||||||
|
sock.connect(f"tcp://{ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
|
||||||
|
logger.info(f"ZMQ 连接成功 -> {ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
|
||||||
|
|
||||||
|
poller = zmq.Poller()
|
||||||
|
poller.register(sock, zmq.POLLIN)
|
||||||
|
next_send_ts = time.perf_counter()
|
||||||
|
|
||||||
|
while g_running.is_set():
|
||||||
|
# 全局运行时长限制
|
||||||
|
if MAX_RUN_SECONDS is not None:
|
||||||
|
run_sec = time.perf_counter() - stat["start_time"]
|
||||||
|
if run_sec > MAX_RUN_SECONDS:
|
||||||
|
logger.info(f"已到达设定运行时长 {MAX_RUN_SECONDS}s,停止任务")
|
||||||
|
return
|
||||||
|
|
||||||
|
# ========== 1. 轮询接收服务端回包 ==========
|
||||||
|
socks_ready = dict(poller.poll(POLL_TIMEOUT))
|
||||||
|
if sock in socks_ready:
|
||||||
|
frames = sock.recv_multipart()
|
||||||
|
if not frames:
|
||||||
|
continue
|
||||||
|
recv_bytes = frames[-1]
|
||||||
|
if not recv_bytes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 解析回包 (50,64)
|
||||||
|
filt_data = np.frombuffer(recv_bytes, dtype=np.float64)
|
||||||
|
expect_size = PKG_RECV_SHAPE[0] * PKG_RECV_SHAPE[1]
|
||||||
|
if filt_data.size != expect_size:
|
||||||
|
logger.warning(f"回包长度异常:实际{filt_data.size},预期{expect_size}")
|
||||||
|
continue
|
||||||
|
filt_data = filt_data.reshape(PKG_RECV_SHAPE)
|
||||||
|
|
||||||
|
# 全局收包计数
|
||||||
|
stat["total_recv"] += 1
|
||||||
|
|
||||||
|
# 仅预热完成后,计入有效统计收包
|
||||||
|
if stat["total_send"] > PREHEAT_SEND_PACKS:
|
||||||
|
stat["valid_recv"] += 1
|
||||||
|
|
||||||
|
# 写入绘图缓冲区
|
||||||
|
with data_lock:
|
||||||
|
filt_data_buf.extend(filt_data[:, TARGET_CHANNEL])
|
||||||
|
|
||||||
|
# ---------- 新增:64通道数据比对(发包前64通道 <-> 回包64通道) ----------
|
||||||
|
raw_64ch = stat["latest_raw_64ch"]
|
||||||
|
if raw_64ch is not None:
|
||||||
|
raw_mean = np.mean(raw_64ch)
|
||||||
|
filt_mean = np.mean(filt_data)
|
||||||
|
raw_amp = np.max(np.abs(raw_64ch))
|
||||||
|
filt_amp = np.max(np.abs(filt_data))
|
||||||
|
logger.debug(
|
||||||
|
f"【通道数据比对】原始64通道均值:{raw_mean:.4f} 幅值:{raw_amp:.4f} | "
|
||||||
|
f"滤波后均值:{filt_mean:.4f} 幅值:{filt_amp:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ========== 2. 精准定时发送数据包 ==========
|
||||||
|
current_ts = time.perf_counter()
|
||||||
|
if current_ts >= next_send_ts:
|
||||||
|
# 生成(5,66)仿真包
|
||||||
|
pkt_data = generate_eeg_packet(pkt_index)
|
||||||
|
pkt_index += 1
|
||||||
|
send_buf = pkt_data.tobytes()
|
||||||
|
|
||||||
|
# 标准三帧Multipart发送
|
||||||
|
sock.send_multipart([CLIENT_ID, EMPTY_FRAME, send_buf])
|
||||||
|
|
||||||
|
# ---------- 发包计数逻辑(核心优化:预热区分) ----------
|
||||||
|
stat["total_send"] += 1
|
||||||
|
# 预热完成后,计入有效发包
|
||||||
|
if stat["total_send"] > PREHEAT_SEND_PACKS:
|
||||||
|
stat["valid_send"] += 1
|
||||||
|
# 计算理论应收包数
|
||||||
|
stat["theo_recv"] = stat["valid_send"] // PACK_RATIO
|
||||||
|
|
||||||
|
# 缓存当前包前64通道,用于后续数据比对
|
||||||
|
stat["latest_raw_64ch"] = pkt_data[:, :CH_EEG_VALID]
|
||||||
|
|
||||||
|
# 绘图缓冲区(单通道波形)
|
||||||
|
with data_lock:
|
||||||
|
raw_data_buf.extend(pkt_data[:, TARGET_CHANNEL])
|
||||||
|
|
||||||
|
# 更新下一次发包时间
|
||||||
|
next_send_ts += send_interval
|
||||||
|
|
||||||
|
# ========== 3. 定时打印统计信息(区分预热/正式统计) ==========
|
||||||
|
now = time.perf_counter()
|
||||||
|
if now - stat["last_print_time"] > PRINT_STAT_INTERVAL:
|
||||||
|
run_sec = now - stat["start_time"]
|
||||||
|
total_send = stat["total_send"]
|
||||||
|
total_recv = stat["total_recv"]
|
||||||
|
|
||||||
|
# 分支1:仍在预热阶段
|
||||||
|
if total_send <= PREHEAT_SEND_PACKS:
|
||||||
|
remain = PREHEAT_SEND_PACKS - total_send
|
||||||
|
logger.info(
|
||||||
|
f"[预热中] 运行:{run_sec:.1f}s | 已发包:{total_send}/{PREHEAT_SEND_PACKS} | "
|
||||||
|
f"剩余预热包:{remain} | 暂不统计丢包"
|
||||||
|
)
|
||||||
|
# 分支2:预热完成,进入正式统计
|
||||||
|
else:
|
||||||
|
v_send = stat["valid_send"]
|
||||||
|
v_recv = stat["valid_recv"]
|
||||||
|
t_recv = stat["theo_recv"]
|
||||||
|
loss_cnt = t_recv - v_recv
|
||||||
|
loss_rate = (loss_cnt / t_recv * 100) if t_recv > 0 else 0.0
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[正式统计] 运行:{run_sec:.1f}s | "
|
||||||
|
f"全局总包: 发{total_send}/收{total_recv} | "
|
||||||
|
f"有效区间: 发{v_send}/应收{t_recv}/实收{v_recv} | "
|
||||||
|
f"丢包数:{loss_cnt} | 丢包率:{loss_rate:.2f}%"
|
||||||
|
)
|
||||||
|
stat["last_print_time"] = now
|
||||||
|
|
||||||
|
except zmq.ZMQError as e:
|
||||||
|
if e.errno == zmq.EAGAIN:
|
||||||
|
continue
|
||||||
|
logger.warning(f"ZMQ 连接异常: {e}")
|
||||||
|
sock.close()
|
||||||
|
poller.unregister(sock)
|
||||||
|
if not ENABLE_RECONNECT:
|
||||||
|
break
|
||||||
|
logger.info("500ms 后尝试重连...")
|
||||||
|
time.sleep(0.5)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"IO线程未知异常:\n{traceback.format_exc()}")
|
||||||
|
break
|
||||||
|
|
||||||
|
context.term()
|
||||||
|
logger.info("ZMQ IO 线程已退出")
|
||||||
|
|
||||||
|
# ===================== 【6. 可视化绘图(无改动)】 =====================
|
||||||
|
def init_plot():
|
||||||
|
fig = plt.figure(figsize=(14, 9))
|
||||||
|
fig.suptitle(f"脑电滤波测试 | 观测通道: {TARGET_CHANNEL}", fontsize=14)
|
||||||
|
|
||||||
|
ax1 = plt.subplot(2, 2, 1)
|
||||||
|
ax1.set_title("原始输入波形 (含噪声+工频)")
|
||||||
|
ax1.set_ylabel("幅值")
|
||||||
|
ax1.grid(True, alpha=0.3)
|
||||||
|
line_raw, = ax1.plot([], [], color="#1f77b4", linewidth=1)
|
||||||
|
|
||||||
|
ax2 = plt.subplot(2, 2, 2)
|
||||||
|
ax2.set_title("滤波后输出波形")
|
||||||
|
ax2.set_ylabel("幅值")
|
||||||
|
ax2.grid(True, alpha=0.3)
|
||||||
|
line_filt, = ax2.plot([], [], color="#d62728", linewidth=1)
|
||||||
|
|
||||||
|
ax3 = plt.subplot(2, 2, 3)
|
||||||
|
ax3.set_title("原始信号频谱")
|
||||||
|
ax3.set_xlabel("频率 (Hz)")
|
||||||
|
ax3.set_xlim(*PLOT_X_LIMIT_FREQ)
|
||||||
|
ax3.grid(True, alpha=0.3)
|
||||||
|
line_raw_fft, = ax3.plot([], [], color="#1f77b4")
|
||||||
|
|
||||||
|
ax4 = plt.subplot(2, 2, 4)
|
||||||
|
ax4.set_title("滤波后信号频谱")
|
||||||
|
ax4.set_xlabel("频率 (Hz)")
|
||||||
|
ax4.set_xlim(*PLOT_X_LIMIT_FREQ)
|
||||||
|
ax4.grid(True, alpha=0.3)
|
||||||
|
line_filt_fft, = ax4.plot([], [], color="#d62728")
|
||||||
|
|
||||||
|
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||||||
|
return fig, [line_raw, line_filt, line_raw_fft, line_filt_fft], [ax1, ax2, ax3, ax4]
|
||||||
|
|
||||||
|
def update_plot(frame, lines, axes):
|
||||||
|
line_raw, line_filt, line_raw_fft, line_filt_fft = lines
|
||||||
|
ax1, ax2, ax3, ax4 = axes
|
||||||
|
|
||||||
|
with data_lock:
|
||||||
|
raw_data = list(raw_data_buf)
|
||||||
|
filt_data = list(filt_data_buf)
|
||||||
|
|
||||||
|
if raw_data:
|
||||||
|
x_raw = np.arange(len(raw_data))
|
||||||
|
line_raw.set_data(x_raw, raw_data)
|
||||||
|
ax1.relim()
|
||||||
|
ax1.autoscale_view()
|
||||||
|
|
||||||
|
if filt_data:
|
||||||
|
x_filt = np.arange(len(filt_data))
|
||||||
|
line_filt.set_data(x_filt, filt_data)
|
||||||
|
ax2.relim()
|
||||||
|
ax2.autoscale_view()
|
||||||
|
|
||||||
|
def calc_fft(sig, n_fft):
|
||||||
|
if len(sig) < n_fft:
|
||||||
|
return [], []
|
||||||
|
win = np.hanning(n_fft)
|
||||||
|
sig_win = sig[-n_fft:] * win
|
||||||
|
fft_vals = np.fft.fft(sig_win)
|
||||||
|
fft_amp = np.abs(fft_vals)[:n_fft//2]
|
||||||
|
freq = np.fft.fftfreq(n_fft, 1/SAMPLE_RATE)[:n_fft//2]
|
||||||
|
return freq, fft_amp
|
||||||
|
|
||||||
|
freq_raw, amp_raw = calc_fft(raw_data, FFT_N_POINTS)
|
||||||
|
freq_filt, amp_filt = calc_fft(filt_data, FFT_N_POINTS)
|
||||||
|
|
||||||
|
line_raw_fft.set_data(freq_raw, amp_raw)
|
||||||
|
line_filt_fft.set_data(freq_filt, amp_filt)
|
||||||
|
ax3.relim()
|
||||||
|
ax3.autoscale_view(scaley=True)
|
||||||
|
ax4.relim()
|
||||||
|
ax4.autoscale_view(scaley=True)
|
||||||
|
|
||||||
|
return lines
|
||||||
|
|
||||||
|
# ===================== 【7. 资源释放 & 最终汇总统计】 =====================
|
||||||
|
def clean_resource():
|
||||||
|
g_running.clear()
|
||||||
|
logger.info("开始停止所有线程...")
|
||||||
|
time.sleep(0.3)
|
||||||
|
plt.close("all")
|
||||||
|
logger.info("资源释放完成")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info("脑电滤波测试客户端【统计逻辑优化版】启动")
|
||||||
|
logger.info(f"服务端地址: {ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
|
||||||
|
logger.info(f"发包: {PKG_SEND_SHAPE}({SEND_INTERVAL*1000:.0f}ms) | 回包: {PKG_RECV_SHAPE}({RECV_INTERVAL*1000:.0f}ms)")
|
||||||
|
logger.info(f"预热规则: {PREHEAT_SECONDS}秒 / {PREHEAT_SEND_PACKS} 包后开启统计")
|
||||||
|
logger.info(f"收发比例: 每 {PACK_RATIO} 个发包对应 1 个回包")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
|
||||||
|
# 启动ZMQ收发线程
|
||||||
|
io_thread = threading.Thread(target=zmq_io_thread, daemon=True, name="ZMQ_IO_Thread")
|
||||||
|
io_thread.start()
|
||||||
|
|
||||||
|
# 启动可视化
|
||||||
|
fig, lines, axes = init_plot()
|
||||||
|
ani = FuncAnimation(
|
||||||
|
fig, update_plot,
|
||||||
|
fargs=(lines, axes),
|
||||||
|
interval=PLOT_REFRESH_INTERVAL,
|
||||||
|
blit=True,
|
||||||
|
cache_frame_data=False
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
plt.show()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("收到 Ctrl+C 中断信号,准备退出")
|
||||||
|
finally:
|
||||||
|
# 输出最终完整汇总报表
|
||||||
|
run_total = time.perf_counter() - stat["start_time"]
|
||||||
|
total_send = stat["total_send"]
|
||||||
|
total_recv = stat["total_recv"]
|
||||||
|
v_send = stat["valid_send"]
|
||||||
|
v_recv = stat["valid_recv"]
|
||||||
|
t_recv = stat["theo_recv"]
|
||||||
|
|
||||||
|
loss_cnt = t_recv - v_recv
|
||||||
|
loss_rate = (loss_cnt / t_recv * 100) if t_recv > 0 else 0.0
|
||||||
|
|
||||||
|
logger.info(f"\n{'='*50} 最终运行汇总 {'='*50}")
|
||||||
|
logger.info(f"总运行时长: {run_total:.1f} s")
|
||||||
|
logger.info(f"【全局总包数】发送: {total_send} | 接收: {total_recv}")
|
||||||
|
logger.info(f"【有效统计区间(跳过预热{PREHEAT_SEND_PACKS}包)】")
|
||||||
|
logger.info(f" 有效发包: {v_send} | 理论应收包: {t_recv} | 实际收包: {v_recv}")
|
||||||
|
logger.info(f" 总丢包数: {loss_cnt} | 整体丢包率: {loss_rate:.2f} %")
|
||||||
|
logger.info(f"{'='*106}")
|
||||||
|
|
||||||
|
clean_resource()
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
119
logs/log.py
119
logs/log.py
@@ -1,87 +1,114 @@
|
|||||||
# log.py
|
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
import logging
|
import logging
|
||||||
from logging.handlers import RotatingFileHandler
|
from logging.handlers import RotatingFileHandler
|
||||||
|
import inspect
|
||||||
from PubLibrary.InifileHelper import IniRead
|
from PubLibrary.InifileHelper import IniRead
|
||||||
|
|
||||||
|
# 全局配置
|
||||||
console_output = IniRead('system', 'console_output', '1')
|
console_output = IniRead('system', 'console_output', '1')
|
||||||
log_level = IniRead('system', 'algo_log_level', 'INFO')
|
log_level = IniRead('system', 'algo_log_level', 'INFO')
|
||||||
|
|
||||||
# 新增:日志去重缓存,key为日志内容,value为是否已打印
|
|
||||||
log_once_cache = set()
|
log_once_cache = set()
|
||||||
|
logger_cache = {}
|
||||||
|
LOG_RETENTION_DAYS = 3
|
||||||
|
LOG_DIR = './logs/'
|
||||||
|
LOG_FILE_PREFIX = 'algo_log_'
|
||||||
|
|
||||||
|
# 日志格式:时间 - 日志器名 - 级别 - 文件名:行号 - 函数名 - 日志内容
|
||||||
|
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
|
||||||
|
|
||||||
|
|
||||||
def init_module_logger():
|
def clean_old_logs():
|
||||||
"""
|
"""清理超过指定天数的旧日志文件"""
|
||||||
初始化指定模块的日志器
|
try:
|
||||||
:return: 对应模块的logger实例
|
if not os.path.exists(LOG_DIR):
|
||||||
"""
|
return
|
||||||
# 缓存命中则直接返回
|
expire_date = datetime.now() - timedelta(days=LOG_RETENTION_DAYS)
|
||||||
log_dir = './logs/' # 确保日志目录存在
|
for filename in os.listdir(LOG_DIR):
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
if not filename.startswith(LOG_FILE_PREFIX) or not filename.endswith('.log'):
|
||||||
|
continue
|
||||||
|
date_str = filename[len(LOG_FILE_PREFIX):-4]
|
||||||
|
try:
|
||||||
|
file_date = datetime.strptime(date_str, '%Y-%m-%d')
|
||||||
|
if file_date < expire_date:
|
||||||
|
file_path = os.path.join(LOG_DIR, filename)
|
||||||
|
os.remove(file_path)
|
||||||
|
print(f"清理过期日志: {file_path}")
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
print(f"清理旧日志异常: {str(e)}")
|
||||||
|
|
||||||
log_file = os.path.join(log_dir, f'algo_log_{datetime.now().strftime("%Y-%m-%d")}.log')
|
|
||||||
|
|
||||||
# 初始化logger
|
def init_module_logger(logger_name):
|
||||||
logger = logging.getLogger('decoderLogger')
|
"""初始化日志器 + 清理旧日志"""
|
||||||
|
os.makedirs(LOG_DIR, exist_ok=True)
|
||||||
|
clean_old_logs()
|
||||||
|
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
log_file = os.path.join(LOG_DIR, f"{LOG_FILE_PREFIX}{current_date}.log")
|
||||||
|
|
||||||
|
if logger_name in logger_cache:
|
||||||
|
return logger_cache[logger_name]
|
||||||
|
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
if logger.handlers:
|
if logger.handlers:
|
||||||
|
logger_cache[logger_name] = logger
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
# 设置日志轮转,最大10个文件,每个10MB
|
# 文件输出处理器
|
||||||
file_handler = RotatingFileHandler(
|
file_handler = RotatingFileHandler(
|
||||||
log_file,
|
log_file,
|
||||||
maxBytes=10 * 1024 * 1024,
|
maxBytes=10 * 1024 * 1024,
|
||||||
backupCount=10,
|
backupCount=10,
|
||||||
encoding='utf-8'
|
encoding='utf-8'
|
||||||
)
|
)
|
||||||
|
formatter = logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)
|
||||||
# 日志格式
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
|
||||||
)
|
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
logger.setLevel(log_level)
|
|
||||||
logger.addHandler(file_handler)
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
# 控制台输出
|
||||||
if console_output:
|
if console_output:
|
||||||
console_handler = logging.StreamHandler()
|
console_handler = logging.StreamHandler()
|
||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
logger.addHandler(console_handler)
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
logger_cache[logger_name] = logger
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
def algo_log(content, level="INFO", record_once=False):
|
def algo_log(content, level="INFO", record_once=False):
|
||||||
"""
|
"""
|
||||||
通用日志函数,支持按模块输出到不同日志文件
|
日志入口函数
|
||||||
:param content: 日志内容
|
自动记录:调用文件名、代码行号、所在函数
|
||||||
:param level: 日志级别(DEBUG/INFO/WARNING/ERROR/FATAL)
|
|
||||||
:param record_once: 是否只打印一次该日志内容,默认False
|
|
||||||
"""
|
"""
|
||||||
# 初始化模块日志器
|
# 回溯栈帧,获取真正调用 algo_log 的代码位置
|
||||||
logger = init_module_logger()
|
# f_back(1) -> algo_log 自身,f_back(2) -> 业务调用处
|
||||||
|
frame = inspect.currentframe().f_back.f_back
|
||||||
|
if not frame:
|
||||||
|
file_name = "unknown"
|
||||||
|
else:
|
||||||
|
file_name = os.path.basename(frame.f_code.co_filename)
|
||||||
|
|
||||||
# 新增:处理只打印一次的逻辑
|
logger = init_module_logger(file_name)
|
||||||
|
|
||||||
|
# 单次日志去重
|
||||||
if record_once:
|
if record_once:
|
||||||
# 生成唯一标识(可根据需要调整,比如拼接level增强唯一性)
|
|
||||||
log_key = f"{level.upper()}_{content}"
|
log_key = f"{level.upper()}_{content}"
|
||||||
if log_key in log_once_cache:
|
if log_key in log_once_cache:
|
||||||
return # 已打印过,直接返回
|
return
|
||||||
log_once_cache.add(log_key) # 未打印过,加入缓存
|
log_once_cache.add(log_key)
|
||||||
|
|
||||||
# 根据级别输出日志
|
# 日志级别分发
|
||||||
level_upper = level.upper()
|
level_upper = level.upper()
|
||||||
if level_upper == "DEBUG":
|
log_map = {
|
||||||
logger.debug(content)
|
"DEBUG": logger.debug,
|
||||||
elif level_upper == "WARNING":
|
"WARNING": logger.warning,
|
||||||
logger.warning(content)
|
"ERROR": logger.error,
|
||||||
elif level_upper == "ERROR":
|
"FATAL": logger.fatal,
|
||||||
logger.error(content)
|
"INFO": logger.info
|
||||||
elif level_upper == "FATAL":
|
}
|
||||||
logger.fatal(content)
|
log_func = log_map.get(level_upper, logger.info)
|
||||||
else: # 默认INFO级别
|
log_func(content)
|
||||||
logger.info(content)
|
|
||||||
55
nuitka_3in1_package.sh
Normal file
55
nuitka_3in1_package.sh
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Git Bash 中文 UTF-8 兼容配置(通用版,无报错)
|
||||||
|
export LC_ALL=en_US.UTF-8
|
||||||
|
export LANG=en_US.UTF-8
|
||||||
|
|
||||||
|
echo "========================"
|
||||||
|
echo "Nuitka 打包脚本 - 优化稳定版"
|
||||||
|
echo "适配:PyTorch2.0.0 + CUDA11.7 + 脑电解码项目"
|
||||||
|
echo "========================"
|
||||||
|
|
||||||
|
# ===================== 自定义配置区 =====================
|
||||||
|
PY_FILE="runDecoder.py" # 主程序文件
|
||||||
|
OUT_DIR="dist_nuitka" # 输出文件夹
|
||||||
|
MODEL_DIR="online_Models" # 模型文件夹
|
||||||
|
# ========================================================
|
||||||
|
|
||||||
|
# 检查主文件是否存在
|
||||||
|
if [ ! -f "${PY_FILE}" ]; then
|
||||||
|
echo "错误:未找到主文件 ${PY_FILE},请检查路径!"
|
||||||
|
read -n 1 -s -r -p "按任意键退出"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "开始打包:${PY_FILE}"
|
||||||
|
echo "输出目录:${OUT_DIR}"
|
||||||
|
|
||||||
|
# Nuitka 核心打包命令(无错误、无冗余、全依赖)
|
||||||
|
python -m nuitka \
|
||||||
|
--standalone \
|
||||||
|
--msvc=latest \
|
||||||
|
--windows-console-mode=disable \
|
||||||
|
--module-parameter=torch-disable-jit=yes \
|
||||||
|
--enable-plugin=no-qt \
|
||||||
|
--include-package=numpy \
|
||||||
|
--include-module=numpy.core._multiarray_umath \
|
||||||
|
--include-package=scipy \
|
||||||
|
--no-deployment-flag=self-execution \
|
||||||
|
--include-data-dir="${MODEL_DIR}=${MODEL_DIR}" \
|
||||||
|
--output-dir="${OUT_DIR}" \
|
||||||
|
--remove-output \
|
||||||
|
"${PY_FILE}"
|
||||||
|
|
||||||
|
# 打包结果判断
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
echo -e "\n========================"
|
||||||
|
echo "✅ 打包成功!"
|
||||||
|
echo "📦 产物路径:${OUT_DIR}/${PY_FILE%.py}.exe"
|
||||||
|
echo "========================"
|
||||||
|
else
|
||||||
|
echo -e "\n❌ 打包失败!"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Git Bash 兼容的暂停
|
||||||
|
read -n 1 -s -r -p "按任意键退出..."
|
||||||
|
echo
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,252 +0,0 @@
|
|||||||
0 0.5
|
|
||||||
1 0.5
|
|
||||||
2 0.375
|
|
||||||
3 0.5
|
|
||||||
4 0.4375
|
|
||||||
5 0.375
|
|
||||||
6 0.5
|
|
||||||
7 0.5
|
|
||||||
8 0.375
|
|
||||||
9 0.375
|
|
||||||
10 0.375
|
|
||||||
11 0.375
|
|
||||||
12 0.5
|
|
||||||
13 0.5625
|
|
||||||
14 0.5625
|
|
||||||
15 0.5
|
|
||||||
16 0.5
|
|
||||||
17 0.5
|
|
||||||
18 0.5
|
|
||||||
19 0.5625
|
|
||||||
20 0.4375
|
|
||||||
21 0.5
|
|
||||||
22 0.5
|
|
||||||
23 0.375
|
|
||||||
24 0.375
|
|
||||||
25 0.375
|
|
||||||
26 0.375
|
|
||||||
27 0.375
|
|
||||||
28 0.3125
|
|
||||||
29 0.375
|
|
||||||
30 0.5625
|
|
||||||
31 0.5
|
|
||||||
32 0.5
|
|
||||||
33 0.5625
|
|
||||||
34 0.5625
|
|
||||||
35 0.3125
|
|
||||||
36 0.3125
|
|
||||||
37 0.3125
|
|
||||||
38 0.375
|
|
||||||
39 0.5625
|
|
||||||
40 0.3125
|
|
||||||
41 0.5625
|
|
||||||
42 0.3125
|
|
||||||
43 0.375
|
|
||||||
44 0.5625
|
|
||||||
45 0.5
|
|
||||||
46 0.375
|
|
||||||
47 0.375
|
|
||||||
48 0.3125
|
|
||||||
49 0.375
|
|
||||||
50 0.375
|
|
||||||
51 0.5
|
|
||||||
52 0.5625
|
|
||||||
53 0.375
|
|
||||||
54 0.5625
|
|
||||||
55 0.5625
|
|
||||||
56 0.375
|
|
||||||
57 0.375
|
|
||||||
58 0.375
|
|
||||||
59 0.5
|
|
||||||
60 0.3125
|
|
||||||
61 0.375
|
|
||||||
62 0.375
|
|
||||||
63 0.375
|
|
||||||
64 0.375
|
|
||||||
65 0.375
|
|
||||||
66 0.3125
|
|
||||||
67 0.375
|
|
||||||
68 0.5625
|
|
||||||
69 0.5625
|
|
||||||
70 0.5625
|
|
||||||
71 0.5
|
|
||||||
72 0.5625
|
|
||||||
73 0.375
|
|
||||||
74 0.375
|
|
||||||
75 0.375
|
|
||||||
76 0.375
|
|
||||||
77 0.375
|
|
||||||
78 0.5
|
|
||||||
79 0.375
|
|
||||||
80 0.375
|
|
||||||
81 0.5
|
|
||||||
82 0.375
|
|
||||||
83 0.375
|
|
||||||
84 0.375
|
|
||||||
85 0.375
|
|
||||||
86 0.3125
|
|
||||||
87 0.375
|
|
||||||
88 0.375
|
|
||||||
89 0.5
|
|
||||||
90 0.375
|
|
||||||
91 0.4375
|
|
||||||
92 0.3125
|
|
||||||
93 0.3125
|
|
||||||
94 0.375
|
|
||||||
95 0.375
|
|
||||||
96 0.375
|
|
||||||
97 0.375
|
|
||||||
98 0.3125
|
|
||||||
99 0.4375
|
|
||||||
100 0.375
|
|
||||||
101 0.375
|
|
||||||
102 0.375
|
|
||||||
103 0.3125
|
|
||||||
104 0.5625
|
|
||||||
105 0.5
|
|
||||||
106 0.5625
|
|
||||||
107 0.5625
|
|
||||||
108 0.5
|
|
||||||
109 0.3125
|
|
||||||
110 0.5625
|
|
||||||
111 0.5625
|
|
||||||
112 0.5
|
|
||||||
113 0.3125
|
|
||||||
114 0.5
|
|
||||||
115 0.3125
|
|
||||||
116 0.375
|
|
||||||
117 0.3125
|
|
||||||
118 0.3125
|
|
||||||
119 0.3125
|
|
||||||
120 0.3125
|
|
||||||
121 0.375
|
|
||||||
122 0.375
|
|
||||||
123 0.375
|
|
||||||
124 0.375
|
|
||||||
125 0.3125
|
|
||||||
126 0.375
|
|
||||||
127 0.375
|
|
||||||
128 0.375
|
|
||||||
129 0.375
|
|
||||||
130 0.5625
|
|
||||||
131 0.375
|
|
||||||
132 0.5
|
|
||||||
133 0.3125
|
|
||||||
134 0.3125
|
|
||||||
135 0.3125
|
|
||||||
136 0.375
|
|
||||||
137 0.5
|
|
||||||
138 0.3125
|
|
||||||
139 0.375
|
|
||||||
140 0.3125
|
|
||||||
141 0.3125
|
|
||||||
142 0.3125
|
|
||||||
143 0.5625
|
|
||||||
144 0.3125
|
|
||||||
145 0.375
|
|
||||||
146 0.5
|
|
||||||
147 0.5
|
|
||||||
148 0.375
|
|
||||||
149 0.4375
|
|
||||||
150 0.5
|
|
||||||
151 0.3125
|
|
||||||
152 0.375
|
|
||||||
153 0.375
|
|
||||||
154 0.375
|
|
||||||
155 0.3125
|
|
||||||
156 0.375
|
|
||||||
157 0.4375
|
|
||||||
158 0.4375
|
|
||||||
159 0.375
|
|
||||||
160 0.375
|
|
||||||
161 0.3125
|
|
||||||
162 0.375
|
|
||||||
163 0.375
|
|
||||||
164 0.375
|
|
||||||
165 0.3125
|
|
||||||
166 0.3125
|
|
||||||
167 0.3125
|
|
||||||
168 0.375
|
|
||||||
169 0.3125
|
|
||||||
170 0.3125
|
|
||||||
171 0.3125
|
|
||||||
172 0.375
|
|
||||||
173 0.3125
|
|
||||||
174 0.3125
|
|
||||||
175 0.5
|
|
||||||
176 0.3125
|
|
||||||
177 0.375
|
|
||||||
178 0.375
|
|
||||||
179 0.3125
|
|
||||||
180 0.3125
|
|
||||||
181 0.3125
|
|
||||||
182 0.3125
|
|
||||||
183 0.5625
|
|
||||||
184 0.5625
|
|
||||||
185 0.3125
|
|
||||||
186 0.5
|
|
||||||
187 0.5
|
|
||||||
188 0.5625
|
|
||||||
189 0.5
|
|
||||||
190 0.5625
|
|
||||||
191 0.5625
|
|
||||||
192 0.5625
|
|
||||||
193 0.5
|
|
||||||
194 0.5
|
|
||||||
195 0.5625
|
|
||||||
196 0.5625
|
|
||||||
197 0.5625
|
|
||||||
198 0.5625
|
|
||||||
199 0.5
|
|
||||||
200 0.5625
|
|
||||||
201 0.5625
|
|
||||||
202 0.375
|
|
||||||
203 0.375
|
|
||||||
204 0.375
|
|
||||||
205 0.375
|
|
||||||
206 0.375
|
|
||||||
207 0.5
|
|
||||||
208 0.5
|
|
||||||
209 0.5625
|
|
||||||
210 0.5625
|
|
||||||
211 0.5625
|
|
||||||
212 0.3125
|
|
||||||
213 0.5
|
|
||||||
214 0.5
|
|
||||||
215 0.5625
|
|
||||||
216 0.5
|
|
||||||
217 0.5
|
|
||||||
218 0.5
|
|
||||||
219 0.5625
|
|
||||||
220 0.5
|
|
||||||
221 0.4375
|
|
||||||
222 0.5
|
|
||||||
223 0.5
|
|
||||||
224 0.4375
|
|
||||||
225 0.5
|
|
||||||
226 0.4375
|
|
||||||
227 0.5
|
|
||||||
228 0.5
|
|
||||||
229 0.375
|
|
||||||
230 0.375
|
|
||||||
231 0.3125
|
|
||||||
232 0.375
|
|
||||||
233 0.375
|
|
||||||
234 0.375
|
|
||||||
235 0.5625
|
|
||||||
236 0.5625
|
|
||||||
237 0.5625
|
|
||||||
238 0.5625
|
|
||||||
239 0.5625
|
|
||||||
240 0.5
|
|
||||||
241 0.5
|
|
||||||
242 0.5
|
|
||||||
243 0.5625
|
|
||||||
244 0.5625
|
|
||||||
245 0.375
|
|
||||||
246 0.375
|
|
||||||
247 0.375
|
|
||||||
248 0.3125
|
|
||||||
249 0.375
|
|
||||||
The average accuracy is: 0.42675
|
|
||||||
The best accuracy is: 0.5625
|
|
||||||
52
requirements.txt
Normal file
52
requirements.txt
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
Bottleneck==1.4.2
|
||||||
|
brotlicffi==1.2.0.0
|
||||||
|
certifi==2026.5.20
|
||||||
|
cffi==2.0.0
|
||||||
|
charset-normalizer==3.4.4
|
||||||
|
contourpy==1.3.2
|
||||||
|
cycler==0.12.1
|
||||||
|
einops==0.8.2
|
||||||
|
filelock==3.20.3
|
||||||
|
fonttools==4.63.0
|
||||||
|
gmpy2==2.2.2
|
||||||
|
idna==3.11
|
||||||
|
Jinja2==3.1.6
|
||||||
|
joblib==1.5.3
|
||||||
|
kiwisolver==1.5.0
|
||||||
|
MarkupSafe==3.0.2
|
||||||
|
matplotlib==3.10.9
|
||||||
|
mkl_fft==1.3.11
|
||||||
|
mkl_random==1.2.8
|
||||||
|
mkl-service==2.5.2
|
||||||
|
mpmath==1.3.0
|
||||||
|
networkx==3.4.2
|
||||||
|
Nuitka==4.1.1
|
||||||
|
numexpr==2.14.1
|
||||||
|
numpy==1.24.3
|
||||||
|
packaging==26.0
|
||||||
|
pandas==2.3.3
|
||||||
|
pillow==12.2.0
|
||||||
|
pip==26.0.1
|
||||||
|
pycparser==3.0
|
||||||
|
pyparsing==3.3.2
|
||||||
|
pyserial==3.5
|
||||||
|
PySocks==1.7.1
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
pytz==2026.1.post1
|
||||||
|
pyzmq==27.1.0
|
||||||
|
requests==2.33.1
|
||||||
|
scikit-learn==1.7.1
|
||||||
|
scipy==1.15.3
|
||||||
|
setuptools==82.0.1
|
||||||
|
six==1.17.0
|
||||||
|
sympy==1.14.0
|
||||||
|
threadpoolctl==3.5.0
|
||||||
|
torch==2.0.0
|
||||||
|
torchaudio==2.0.0
|
||||||
|
torchsummary==1.5.1
|
||||||
|
torchvision==0.15.0
|
||||||
|
typing_extensions==4.15.0
|
||||||
|
tzdata==2026.2
|
||||||
|
urllib3==2.7.0
|
||||||
|
wheel==0.46.3
|
||||||
|
win_inet_pton==1.1.0
|
||||||
@@ -6,32 +6,33 @@ import time
|
|||||||
from Decoder import Decoder_main
|
from Decoder import Decoder_main
|
||||||
from PubLibrary.RunOnce import is_program_running
|
from PubLibrary.RunOnce import is_program_running
|
||||||
from PubLibrary.InifileHelper import IniRead
|
from PubLibrary.InifileHelper import IniRead
|
||||||
|
from logs.log import algo_log
|
||||||
|
|
||||||
def get_device_info(device_type):
|
def get_device_info(device_type):
|
||||||
|
|
||||||
|
|
||||||
section = f'device_type_{device_type}'
|
section = f'device_type_{device_type}'
|
||||||
device_info = {
|
device_info = {
|
||||||
'device_sample_rate': int(IniRead(section, 'sample_rate')) if IniRead(section, 'sample_rate') is not None else 250,
|
'sample_rate': int(IniRead(section, 'sample_rate')) if IniRead(section, 'sample_rate') is not None else 250,
|
||||||
|
'frame_points': int(IniRead(section, 'frame_points')) if IniRead(section, 'frame_points') is not None else 5,
|
||||||
''
|
'channel_nums': int(IniRead(section, 'channel_nums')) if IniRead(section, 'channel_nums') is not None else 66,
|
||||||
|
'channel_names': IniRead(section, 'channel_names') if IniRead(section, 'channel_names') is not None else None,
|
||||||
|
'channel_index': IniRead(section, 'channel_index') if IniRead(section, 'channel_index') is not None else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return device_info
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if not is_program_running():
|
if not is_program_running():
|
||||||
# 解析命令行参数
|
# 解析命令行参数
|
||||||
parser = argparse.ArgumentParser(description="EEG Decoder Application")
|
# parser = argparse.ArgumentParser(description="EEG Decoder Application")
|
||||||
parser.add_argument('-dt', '-t','--device-type', type=int, default=None, help="Device Type")
|
# parser.add_argument('-dt', '-t','--device-type', type=int, default=None, help="Device Type")
|
||||||
# parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP")
|
# parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP")
|
||||||
# parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port")
|
# parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port")
|
||||||
# parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP")
|
# parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP")
|
||||||
# parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port")
|
# parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port")
|
||||||
|
# args = parser.parse_args()
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
device_info= get_device_info(args.device_type)
|
|
||||||
|
|
||||||
|
|
||||||
decoder = Decoder_main(device_info=device_info)
|
|
||||||
# decoder.connect(
|
# decoder.connect(
|
||||||
# device_type=args.device_type,
|
# device_type=args.device_type,
|
||||||
# device_host=args.device_host,
|
# device_host=args.device_host,
|
||||||
@@ -40,6 +41,10 @@ if __name__ == "__main__":
|
|||||||
# upper_port=args.upper_port
|
# upper_port=args.upper_port
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
device_info= get_device_info(1)
|
||||||
|
algo_log(f"device_info: {device_info}", level="DEBUG")
|
||||||
|
decoder = Decoder_main(device_info=device_info)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
decoder.start()
|
decoder.start()
|
||||||
while not decoder.zmqServer.IsExitApp:
|
while not decoder.zmqServer.IsExitApp:
|
||||||
|
|||||||
305
upperHost_stimmock/MI_headless.py
Normal file
305
upperHost_stimmock/MI_headless.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
"""
|
||||||
|
MI_headless.py
|
||||||
|
无界面版 MI 运动想象范式通讯流程模拟脚本。
|
||||||
|
复现 MI_main.py 的完整指令序列(train 0/1, rest, predict, saveData),
|
||||||
|
但不依赖 psychopy 也不打开任何窗口/音频,用 time.sleep 替代帧循环等待。
|
||||||
|
|
||||||
|
启动顺序:
|
||||||
|
1. runDecoder.py
|
||||||
|
2. datamock.py
|
||||||
|
3. MI_headless.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import zmq
|
||||||
|
import numpy as np
|
||||||
|
import ast
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from PubLibrary.InifileHelper import IniRead
|
||||||
|
|
||||||
|
personname = 'demo'
|
||||||
|
session = '01'
|
||||||
|
|
||||||
|
DATAMOCK_LABEL_ADDR = 'tcp://127.0.0.1:8101' # datamock 标签命令地址
|
||||||
|
|
||||||
|
|
||||||
|
# ========== ZMQ 结果接收服务 ==========
|
||||||
|
class ZmqResultServer(threading.Thread):
|
||||||
|
def __init__(self, port=8088):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.port = port
|
||||||
|
self.running = True
|
||||||
|
self.energy = 0
|
||||||
|
self.paradigm = 0 # 0=个体校准, 1=康复训练, 2=等待模型训练
|
||||||
|
self.ChoosenNum = -1
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.socket = self.context.socket(zmq.ROUTER)
|
||||||
|
self.socket.bind(f"tcp://0.0.0.0:{self.port}")
|
||||||
|
self.daemon = True
|
||||||
|
self.trial_idx = 0
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
print(f"[Server] UpperHost_Server listening on {self.port}")
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
||||||
|
if len(frames) < 3:
|
||||||
|
continue
|
||||||
|
message = json.loads(frames[2].decode('utf-8'))
|
||||||
|
method = message.get('method')
|
||||||
|
params = message.get('params')
|
||||||
|
if method == 'energy':
|
||||||
|
self.energy = params
|
||||||
|
elif method == 'paradigm':
|
||||||
|
self.paradigm = params
|
||||||
|
print(f"[Server] paradigm -> {params}")
|
||||||
|
elif method == 'result':
|
||||||
|
self.ChoosenNum = params
|
||||||
|
self.trial_idx += 1
|
||||||
|
print(f"[Server] result={self.ChoosenNum} (trial {self.trial_idx})")
|
||||||
|
except zmq.Again:
|
||||||
|
time.sleep(0.005)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Server] error: {e}")
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.running = False
|
||||||
|
self.socket.close()
|
||||||
|
self.context.term()
|
||||||
|
|
||||||
|
|
||||||
|
# ========== ZMQ 命令发送客户端 ==========
|
||||||
|
class ZmqCmdClient:
|
||||||
|
def __init__(self, host, port):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.socket = self.context.socket(zmq.DEALER)
|
||||||
|
# PUSH socket 用于向 datamock.py 发送标签命令
|
||||||
|
self._label_sock = self.context.socket(zmq.PUSH)
|
||||||
|
self._label_sock.connect(DATAMOCK_LABEL_ADDR)
|
||||||
|
print(f"[Client] label PUSH connected to {DATAMOCK_LABEL_ADDR}")
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
self.socket.connect(f"tcp://{self.host}:{self.port}")
|
||||||
|
print(f"[Client] connected to {self.host}:{self.port}")
|
||||||
|
|
||||||
|
def start_recv_thread(self, result_server):
|
||||||
|
"""启动后台线程,持续接收 decoder 通过 8099 ROUTER 回发的消息,并更新 result_server 的状态"""
|
||||||
|
self._result_server = result_server
|
||||||
|
self._stop_recv = threading.Event()
|
||||||
|
|
||||||
|
def _recv_loop():
|
||||||
|
while not self._stop_recv.is_set():
|
||||||
|
try:
|
||||||
|
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
||||||
|
# DEALER 收到的格式: [b'', json_bytes]
|
||||||
|
data_bytes = frames[-1]
|
||||||
|
message = json.loads(data_bytes.decode('utf-8'))
|
||||||
|
method = message.get('method')
|
||||||
|
params = message.get('params')
|
||||||
|
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||||
|
print(f"[{ts}] [CmdClient] recv: {method}={params}")
|
||||||
|
if method == 'paradigm':
|
||||||
|
self._result_server.paradigm = params
|
||||||
|
print(f"[{ts}] [CmdClient] paradigm updated -> {params}")
|
||||||
|
elif method == 'result':
|
||||||
|
self._result_server.ChoosenNum = params
|
||||||
|
self._result_server.trial_idx += 1
|
||||||
|
print(f"[{ts}] [CmdClient] result={params} (trial {self._result_server.trial_idx})")
|
||||||
|
elif method == 'energy':
|
||||||
|
self._result_server.energy = params
|
||||||
|
except zmq.Again:
|
||||||
|
time.sleep(0.005)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[CmdClient recv] error: {e}")
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
self._recv_thread = threading.Thread(target=_recv_loop, daemon=True)
|
||||||
|
self._recv_thread.start()
|
||||||
|
print(f"[Client] 后台接收线程已启动(监听 decoder 8099 回发消息)")
|
||||||
|
|
||||||
|
def stop_recv_thread(self):
|
||||||
|
if hasattr(self, '_stop_recv'):
|
||||||
|
self._stop_recv.set()
|
||||||
|
|
||||||
|
def _send_label(self, label_value):
|
||||||
|
"""向 datamock.py 发送标签命令"""
|
||||||
|
try:
|
||||||
|
self._label_sock.send_string(str(label_value), zmq.NOBLOCK)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Client] label send error: {e}")
|
||||||
|
|
||||||
|
def send_data(self, method, params):
|
||||||
|
msg = {'method': method, 'params': params}
|
||||||
|
try:
|
||||||
|
self.socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
|
||||||
|
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||||
|
print(f"[{ts}] send_data: {method}={params}")
|
||||||
|
# 根据 train/predict 命令向 datamock 发送标签
|
||||||
|
if method == 'train':
|
||||||
|
if params == 0:
|
||||||
|
self._send_label(1)
|
||||||
|
print(f"[Label] train 0 -> datamock label=1")
|
||||||
|
elif params == 1:
|
||||||
|
self._send_label(2)
|
||||||
|
print(f"[Label] train 1 -> datamock label=2")
|
||||||
|
elif method == 'predict':
|
||||||
|
self._send_label(99)
|
||||||
|
print(f"[Label] predict -> datamock label=99")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Client] send error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 主流程 ==========
|
||||||
|
def run_headless():
|
||||||
|
server = ZmqResultServer(port=8088)
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
_dh = str(IniRead('system', 'Decoder_Host'))
|
||||||
|
_dp = int(IniRead('system', 'Decoder_Port'))
|
||||||
|
client = ZmqCmdClient(_dh, _dp)
|
||||||
|
client.connect()
|
||||||
|
client.start_recv_thread(server) # 启动后台接收线程,监听 decoder 8099 回发的 paradigm/result 消息
|
||||||
|
|
||||||
|
time.sleep(1) # 等待连接建立
|
||||||
|
client.send_data('decoderClass', 'mi')
|
||||||
|
|
||||||
|
# MI_IntervalEpoch = [0.5, 4.5],trial时长 = 4.5-0.5 = 4.0s
|
||||||
|
_mi_iv = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
|
||||||
|
_trial_sec = float(_mi_iv[1] - _mi_iv[0]) # 4.0s
|
||||||
|
_margin = 1.0
|
||||||
|
train_time = max(5.0, _trial_sec + _margin) # 训练刺激时长(与 MI_main.py 保持一致)
|
||||||
|
|
||||||
|
# MI epoch latency = interval_epoch[1] // 5 = (4.5*250)//5 = 225包 × 20ms = 4.5s
|
||||||
|
# train_latency = 225包(MI中 train_latency == latency)
|
||||||
|
# 在 train_time 后需再等 epoch_wait 秒,decoder 才能完成 epoch 采集
|
||||||
|
epoch_wait = _mi_iv[1] / _mi_iv[1] * (_mi_iv[1] * 250 // 5) * 0.02 # = latency * 20ms
|
||||||
|
# 更直接的计算:latency = interval_epoch[1] // 5 = int(4.5*250)//5 = 225,225*0.02 = 4.5s
|
||||||
|
epoch_wait = (int(_mi_iv[1] * 250) // 5) * 0.02 # 4.5s
|
||||||
|
|
||||||
|
# predict epoch wait(与 train 相同,MI中 latency == train_latency)
|
||||||
|
predict_epoch_wait = epoch_wait # 4.5s
|
||||||
|
|
||||||
|
test_time = 7.0 # 预测窗口时长(与 MI_main.py 保持一致)
|
||||||
|
right_rehabilitation = float(IniRead('system', 'Right_rehabilitation'))
|
||||||
|
fault_rehabilitation = float(IniRead('system', 'Fault_rehabilitation'))
|
||||||
|
rest_time = float(IniRead('system', 'Rest_time'))
|
||||||
|
|
||||||
|
num_blocks = int(IniRead('system', 'Num_blocks'))
|
||||||
|
num_trials = int(IniRead('system', 'Num_trials'))
|
||||||
|
|
||||||
|
trained = 0
|
||||||
|
Num_Total = 0
|
||||||
|
Num_Success = 0
|
||||||
|
user_choice = []
|
||||||
|
|
||||||
|
print("=" * 50)
|
||||||
|
print("[Headless] 开始运行 MI 通讯流程(无界面)")
|
||||||
|
print(f" MI_IntervalEpoch={_mi_iv}, trial_sec={_trial_sec:.2f}s")
|
||||||
|
print(f" train_time={train_time:.2f}s, epoch_wait={epoch_wait:.2f}s")
|
||||||
|
print(f" test_time={test_time:.2f}s, predict_epoch_wait={predict_epoch_wait:.2f}s")
|
||||||
|
print(f" num_blocks={num_blocks}, num_trials={num_trials}")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# -------- 个体校准阶段 --------
|
||||||
|
print("\n[Phase] 个体校准阶段 (paradigm=0)")
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
while server.paradigm == 0:
|
||||||
|
# 左侧 MI 刺激(train 0,label=1)
|
||||||
|
print(f"\n[Train] 左侧 MI 刺激 (train 0) trained={trained}")
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(0.5) # ding 提示后等待
|
||||||
|
|
||||||
|
client.send_data('train', 0)
|
||||||
|
time.sleep(train_time + epoch_wait) # 等待刺激时间 + epoch 完成时间
|
||||||
|
|
||||||
|
trained += 1
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(1.0) # 类间休息
|
||||||
|
|
||||||
|
# 空闲态样本采集(train 1,label=2)
|
||||||
|
print(f"\n[Train] 空闲态采集 (train 1) trained={trained}")
|
||||||
|
client.send_data('train', 1)
|
||||||
|
time.sleep(train_time + epoch_wait) # 等待刺激时间 + epoch 完成时间
|
||||||
|
|
||||||
|
trained += 1
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(1.0) # 类间休息
|
||||||
|
|
||||||
|
# 个体校准阶段结束
|
||||||
|
print("\n[Phase] 个体校准结束,等待模型训练 (paradigm=2) ...")
|
||||||
|
trained = 0
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# 等待模型训练完成 (paradigm=2 -> paradigm=1)
|
||||||
|
while server.paradigm == 2:
|
||||||
|
print("[Phase] 等待模型训练完成 ...")
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
# -------- 康复训练阶段 --------
|
||||||
|
while server.paradigm == 1:
|
||||||
|
print("\n[Phase] 康复训练阶段 (paradigm=1)")
|
||||||
|
for block_idx in range(num_blocks):
|
||||||
|
print(f"\n [Block {block_idx+1}/{num_blocks}]")
|
||||||
|
time.sleep(10) # 每轮开始前等待
|
||||||
|
|
||||||
|
for trial_idx in range(num_trials):
|
||||||
|
print(f" [Trial {trial_idx+1}/{num_trials}]")
|
||||||
|
|
||||||
|
time.sleep(0.5) # ding 提示
|
||||||
|
server.ChoosenNum = -1
|
||||||
|
|
||||||
|
# 开始预测
|
||||||
|
# MI predict epoch latency = 225包 × 20ms = 4.5s,需额外等待 epoch 完成
|
||||||
|
client.send_data('predict', 1)
|
||||||
|
t_start = time.perf_counter()
|
||||||
|
while time.perf_counter() - t_start < test_time + predict_epoch_wait:
|
||||||
|
if server.ChoosenNum >= 0:
|
||||||
|
Num_Total += 1
|
||||||
|
user_choice.append(server.ChoosenNum)
|
||||||
|
if server.ChoosenNum == 0:
|
||||||
|
Num_Success += 1
|
||||||
|
rest_time = right_rehabilitation
|
||||||
|
elif server.ChoosenNum == 1:
|
||||||
|
rest_time = fault_rehabilitation
|
||||||
|
break
|
||||||
|
time.sleep(0.02)
|
||||||
|
|
||||||
|
trained += 1
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(0.5)
|
||||||
|
time.sleep(rest_time)
|
||||||
|
server.ChoosenNum = -1
|
||||||
|
|
||||||
|
# 训练结束
|
||||||
|
print("\n[Phase] 康复训练结束")
|
||||||
|
break # 退出康复训练循环
|
||||||
|
|
||||||
|
# 统计结果
|
||||||
|
overall_accuracy = Num_Success / Num_Total if Num_Total > 0 else 0
|
||||||
|
print(f"\n[Result] Overall={overall_accuracy:.3f} ({Num_Success}/{Num_Total})")
|
||||||
|
print(f"[Result] user_choice={user_choice}")
|
||||||
|
break # 完成一个完整流程后退出
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n[Headless] 用户中断")
|
||||||
|
finally:
|
||||||
|
client.send_data('predict', 2) # 关闭系统
|
||||||
|
client.send_data('saveData', 0)
|
||||||
|
server.stop()
|
||||||
|
print("[Headless] 已发送关闭指令,退出。")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run_headless()
|
||||||
301
upperHost_stimmock/ssmvep_headless.py
Normal file
301
upperHost_stimmock/ssmvep_headless.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
"""
|
||||||
|
ssmvep_headless.py
|
||||||
|
无界面版 SSMVEP 范式通讯流程模拟脚本。
|
||||||
|
复现 ssmvep_main.py 的完整指令序列(train 0/1/2, rest, predict, saveData),
|
||||||
|
但不依赖 psychopy 也不打开任何窗口/音频,用 time.sleep 替代帧循环等待。
|
||||||
|
|
||||||
|
启动顺序:
|
||||||
|
1. runDecoder.py
|
||||||
|
2. datamock.py
|
||||||
|
3. ssmvep_headless.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import zmq
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from PubLibrary.InifileHelper import IniRead
|
||||||
|
|
||||||
|
personname = 'demo'
|
||||||
|
session = '01'
|
||||||
|
|
||||||
|
DATAMOCK_LABEL_ADDR = 'tcp://127.0.0.1:8101' # datamock 标签命令地址
|
||||||
|
|
||||||
|
|
||||||
|
# ========== ZMQ 结果接收服务 ==========
|
||||||
|
class ZmqResultServer(threading.Thread):
|
||||||
|
def __init__(self, port=8088):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.port = port
|
||||||
|
self.running = True
|
||||||
|
self.energy = 0
|
||||||
|
self.paradigm = 0 # 0=个体校准, 1=康复训练, 2=等待模型训练
|
||||||
|
self.ChoosenNum = -1
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.socket = self.context.socket(zmq.ROUTER)
|
||||||
|
self.socket.bind(f"tcp://0.0.0.0:{self.port}")
|
||||||
|
self.daemon = True
|
||||||
|
self.trial_idx = 0
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
print(f"[Server] UpperHost_Server listening on {self.port}")
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
||||||
|
if len(frames) < 3:
|
||||||
|
continue
|
||||||
|
message = json.loads(frames[2].decode('utf-8'))
|
||||||
|
method = message.get('method')
|
||||||
|
params = message.get('params')
|
||||||
|
if method == 'energy':
|
||||||
|
self.energy = params
|
||||||
|
elif method == 'paradigm':
|
||||||
|
self.paradigm = params
|
||||||
|
print(f"[Server] paradigm -> {params}")
|
||||||
|
elif method == 'result':
|
||||||
|
self.ChoosenNum = params
|
||||||
|
self.trial_idx += 1
|
||||||
|
print(f"[Server] result={self.ChoosenNum} (trial {self.trial_idx})")
|
||||||
|
except zmq.Again:
|
||||||
|
time.sleep(0.005)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Server] error: {e}")
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.running = False
|
||||||
|
self.socket.close()
|
||||||
|
self.context.term()
|
||||||
|
|
||||||
|
|
||||||
|
# ========== ZMQ 命令发送客户端 ==========
|
||||||
|
class ZmqCmdClient:
|
||||||
|
def __init__(self, host, port):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.socket = self.context.socket(zmq.DEALER)
|
||||||
|
# PUSH socket 用于向 datamock.py 发送标签命令
|
||||||
|
self._label_sock = self.context.socket(zmq.PUSH)
|
||||||
|
self._label_sock.connect(DATAMOCK_LABEL_ADDR)
|
||||||
|
print(f"[Client] label PUSH connected to {DATAMOCK_LABEL_ADDR}")
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
self.socket.connect(f"tcp://{self.host}:{self.port}")
|
||||||
|
print(f"[Client] connected to {self.host}:{self.port}")
|
||||||
|
|
||||||
|
def start_recv_thread(self, result_server):
|
||||||
|
"""启动后台线程,持续接收 decoder 通过 8099 ROUTER 回发的消息,并更新 result_server 的状态"""
|
||||||
|
self._result_server = result_server
|
||||||
|
self._stop_recv = threading.Event()
|
||||||
|
|
||||||
|
def _recv_loop():
|
||||||
|
while not self._stop_recv.is_set():
|
||||||
|
try:
|
||||||
|
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
||||||
|
# DEALER 收到的格式: [b'', json_bytes]
|
||||||
|
data_bytes = frames[-1]
|
||||||
|
message = json.loads(data_bytes.decode('utf-8'))
|
||||||
|
method = message.get('method')
|
||||||
|
params = message.get('params')
|
||||||
|
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||||
|
print(f"[{ts}] [CmdClient] recv: {method}={params}")
|
||||||
|
if method == 'paradigm':
|
||||||
|
self._result_server.paradigm = params
|
||||||
|
print(f"[{ts}] [CmdClient] paradigm updated -> {params}")
|
||||||
|
elif method == 'result':
|
||||||
|
self._result_server.ChoosenNum = params
|
||||||
|
self._result_server.trial_idx += 1
|
||||||
|
print(f"[{ts}] [CmdClient] result={params} (trial {self._result_server.trial_idx})")
|
||||||
|
elif method == 'energy':
|
||||||
|
self._result_server.energy = params
|
||||||
|
except zmq.Again:
|
||||||
|
time.sleep(0.005)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[CmdClient recv] error: {e}")
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
self._recv_thread = threading.Thread(target=_recv_loop, daemon=True)
|
||||||
|
self._recv_thread.start()
|
||||||
|
print(f"[Client] 后台接收线程已启动(监听 decoder 8099 回发消息)")
|
||||||
|
|
||||||
|
def stop_recv_thread(self):
|
||||||
|
if hasattr(self, '_stop_recv'):
|
||||||
|
self._stop_recv.set()
|
||||||
|
|
||||||
|
def _send_label(self, label_value):
|
||||||
|
"""向 datamock.py 发送标签命令"""
|
||||||
|
try:
|
||||||
|
self._label_sock.send_string(str(label_value), zmq.NOBLOCK)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Client] label send error: {e}")
|
||||||
|
|
||||||
|
def send_data(self, method, params):
|
||||||
|
msg = {'method': method, 'params': params}
|
||||||
|
try:
|
||||||
|
self.socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
|
||||||
|
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||||
|
print(f"[{ts}] send_data: {method}={params}")
|
||||||
|
# 根据 train/predict 命令向 datamock 发送标签
|
||||||
|
if method == 'train':
|
||||||
|
if params == 0:
|
||||||
|
self._send_label(1)
|
||||||
|
print(f"[Label] train 0 -> datamock label=1")
|
||||||
|
elif params == 1:
|
||||||
|
self._send_label(2)
|
||||||
|
print(f"[Label] train 1 -> datamock label=2")
|
||||||
|
elif method == 'predict':
|
||||||
|
self._send_label(99)
|
||||||
|
print(f"[Label] predict -> datamock label=99")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Client] send error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 主流程 ==========
|
||||||
|
def run_headless():
|
||||||
|
server = ZmqResultServer(port=8088)
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
_dh = str(IniRead('system', 'Decoder_Host'))
|
||||||
|
_dp = int(IniRead('system', 'Decoder_Port'))
|
||||||
|
client = ZmqCmdClient(_dh, _dp)
|
||||||
|
client.connect()
|
||||||
|
client.start_recv_thread(server) # 启动后台接收线程,监听 decoder 8099 回发的 paradigm/result 消息
|
||||||
|
|
||||||
|
time.sleep(1) # 等待连接建立
|
||||||
|
client.send_data('decoderClass', 'ssmvep')
|
||||||
|
|
||||||
|
train_time = 2.5 # 每轮训练刺激时长 (s)
|
||||||
|
test_time = 2.5 # 每轮测试刺激时长 (s)
|
||||||
|
right_rehabilitation = float(IniRead('system', 'Right_rehabilitation'))
|
||||||
|
fault_rehabilitation = float(IniRead('system', 'Fault_rehabilitation'))
|
||||||
|
rest_time = float(IniRead('system', 'Rest_time'))
|
||||||
|
|
||||||
|
num_blocks = int(IniRead('system', 'Num_blocks'))
|
||||||
|
num_trials = int(IniRead('system', 'Num_trials'))
|
||||||
|
|
||||||
|
position = [0, 1]
|
||||||
|
truePos_seq = position * int(num_trials / len(position))
|
||||||
|
truePos_seq = np.random.permutation(truePos_seq).tolist()
|
||||||
|
user_choice = []
|
||||||
|
|
||||||
|
os.makedirs('EEGFiles', exist_ok=True)
|
||||||
|
seq_file_path = f'EEGFiles/pos_seq_{personname}{session}_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
||||||
|
seq_info = {
|
||||||
|
'position': position,
|
||||||
|
'sequence': truePos_seq,
|
||||||
|
'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
}
|
||||||
|
with open(seq_file_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(seq_info, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
trained = 0
|
||||||
|
Num_Total = 0
|
||||||
|
Num_Success = 0
|
||||||
|
|
||||||
|
print("=" * 50)
|
||||||
|
print("[Headless] 开始运行 SSMVEP 通讯流程(无界面)")
|
||||||
|
print(f" num_blocks={num_blocks}, num_trials={num_trials}")
|
||||||
|
print(f" train_time={train_time}s, test_time={test_time}s")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# -------- 个体校准阶段 --------
|
||||||
|
print("\n[Phase] 个体校准阶段 (paradigm=0)")
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# epoch完成需要的额外等待时间:train_latency=120包×20ms=2.4s
|
||||||
|
# 在train_time后需再等epoch_wait秒,decoder才能完成epoch采集并取出数据
|
||||||
|
epoch_wait = 2.4 # 秒,与train_latency对应
|
||||||
|
|
||||||
|
while server.paradigm == 0:
|
||||||
|
# 左腿刺激
|
||||||
|
print(f"\n[Train] 左腿刺激 (train 0) trained={trained}")
|
||||||
|
client.send_data('train', 0)
|
||||||
|
time.sleep(train_time + epoch_wait) # 等待刺激时间+epoch完成时间
|
||||||
|
|
||||||
|
trained += 1
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(max(0, abs(fault_rehabilitation - train_time) - epoch_wait))
|
||||||
|
|
||||||
|
# 右腿刺激
|
||||||
|
print(f"\n[Train] 右腿刺激 (train 1) trained={trained}")
|
||||||
|
client.send_data('train', 1)
|
||||||
|
time.sleep(train_time + epoch_wait) # 等待刺激时间+epoch完成时间
|
||||||
|
|
||||||
|
trained += 1
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(max(0, fault_rehabilitation - epoch_wait))
|
||||||
|
|
||||||
|
# 个体校准阶段结束
|
||||||
|
print("\n[Phase] 个体校准结束,等待 paradigm=1 ...")
|
||||||
|
trained = 0
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# -------- 康复训练阶段 --------
|
||||||
|
while server.paradigm == 1:
|
||||||
|
print("\n[Phase] 康复训练阶段 (paradigm=1)")
|
||||||
|
for block_idx in range(num_blocks):
|
||||||
|
print(f"\n [Block {block_idx+1}/{num_blocks}]")
|
||||||
|
time.sleep(10) # 每轮开始前等待
|
||||||
|
|
||||||
|
for trial_idx in range(num_trials):
|
||||||
|
true_position = truePos_seq[trial_idx]
|
||||||
|
print(f" [Trial {trial_idx+1}/{num_trials}] true_pos={true_position}")
|
||||||
|
|
||||||
|
time.sleep(0.5) # 提示 + 叮声
|
||||||
|
server.ChoosenNum = -1
|
||||||
|
|
||||||
|
# 开始测试
|
||||||
|
# predict epoch latency = 115包×20ms = 2.3s,需额外等待epoch完成
|
||||||
|
predict_epoch_wait = 2.3 # 秒,与predict latency=115包对应
|
||||||
|
client.send_data('predict', 1)
|
||||||
|
t_start = time.perf_counter()
|
||||||
|
while time.perf_counter() - t_start < test_time + predict_epoch_wait:
|
||||||
|
if server.ChoosenNum >= 0:
|
||||||
|
Num_Total += 1
|
||||||
|
user_choice.append(server.ChoosenNum)
|
||||||
|
if server.ChoosenNum in [0, 1]:
|
||||||
|
Num_Success += 1
|
||||||
|
rest_time = right_rehabilitation
|
||||||
|
break
|
||||||
|
time.sleep(0.02)
|
||||||
|
|
||||||
|
trained += 1
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(0.5)
|
||||||
|
time.sleep(rest_time)
|
||||||
|
server.ChoosenNum = -1
|
||||||
|
|
||||||
|
# 训练结束
|
||||||
|
print("\n[Phase] 康复训练结束")
|
||||||
|
break # 退出康复训练循环
|
||||||
|
|
||||||
|
# 统计结果
|
||||||
|
overall_accuracy = Num_Success / Num_Total if Num_Total > 0 else 0
|
||||||
|
expected_seq = truePos_seq * num_blocks
|
||||||
|
min_len = min(len(user_choice), len(expected_seq))
|
||||||
|
same_count = sum(1 for a, b in zip(user_choice[:min_len], expected_seq[:min_len]) if a == b)
|
||||||
|
true_accuracy = same_count / min_len if min_len > 0 else 0
|
||||||
|
print(f"\n[Result] Overall={overall_accuracy:.3f} ({Num_Success}/{Num_Total})")
|
||||||
|
print(f"[Result] TrueAcc={true_accuracy:.3f} ({same_count}/{min_len})")
|
||||||
|
break # 完成一个完整流程后退出
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n[Headless] 用户中断")
|
||||||
|
finally:
|
||||||
|
client.send_data('predict', 2) # 关闭系统
|
||||||
|
client.send_data('saveData', 0)
|
||||||
|
server.stop()
|
||||||
|
print("[Headless] 已发送关闭指令,退出。")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run_headless()
|
||||||
364
upperHost_stimmock/ssvep_main.py
Normal file
364
upperHost_stimmock/ssvep_main.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
from psychopy import visual, core, logging # import some libraries from PsychoPy
|
||||||
|
import random
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# LAB STREAMING LAYER1
|
||||||
|
from pylsl import StreamInfo, StreamOutlet
|
||||||
|
from psychopy import event
|
||||||
|
import numpy as np
|
||||||
|
from DecoderDW.Server import TCPServer
|
||||||
|
from DecoderDW.Client import TCPClient
|
||||||
|
# import subprocess
|
||||||
|
|
||||||
|
# ----------------------
|
||||||
|
# constants
|
||||||
|
# size of the window
|
||||||
|
WINWIDTH = 1920
|
||||||
|
WINHEIGHT = 1080
|
||||||
|
REFRESH_RATE = 144
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_keypress():
|
||||||
|
keys = event.getKeys()
|
||||||
|
if keys:
|
||||||
|
return keys[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def shutdown(win,client):
|
||||||
|
client.send_data('saveData', 0)
|
||||||
|
client.send_data('predict',2)
|
||||||
|
win.close()
|
||||||
|
core.quit()
|
||||||
|
|
||||||
|
|
||||||
|
# end of configuration
|
||||||
|
# ----------------------
|
||||||
|
|
||||||
|
def generate_square_wave(frequency, sampling_rate=REFRESH_RATE, duration=5):
|
||||||
|
"""
|
||||||
|
生成方波序列
|
||||||
|
|
||||||
|
参数:
|
||||||
|
frequency (float): 频率(Hz)
|
||||||
|
sampling_rate (int): 采样率(Hz),应与屏幕刷新率一致
|
||||||
|
duration (float): 时长(秒)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
square_wave (list): 方波序列
|
||||||
|
"""
|
||||||
|
# 计算总点数
|
||||||
|
n_points = int(duration * sampling_rate)
|
||||||
|
|
||||||
|
# 生成时间序列
|
||||||
|
time = np.linspace(0, duration, n_points, endpoint=False)
|
||||||
|
|
||||||
|
# 生成正弦波数据
|
||||||
|
sin_wave = np.sin(2 * np.pi * frequency * time)
|
||||||
|
# 生成方波数据
|
||||||
|
square_wave = np.where(sin_wave >= 0, 1, 0)
|
||||||
|
|
||||||
|
return square_wave.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
# 启动一个进程,不等待其完成
|
||||||
|
import os
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# ----------------------------------------------------------------------------------
|
||||||
|
# main window settings
|
||||||
|
main_win = visual.Window(size=(WINWIDTH, WINHEIGHT), units='height', screen=0, fullscr=False,
|
||||||
|
gammaErrorPolicy='warn', color=(0.7, 0.7, 0.7))
|
||||||
|
print('starting 1')
|
||||||
|
# Set up LabStreamingLayer stream.
|
||||||
|
info = StreamInfo(name='psychopy_stimuli', type='Markers', channel_count=1, channel_format='string',
|
||||||
|
source_id='psychopy_stimuli_001')
|
||||||
|
outlet = StreamOutlet(info) # Broadcast the stream.
|
||||||
|
|
||||||
|
imageStim1 = visual.ImageStim(main_win, size=(300, 300), pos=(-600, 300), units='pix', image='UI/figures/xy.jpg')
|
||||||
|
txtStim1 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||||
|
italic=False, pos=(-600, 30))
|
||||||
|
|
||||||
|
imageStim2 = visual.ImageStim(main_win, size=(300, 300), pos=(0, 300), units='pix', image='UI/figures/xy.jpg')
|
||||||
|
txtStim2 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||||
|
italic=False, pos=(0, 30))
|
||||||
|
|
||||||
|
imageStim3 = visual.ImageStim(main_win, size=(300, 300), pos=(600, 300), units='pix', image='UI/figures/xy.jpg')
|
||||||
|
txtStim3 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||||
|
italic=False, pos=(600, 30))
|
||||||
|
imageStim4 = visual.ImageStim(main_win, size=(300, 300), pos=(-600, -200), units='pix', image='UI/figures/xy.jpg')
|
||||||
|
txtStim4 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||||
|
italic=False, pos=(-600, -470))
|
||||||
|
imageStim5 = visual.ImageStim(main_win, size=(300, 300), pos=(0, -200), units='pix', image='UI/figures/xy.jpg')
|
||||||
|
txtStim5 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||||
|
italic=False, pos=(0, -470))
|
||||||
|
imageStim6 = visual.ImageStim(main_win, size=(300, 300), pos=(600, -200), units='pix', image='UI/figures/xy.jpg')
|
||||||
|
txtStim6 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||||
|
italic=False, pos=(600, -470))
|
||||||
|
imageStim1red = visual.ImageStim(main_win, size=(300, 300), pos=(-600, 300), units='pix', image='UI/figures/xy_red.jpg')
|
||||||
|
imageStim2red = visual.ImageStim(main_win, size=(300, 300), pos=(0, 300), units='pix', image='UI/figures/xy_red.jpg')
|
||||||
|
imageStim3red = visual.ImageStim(main_win, size=(300, 300), pos=(600, 300), units='pix', image='UI/figures/xy_red.jpg')
|
||||||
|
imageStim4red = visual.ImageStim(main_win, size=(300, 300), pos=(-600, -200), units='pix', image='UI/figures/xy_red.jpg')
|
||||||
|
imageStim5red = visual.ImageStim(main_win, size=(300, 300), pos=(0, -200), units='pix', image='UI/figures/xy_red.jpg')
|
||||||
|
imageStim6red = visual.ImageStim(main_win, size=(300, 300), pos=(600, -200), units='pix', image='UI/figures/xy_red.jpg')
|
||||||
|
|
||||||
|
|
||||||
|
frequencies = [25,26,27,28,29,30] #[9,10,11,12,13,14] #[30,31,32,33,34,35] [25,26,27,28,29,30]
|
||||||
|
# 生成方波数据
|
||||||
|
square_wave_9 = generate_square_wave(frequencies[0], REFRESH_RATE, 5)
|
||||||
|
square_wave_11 = generate_square_wave(frequencies[1], REFRESH_RATE, 5)
|
||||||
|
square_wave_12 = generate_square_wave(frequencies[2], REFRESH_RATE, 5)
|
||||||
|
square_wave_13 = generate_square_wave(frequencies[3], REFRESH_RATE, 5)
|
||||||
|
square_wave_14 = generate_square_wave(frequencies[4], REFRESH_RATE, 5)
|
||||||
|
square_wave_15 = generate_square_wave(frequencies[5], REFRESH_RATE, 5)
|
||||||
|
|
||||||
|
# 创建刺激对象列表,便于管理
|
||||||
|
image_stims = [imageStim1, imageStim2, imageStim3, imageStim4, imageStim5, imageStim6]
|
||||||
|
txt_stims = [txtStim1, txtStim2, txtStim3, txtStim4, txtStim5, txtStim6]
|
||||||
|
square_waves = [square_wave_9, square_wave_11, square_wave_12, square_wave_13, square_wave_14, square_wave_15]
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
# grating.color = 'black'
|
||||||
|
server = TCPServer()
|
||||||
|
server.start()
|
||||||
|
client = TCPClient('127.0.0.1', 8099)
|
||||||
|
client.connect()
|
||||||
|
print('Connected decoder_main')
|
||||||
|
# client.send_data('impedance', 1)
|
||||||
|
# time.sleep(20)
|
||||||
|
# client.send_data('impedance', 2)
|
||||||
|
client.send_data('targetFreqs', frequencies) # 使用frequencies变量,确保与刺激频率一致
|
||||||
|
time.sleep(1)
|
||||||
|
# 开启全程数据保存到 EEGFiles
|
||||||
|
client.send_data('saveData',1)
|
||||||
|
# client.send_data('impedance',1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 实验参数
|
||||||
|
repeats = 3
|
||||||
|
seq_freq = frequencies * repeats
|
||||||
|
seq_freq = np.random.permutation(seq_freq).tolist()
|
||||||
|
num_trials = len(seq_freq) # 总试验次数, 6*6=36
|
||||||
|
trial_count = 0
|
||||||
|
|
||||||
|
# 在线解码精度计算
|
||||||
|
online_results = [] # 存储每个trial的解码结果
|
||||||
|
correct_predictions = 0 # 正确预测计数
|
||||||
|
|
||||||
|
# 保存序列信息
|
||||||
|
seq_info = {
|
||||||
|
'total_trials': num_trials,
|
||||||
|
'frequencies': frequencies,
|
||||||
|
'sequence': seq_freq,
|
||||||
|
'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
}
|
||||||
|
# 保存序列信息到文件
|
||||||
|
import json
|
||||||
|
seq_file_path = f'EEGFiles/sequence_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
||||||
|
with open(seq_file_path, 'a', encoding='utf-8') as f:
|
||||||
|
json.dump(seq_info, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
#========================Trials Started======================#
|
||||||
|
while trial_count < num_trials:
|
||||||
|
# 从序列中获取当前试验的目标频率
|
||||||
|
target_freq = seq_freq[trial_count]
|
||||||
|
target_freq_index = frequencies.index(target_freq)
|
||||||
|
print(f'Trials {trial_count + 1}/{num_trials} - Target Frequency: {target_freq}Hz (Label: {target_freq_index + 1})')
|
||||||
|
|
||||||
|
# Stage 1: Cue Stage
|
||||||
|
# print('Cue Stage: The target frequency is in Red')
|
||||||
|
client.send_data('setLabelAndTrialInfo', {
|
||||||
|
'label': 0,
|
||||||
|
'trial_info': {
|
||||||
|
'trial': trial_count + 1,
|
||||||
|
'phase': 'cue',
|
||||||
|
'target_freq': target_freq
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
for frameN in range(int(1 * REFRESH_RATE)): # 1秒提示
|
||||||
|
key_press = get_keypress()
|
||||||
|
if key_press in ['q']:
|
||||||
|
shutdown(main_win, client)
|
||||||
|
|
||||||
|
# 显示所有刺激,目标刺激为红色
|
||||||
|
for i, stim in enumerate(image_stims):
|
||||||
|
if i == target_freq_index:
|
||||||
|
# 目标刺激显示红色
|
||||||
|
if i == 0:
|
||||||
|
imageStim1red.draw()
|
||||||
|
elif i == 1:
|
||||||
|
imageStim2red.draw()
|
||||||
|
elif i == 2:
|
||||||
|
imageStim3red.draw()
|
||||||
|
elif i == 3:
|
||||||
|
imageStim4red.draw()
|
||||||
|
elif i == 4:
|
||||||
|
imageStim5red.draw()
|
||||||
|
elif i == 5:
|
||||||
|
imageStim6red.draw()
|
||||||
|
else:
|
||||||
|
# 其他刺激显示正常颜色
|
||||||
|
stim.draw()
|
||||||
|
|
||||||
|
main_win.flip()
|
||||||
|
|
||||||
|
# Stage 2: Flanker Stimulus
|
||||||
|
# print('Flanker Stage: flank all frequencies')
|
||||||
|
client.send_data('predict', 1)
|
||||||
|
client.send_data('setLabelAndTrialInfo', {
|
||||||
|
'label': target_freq_index + 1, # 设置目标频率标签 这里+1,是因为0代表不记录数据
|
||||||
|
'trial_info': {
|
||||||
|
'trial': trial_count + 1, # trial 从0开始
|
||||||
|
'phase': 'stimulus',
|
||||||
|
'target_freq': target_freq
|
||||||
|
}
|
||||||
|
})
|
||||||
|
outlet.push_sample(['S 1'])
|
||||||
|
|
||||||
|
for frameN in range(6 * REFRESH_RATE): # 6秒刺激
|
||||||
|
key_press = get_keypress()
|
||||||
|
if key_press in ['q']:
|
||||||
|
shutdown(main_win, client)
|
||||||
|
|
||||||
|
# 所有频率按照方波闪烁
|
||||||
|
if square_wave_9[frameN % len(square_wave_9)] == 1:
|
||||||
|
imageStim1.draw()
|
||||||
|
if square_wave_11[frameN % len(square_wave_11)] == 1:
|
||||||
|
imageStim2.draw()
|
||||||
|
if square_wave_12[frameN % len(square_wave_12)] == 1:
|
||||||
|
imageStim3.draw()
|
||||||
|
if square_wave_13[frameN % len(square_wave_13)] == 1:
|
||||||
|
imageStim4.draw()
|
||||||
|
if square_wave_14[frameN % len(square_wave_14)] == 1:
|
||||||
|
imageStim5.draw()
|
||||||
|
if square_wave_15[frameN % len(square_wave_15)] == 1:
|
||||||
|
imageStim6.draw()
|
||||||
|
|
||||||
|
main_win.flip()
|
||||||
|
if server.ChoosenNum != -1:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 记录在线解码结果
|
||||||
|
predicted_freq_index = server.ChoosenNum # 解码结果
|
||||||
|
predicted_freq = frequencies[predicted_freq_index] if predicted_freq_index != -1 else -1
|
||||||
|
|
||||||
|
# 判断解码是否正确
|
||||||
|
is_correct = (predicted_freq_index == target_freq_index) if predicted_freq_index != -1 else False
|
||||||
|
if is_correct:
|
||||||
|
correct_predictions += 1
|
||||||
|
|
||||||
|
# 记录trial结果
|
||||||
|
trial_result = {
|
||||||
|
'trial': trial_count + 1,
|
||||||
|
'target_freq': target_freq,
|
||||||
|
'target_freq_index': target_freq_index,
|
||||||
|
'predicted_freq': predicted_freq,
|
||||||
|
'predicted_freq_index': predicted_freq_index,
|
||||||
|
'is_correct': is_correct,
|
||||||
|
'status': 'Success' if predicted_freq_index != -1 else 'Failed'
|
||||||
|
}
|
||||||
|
online_results.append(trial_result)
|
||||||
|
|
||||||
|
# 打印当前trial结果
|
||||||
|
status_symbol = "✓" if is_correct else "✗"
|
||||||
|
if predicted_freq_index == -1:
|
||||||
|
print(f'Trial {trial_count + 1}: 目标{target_freq}Hz -> 解码失败 - {status_symbol}')
|
||||||
|
else:
|
||||||
|
print(f'Trial {trial_count + 1}: 目标{target_freq}Hz -> 预测{predicted_freq}Hz - {status_symbol}')
|
||||||
|
|
||||||
|
|
||||||
|
# Stage 3: Decoding Feedback
|
||||||
|
outlet.push_sample(['S 2'])
|
||||||
|
client.send_data('setLabelAndTrialInfo', {
|
||||||
|
'label': 0, # 反馈阶段标签为0
|
||||||
|
'trial_info': {
|
||||||
|
'trial': trial_count + 1,
|
||||||
|
'phase': 'feedback',
|
||||||
|
'target_freq': target_freq
|
||||||
|
}
|
||||||
|
})
|
||||||
|
# print('反馈阶段: 显示解码结果')
|
||||||
|
|
||||||
|
for frameN in range(1 * REFRESH_RATE): # 1秒反馈
|
||||||
|
key_press = get_keypress()
|
||||||
|
if key_press in ['q']:
|
||||||
|
shutdown(main_win, client)
|
||||||
|
|
||||||
|
# 显示所有刺激但不闪烁
|
||||||
|
for stim in image_stims:
|
||||||
|
stim.draw()
|
||||||
|
|
||||||
|
# 显示解码结果
|
||||||
|
if server.ChoosenNum == 0:
|
||||||
|
txtStim1.draw()
|
||||||
|
elif server.ChoosenNum == 1:
|
||||||
|
txtStim2.draw()
|
||||||
|
elif server.ChoosenNum == 2:
|
||||||
|
txtStim3.draw()
|
||||||
|
elif server.ChoosenNum == 3:
|
||||||
|
txtStim4.draw()
|
||||||
|
elif server.ChoosenNum == 4:
|
||||||
|
txtStim5.draw()
|
||||||
|
elif server.ChoosenNum == 5:
|
||||||
|
txtStim6.draw()
|
||||||
|
|
||||||
|
main_win.flip()
|
||||||
|
|
||||||
|
server.ChoosenNum = -1
|
||||||
|
trial_count += 1
|
||||||
|
|
||||||
|
# 计算总体在线解码精度
|
||||||
|
total_trials = len(online_results)
|
||||||
|
successful_trials = len([r for r in online_results if r['status'] == 'Success'])
|
||||||
|
failed_trials = len([r for r in online_results if r['status'] == 'Failed'])
|
||||||
|
overall_accuracy = correct_predictions / total_trials if total_trials > 0 else 0
|
||||||
|
|
||||||
|
# Print Accuracy
|
||||||
|
print(f"Total Accuracy: {overall_accuracy:.3f} ({correct_predictions}/{total_trials})")
|
||||||
|
|
||||||
|
# 按频率分析准确率
|
||||||
|
print(f"\n=== 按频率分析准确率 ===")
|
||||||
|
freq_accuracy = {}
|
||||||
|
for result in online_results:
|
||||||
|
freq = result['target_freq']
|
||||||
|
if freq not in freq_accuracy:
|
||||||
|
freq_accuracy[freq] = {'correct': 0, 'total': 0, 'failed': 0}
|
||||||
|
|
||||||
|
freq_accuracy[freq]['total'] += 1
|
||||||
|
if result['status'] == 'Failed':
|
||||||
|
freq_accuracy[freq]['failed'] += 1
|
||||||
|
elif result['is_correct']:
|
||||||
|
freq_accuracy[freq]['correct'] += 1
|
||||||
|
|
||||||
|
print(f"{'频率':<8} {'准确率':<8} {'正确/总数':<10} {'失败数':<8}")
|
||||||
|
print("-" * 40)
|
||||||
|
for freq in sorted(freq_accuracy.keys()):
|
||||||
|
stats = freq_accuracy[freq]
|
||||||
|
accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
|
||||||
|
print(f"{freq}Hz{'':<4} {accuracy:.3f}{'':<4} {stats['correct']}/{stats['total']}{'':<6} {stats['failed']}")
|
||||||
|
|
||||||
|
# 保存在线解码结果到文件
|
||||||
|
online_results_file = f'EEGFiles/online_results_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
||||||
|
online_summary = {
|
||||||
|
'total_trials': total_trials,
|
||||||
|
'successful_trials': successful_trials,
|
||||||
|
'failed_trials': failed_trials,
|
||||||
|
'correct_predictions': correct_predictions,
|
||||||
|
'overall_accuracy': overall_accuracy,
|
||||||
|
# 'freq_accuracy': freq_accuracy,
|
||||||
|
'trial_results': online_results,
|
||||||
|
# 'end_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(online_results_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(online_summary, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
client.send_data('predict',2) # 关闭系统
|
||||||
|
main_win.close()
|
||||||
304
verify_datamock.py
Normal file
304
verify_datamock.py
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
"""
|
||||||
|
datamock 验证脚本(模拟算法端)
|
||||||
|
作为 ZMQ ROUTER 监听 8100 端口,等待 datamock.py 连接并验证数据流
|
||||||
|
|
||||||
|
运行顺序:
|
||||||
|
第一步: python verify_datamock.py (先启动,监听 8100)
|
||||||
|
第二步: python datamock.py (后启动,连接 8100)
|
||||||
|
"""
|
||||||
|
import zmq
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
import sys
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('TkAgg')
|
||||||
|
|
||||||
|
# 在导入 pyplot 之前确保 Tkinter 正确初始化
|
||||||
|
try:
|
||||||
|
import tkinter as tk
|
||||||
|
root = tk.Tk()
|
||||||
|
root.withdraw() # 隐藏主窗口,我们只需要它的事件循环
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] Tkinter 初始化警告: {e}")
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# ===== 可视化参数 =====
|
||||||
|
PLOT_WINDOW_SEC = 2.0 # 滑动窗口时长(秒)
|
||||||
|
PLOT_CHANNELS = [0, 1, 2, 3] # 要显示的 EEG 通道索引
|
||||||
|
|
||||||
|
SERVER_ADDR = 'tcp://127.0.0.1:8100'
|
||||||
|
FS = 250
|
||||||
|
N_SAMPLES_PER_PKT = 5
|
||||||
|
N_CHAN = 66
|
||||||
|
EEG_FREQ = 10
|
||||||
|
EEG_AMP = 100.0 # EEG 幅值 100μV(峰值)
|
||||||
|
EEG_AMP_MEAN = EEG_AMP * 2 / np.pi # 正弦波 |mean| ≈ 63.7μV
|
||||||
|
EEG_AMP_TOLERANCE = 1.5 # 幅值容差倍数
|
||||||
|
LABEL_INTERVAL = 5
|
||||||
|
FFT_SAMPLES = 250 # 做一次 FFT 需要的采样点数(1s数据)
|
||||||
|
EXPECTED_BYTES = N_SAMPLES_PER_PKT * N_CHAN * 4 # 1320 bytes (5*66*4)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_fft(samples):
|
||||||
|
"""对 Ch0 数据做 FFT,返回峰值频率"""
|
||||||
|
freqs = np.fft.rfftfreq(FFT_SAMPLES, d=1 / FS)
|
||||||
|
fft_mag = np.abs(np.fft.rfft(samples))
|
||||||
|
peak_idx = np.argmax(fft_mag[1:]) + 1 # 跳过 DC
|
||||||
|
return freqs[peak_idx], fft_mag, freqs
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
ctx = zmq.Context()
|
||||||
|
sock = ctx.socket(zmq.ROUTER)
|
||||||
|
sock.bind(SERVER_ADDR)
|
||||||
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] ZMQ ROUTER 绑定 {SERVER_ADDR},等待 datamock.py 连接...\n")
|
||||||
|
|
||||||
|
# ===== 初始化交互式绘图 =====
|
||||||
|
plt.ion() # 开启交互模式
|
||||||
|
fig = plt.figure(figsize=(14, 10))
|
||||||
|
fig.suptitle('EEG Data Monitor (Real-time)', fontsize=14)
|
||||||
|
|
||||||
|
# 使用 GridSpec 进行布局
|
||||||
|
from matplotlib.gridspec import GridSpec
|
||||||
|
gs = GridSpec(len(PLOT_CHANNELS) + 2, 1, figure=fig, hspace=0.3)
|
||||||
|
axes = []
|
||||||
|
lines_eeg = []
|
||||||
|
for i, ch in enumerate(PLOT_CHANNELS):
|
||||||
|
ax = fig.add_subplot(gs[i])
|
||||||
|
axes.append(ax)
|
||||||
|
ax.set_ylabel(f'Ch{ch} (μV)', fontsize=8)
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.set_ylim(-150, 150)
|
||||||
|
line, = ax.plot([], [], lw=0.8)
|
||||||
|
lines_eeg.append(line)
|
||||||
|
ax.set_title(f'EEG Channel {ch}', fontsize=9)
|
||||||
|
|
||||||
|
# 标签通道子图 (Ch64 - 标签值)
|
||||||
|
ax_label = fig.add_subplot(gs[len(PLOT_CHANNELS)])
|
||||||
|
axes.append(ax_label)
|
||||||
|
ax_label.set_ylabel('Label Value', fontsize=8)
|
||||||
|
ax_label.grid(True, alpha=0.3)
|
||||||
|
ax_label.set_ylim(-0.5, 2.5)
|
||||||
|
line_label, = ax_label.plot([], [], 'ro-', lw=1.5, markersize=4)
|
||||||
|
line_label_data = line_label
|
||||||
|
ax_label.set_title('Ch64 - Label Value', fontsize=9)
|
||||||
|
|
||||||
|
# Ch65 标签序号子图
|
||||||
|
ax_seq = fig.add_subplot(gs[len(PLOT_CHANNELS) + 1])
|
||||||
|
axes.append(ax_seq)
|
||||||
|
ax_seq.set_ylabel('Label Seq', fontsize=8)
|
||||||
|
ax_seq.set_xlabel('Time (samples)', fontsize=8)
|
||||||
|
ax_seq.grid(True, alpha=0.3)
|
||||||
|
ax_seq.set_ylim(-0.5, 10)
|
||||||
|
line_seq, = ax_seq.plot([], [], 'gs-', lw=1.5, markersize=4)
|
||||||
|
line_seq_data = line_seq
|
||||||
|
ax_seq.set_title('Ch65 - Label Sequence', fontsize=9)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# ===== 状态 =====
|
||||||
|
global_idx = 0 # 全局采样点索引
|
||||||
|
label_events = [] # 捕获的标签事件
|
||||||
|
start_time = None
|
||||||
|
fft_done = False
|
||||||
|
fft_buffer = [] # 暂存前 250 点做 FFT
|
||||||
|
ch64_zero_ok = True # 验证 Ch64 非标签采样点均为 0
|
||||||
|
ch65_zero_ok = True # 验证 Ch65 非标签采样点均为 0
|
||||||
|
label_pos_ok_all = True # 验证标签均在包内索引 4
|
||||||
|
|
||||||
|
# ===== 数据缓冲区 =====
|
||||||
|
max_samples = int(FS * PLOT_WINDOW_SEC)
|
||||||
|
eeg_buffer = {ch: np.zeros(max_samples) for ch in PLOT_CHANNELS}
|
||||||
|
label_buffer = np.zeros(max_samples)
|
||||||
|
seq_buffer = np.zeros(max_samples)
|
||||||
|
time_axis = np.arange(max_samples)
|
||||||
|
|
||||||
|
# ZMQ 收发统计
|
||||||
|
recv_count = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 首次 pause 用于显示窗口
|
||||||
|
plt.pause(0.5)
|
||||||
|
print(f"[INFO] 交互窗口已显示,如未看到请检查任务栏")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# ROUTER recv: prepended 一个 identity 帧
|
||||||
|
# datamock 发送 3帧 [b'datamock', b'', data_bytes]
|
||||||
|
# ROUTER 接收后变成 4帧 [router_identity, b'datamock', b'', data_bytes]
|
||||||
|
frames = sock.recv_multipart()
|
||||||
|
recv_count += 1
|
||||||
|
now = time.time()
|
||||||
|
if start_time is None:
|
||||||
|
start_time = now
|
||||||
|
|
||||||
|
# 帧格式: [router_identity, b'datamock', b'', data_bytes]
|
||||||
|
router_id = frames[0] # ROUTER 添加的身份帧
|
||||||
|
identity = frames[1] # 发送端的 identity
|
||||||
|
_empty = frames[2] # 空帧
|
||||||
|
raw_data = frames[3] # 实际数据字节
|
||||||
|
|
||||||
|
# 数据长度校验
|
||||||
|
if len(raw_data) != EXPECTED_BYTES:
|
||||||
|
print(f"[ERROR] 数据长度错误: 期望{EXPECTED_BYTES}字节, 实际{len(raw_data)}字节")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 解析为 [5, 66] float32 数组
|
||||||
|
packet = np.frombuffer(raw_data, dtype=np.float32).reshape(N_SAMPLES_PER_PKT, N_CHAN)
|
||||||
|
|
||||||
|
elapsed = now - start_time
|
||||||
|
|
||||||
|
# ===== 验证 1: 数据形状 =====
|
||||||
|
if recv_count == 1:
|
||||||
|
shape_ok = packet.shape == (N_SAMPLES_PER_PKT, N_CHAN)
|
||||||
|
print(f"[{'✓' if shape_ok else '✗'}] 数据形状: {packet.shape} "
|
||||||
|
f"(期望 [{N_SAMPLES_PER_PKT}, {N_CHAN}])")
|
||||||
|
if not shape_ok:
|
||||||
|
print(f" ✗ 形状不匹配,退出")
|
||||||
|
break
|
||||||
|
|
||||||
|
# ===== 验证 2: EEG 幅值(首包) =====
|
||||||
|
if recv_count == 1:
|
||||||
|
eeg = packet[:, :64]
|
||||||
|
amp_mean = np.mean(np.abs(eeg))
|
||||||
|
amp_ok = amp_mean <= EEG_AMP_MEAN * EEG_AMP_TOLERANCE
|
||||||
|
print(f"[{'✓' if amp_ok else '✗'}] EEG 幅值: 均值={amp_mean:.2f}μV "
|
||||||
|
f"(期望 ~{EEG_AMP_MEAN:.2f}μV,峰值 ~{EEG_AMP:.2f}μV)")
|
||||||
|
if not amp_ok:
|
||||||
|
print(f" ✗ 幅值超出容差范围")
|
||||||
|
|
||||||
|
# ===== 验证 3: EEG 频率(首秒数据收集满后做 FFT) =====
|
||||||
|
fft_buffer.append(packet[:, 0].copy()) # 收集 Ch0
|
||||||
|
|
||||||
|
if not fft_done and len(fft_buffer) * N_SAMPLES_PER_PKT >= FFT_SAMPLES:
|
||||||
|
# 凑够 250 点,做 FFT
|
||||||
|
all_ch0 = np.concatenate(fft_buffer)[:FFT_SAMPLES]
|
||||||
|
peak_freq, fft_mag, freqs = validate_fft(all_ch0)
|
||||||
|
freq_ok = abs(peak_freq - EEG_FREQ) < 1.0
|
||||||
|
|
||||||
|
print(f"[{'✓' if freq_ok else '✗'}] EEG 频率: 峰值={peak_freq:.1f}Hz "
|
||||||
|
f"(期望 ~{EEG_FREQ}Hz)")
|
||||||
|
print(f" FFT 幅度谱前 5 峰值:")
|
||||||
|
top5 = np.argsort(fft_mag[1:])[-5:][::-1] + 1
|
||||||
|
for rank, idx in enumerate(top5):
|
||||||
|
print(f" {rank+1}. {freqs[idx]:.1f}Hz 幅度={fft_mag[idx]:.1f}")
|
||||||
|
print()
|
||||||
|
fft_done = True
|
||||||
|
|
||||||
|
# ===== 验证 4: 标签通道(Ch64/Ch65) =====
|
||||||
|
ch64 = packet[:, 64]
|
||||||
|
ch65 = packet[:, 65]
|
||||||
|
ch64_nonzero = np.where(ch64 != 0)[0]
|
||||||
|
ch65_nonzero = np.where(ch65 != 0)[0]
|
||||||
|
|
||||||
|
# 检查非标签采样点是否全为 0
|
||||||
|
ch64_zeros = np.all(ch64[:4] == 0)
|
||||||
|
ch65_zeros = np.all(ch65[:4] == 0)
|
||||||
|
ch64_zero_ok = ch64_zero_ok and ch64_zeros
|
||||||
|
ch65_zero_ok = ch65_zero_ok and ch65_zeros
|
||||||
|
|
||||||
|
if len(ch64_nonzero) > 0:
|
||||||
|
pos_in_pkt = int(ch64_nonzero[0])
|
||||||
|
label_val = int(ch64[pos_in_pkt])
|
||||||
|
label_seq = int(ch65[pos_in_pkt])
|
||||||
|
|
||||||
|
pos_ok = (len(ch64_nonzero) == 1 and pos_in_pkt == 4)
|
||||||
|
label_pos_ok_all = label_pos_ok_all and pos_ok
|
||||||
|
|
||||||
|
elapsed_since_start = now - start_time
|
||||||
|
print(f"[✓] 标签触发 @ {elapsed_since_start:.1f}s "
|
||||||
|
f"(global_idx={global_idx} 包{recv_count})")
|
||||||
|
print(f" Ch64 标签值: {label_val} Ch65 序号: {label_seq}")
|
||||||
|
print(f" 包内位置: 采样点 {pos_in_pkt}/4 "
|
||||||
|
f"({'✓' if pos_ok else '✗ 期望 4'}) "
|
||||||
|
f"其余采样点 Ch64=0: {'✓' if ch64_zeros else '✗'} "
|
||||||
|
f"Ch65=0: {'✓' if ch65_zeros else '✗'}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
label_events.append({
|
||||||
|
'time': elapsed_since_start,
|
||||||
|
'label': label_val,
|
||||||
|
'seq': label_seq
|
||||||
|
})
|
||||||
|
|
||||||
|
global_idx += N_SAMPLES_PER_PKT
|
||||||
|
|
||||||
|
# ===== 更新绘图缓冲区 =====
|
||||||
|
for ch_idx, ch in enumerate(PLOT_CHANNELS):
|
||||||
|
eeg_buffer[ch] = np.roll(eeg_buffer[ch], -N_SAMPLES_PER_PKT)
|
||||||
|
eeg_buffer[ch][-N_SAMPLES_PER_PKT:] = packet[:, ch]
|
||||||
|
|
||||||
|
label_buffer = np.roll(label_buffer, -N_SAMPLES_PER_PKT)
|
||||||
|
label_buffer[-N_SAMPLES_PER_PKT:] = packet[:, 64]
|
||||||
|
|
||||||
|
seq_buffer = np.roll(seq_buffer, -N_SAMPLES_PER_PKT)
|
||||||
|
seq_buffer[-N_SAMPLES_PER_PKT:] = packet[:, 65]
|
||||||
|
|
||||||
|
# ===== 实时更新绘图 =====
|
||||||
|
for i, ch in enumerate(PLOT_CHANNELS):
|
||||||
|
lines_eeg[i].set_data(time_axis, eeg_buffer[ch]) # 数据已是 μV 单位
|
||||||
|
line_label_data.set_data(time_axis, label_buffer)
|
||||||
|
line_seq_data.set_data(time_axis, seq_buffer)
|
||||||
|
|
||||||
|
# 设置 x 轴范围
|
||||||
|
for ax in axes:
|
||||||
|
ax.set_xlim(0, max_samples)
|
||||||
|
|
||||||
|
# 刷新图形(交互模式)
|
||||||
|
fig.canvas.draw_idle()
|
||||||
|
plt.pause(0.001)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n" + "=" * 55)
|
||||||
|
print(" 验证结果汇总")
|
||||||
|
print("=" * 55)
|
||||||
|
print(f" 运行时长: {time.time() - start_time:.1f}s")
|
||||||
|
print(f" 收到包数: {recv_count}")
|
||||||
|
print(f" FFT 验证: {'✓ 已完成' if fft_done else '✗ 未完成(时长不足1s)'}")
|
||||||
|
print(f" 非标签采样点 Ch64=0: {'✓' if ch64_zero_ok else '✗'}")
|
||||||
|
print(f" 非标签采样点 Ch65=0: {'✓' if ch65_zero_ok else '✗'}")
|
||||||
|
print(f" 标签均在包内位置4: {'✓' if label_pos_ok_all else '✗'}")
|
||||||
|
|
||||||
|
if label_events:
|
||||||
|
print(f"\n 共捕获 {len(label_events)} 次标签事件:")
|
||||||
|
for i, ev in enumerate(label_events):
|
||||||
|
print(f" {i+1}. t={ev['time']:.1f}s label={ev['label']} 序号={ev['seq']}")
|
||||||
|
|
||||||
|
# 标签间隔
|
||||||
|
print(f"\n 标签间隔验证 (期望 ~{LABEL_INTERVAL}s):")
|
||||||
|
for i in range(1, len(label_events)):
|
||||||
|
dt = label_events[i]['time'] - label_events[i-1]['time']
|
||||||
|
ok = abs(dt - LABEL_INTERVAL) < 0.1
|
||||||
|
print(f" {i}->{i+1}: {dt:.2f}s {'✓' if ok else '✗'}")
|
||||||
|
|
||||||
|
# 标签交替
|
||||||
|
labels = [e['label'] for e in label_events]
|
||||||
|
alt_ok = all(labels[i] != labels[i+1] for i in range(len(labels) - 1))
|
||||||
|
print(f"\n 标签交替: {labels} {'✓ 交替正确' if alt_ok else '✗ 交替错误'}")
|
||||||
|
|
||||||
|
# 序号
|
||||||
|
label1_seqs = [e['seq'] for e in label_events if e['label'] == 1]
|
||||||
|
label2_seqs = [e['seq'] for e in label_events if e['label'] == 2]
|
||||||
|
s1_ok = label1_seqs == list(range(1, len(label1_seqs) + 1))
|
||||||
|
s2_ok = label2_seqs == list(range(1, len(label2_seqs) + 1))
|
||||||
|
print(f" label=1 序号: {label1_seqs} {'✓' if s1_ok else '✗'}")
|
||||||
|
print(f" label=2 序号: {label2_seqs} {'✓' if s2_ok else '✗'}")
|
||||||
|
else:
|
||||||
|
print(f"\n 未捕获标签事件(运行时长不足 {LABEL_INTERVAL}s)")
|
||||||
|
|
||||||
|
print("=" * 55)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
sock.close()
|
||||||
|
ctx.term()
|
||||||
|
plt.ioff()
|
||||||
|
plt.close('all')
|
||||||
|
try:
|
||||||
|
root.destroy()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user