Files
bci_algo/Decoder.py

632 lines
32 KiB
Python
Raw Normal View History

2026-06-05 09:34:29 +08:00
import ast
import threading
from datetime import datetime
import multiprocessing as mp
import numpy as np
import time
import torch
from queue import Empty
from scipy import signal
from torch.autograd import Variable
from Device.SunnyLinker import SunnyLinker64
from SSMVEP.algorithm.tdca import TDCA
from SSMVEP.algorithm.base import generate_cca_references
from concentration.algorithm.calculate_focus import Calculate
from blinkdetection.algorithm.eye_detection import blink_detection
from Zmq.zmqServer import zmqServer
from Zmq.zmqClient import zmqClient
from MI.Algorithm.conformer_2class import onlineTrain
from PubLibrary.InifileHelper import IniRead
from SSVEP.dwfbcca import FbccaDw
from Tools.plot_MI_EEG import plotMain
from collections import deque
class Decoder_main(threading.Thread):
def __init__(self):
threading.Thread.__init__(self)
self.Runing=True
self.decoder = None
self.fs = 250 # 采样率
self.energy = 0 # 电量
self.status_code = 0 # 与采集设备通信的状态码0为异常1为正常
self.decoder_class = None #解码器类别
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None):
self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type'))
_device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host'))
_device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port'))
_upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host'))
_upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port'))
if self.DeviceType == 1:
self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.fs, 64, method='tcp')
self.thread_data_server.host = _device_host
self.thread_data_server.port = _device_port
self.thread_data_server.toUv = True
self.thread_data_server.start()
self.zmqServer = zmqServer()
self.zmqServer.start()
self.zmqClient = zmqClient(_upper_host, _upper_port)
self.zmqClient.set_zmq_server(self.zmqServer)
self.zmqClient.connect()
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
# data: (chans, samples)
energy = np.mean(np.var(data, axis=1)) # 各通道方差均值
if energy > threshold:
return False
return True
def init_Decoder(self,decoder_class):
'''
初始化解码器
:param decoder_class: 'ssvep' or 'ssmvep' or 'mi' or 'concentration' or ''
:return:
'''
self.decoder_class = decoder_class
if decoder_class == 'ssvep' or decoder_class == 'pvs':
self.n_chan = 8
self.thread_data_server.interval_inited = False
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
self.ListFreq = self.zmqServer.targetFreqs
self.num_target = len(self.ListFreq)
if self.num_target == 0:
return
# 初始化对象 二代算法
self.dw = FbccaDw(self.fs, self.num_target, self.n_chan, 5, 5,
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
# frequence band
self.dw.filterFrequenceBank()
self.dw.setNotchFilterPara()
self.calculateCount = 0
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.fs),
5)
self.dw.filterInit()
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
elif decoder_class == 'ssmvep':
self.thread_data_server.interval_init(decoder_class)
self.n_chan = 8
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
self.single_train = 10 # 单类别数量
self.num_target = 2 # 分类目标数目
self.list_freqs = np.array([8, 9]) # 刺激频率
self.list_phase = np.array([0, 0]) # 相位
self.tdca = TDCA(padding_len=5, n_components=1)
self.Yf = generate_cca_references(self.list_freqs, srate=self.fs, T=self.sample_length,
phases=self.list_phase, n_harmonics=5)
self.parameter_init(5,45)
elif decoder_class == 'mi' or decoder_class == 'ma':
self.thread_data_server.interval_init(decoder_class)
self.n_chan = 21
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
self.single_train = 40 # 单类别数量
self.num_target = 2 # 分类目标数目
self.parameter_init(8, 30)
elif decoder_class == 'concentration':
self.thread_data_server.interval_inited = False
self.n_chan = 6
self.win_len = 10
self.win_step = 1
self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
self.calculate = Calculate(self.low_threshold, self.high_threshold, self.fs, self.win_len)
self.interval_epoch = [0, 1]
self.parameter_init(2, 40)
# self.eegQueue moved to Calculate class
elif decoder_class == 'blink':
self.n_chan = 2
self.l_freq = 0.1 # 带通滤波器低频截止
self.h_freq = 8.0 # 带通滤波器高频截止
self.total_samples = 0 # 总采样点数
self.window_ms = 600 # 检测窗口大小 (ms)
self.step_ms = 100 # 滑动步长 (ms)
self.window_samples = int(self.window_ms * self.fs / 1000) # 150个样本点
self.step_samples = int(self.step_ms * self.fs / 1000) # 25个样本点
self.buffer_size = self.window_samples + self.step_samples * 5
self.fp1_buffer = deque(maxlen=self.buffer_size)
self.fp2_buffer = deque(maxlen=self.buffer_size)
self.sample_counter = 0
# 预计算滤波器系数,避免在循环中重复设计
self.Dmin, self.Dmax, self.EMin, self.EMax, self.jitterwin,self.double_blink_interval,self.double_blink_jitter = ast.literal_eval(IniRead('system', 'blink'))
self.blink_count = 0 # 单次眨眼的次数
self.last_blink_time = 0 # 上次检测到单次眨眼的时间(样本索引)
self.blink_timestamps = deque(maxlen=10) # 记录最近10次 单次眨眼的时间戳
self.double_blink_count = 0 # 连续两次眨眼的次数
self.double_blink_events = [] # 连续眨眼事件记录
self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
self.blink_events = []
self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.fs / 2), self.h_freq / (self.fs / 2)], btype='band')
def parameter_init(self,bandPass_low,bandPass_high):
self.interval_epoch = [int(i * self.fs) for i in self.interval_epoch] # epoch截取信息
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.fs)] # 训练样本epoch
self.trainData = [] #训练数据
self.trainLabel = [] #训练标签
self.plotData = [] #报告分析数据
self.plotLabel = [] #报告分析标签
self.currentLabel = -1 #刺激界面当前显示的训练标签
self.train_started = False #是否开始训练模型
self.load_model = False # 调用模型是否完成的标志
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) # 50Hz工频陷波250是采样率30是质量因子
self.b_design = signal.firwin(65, [bandPass_low / (self.fs/2), bandPass_high / (self.fs/2)], pass_zero=False) # 设计8-30Hz带通滤波器
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
filePath = './online_Models/'
self.modelPath = ''.join([filePath, fileName, '.pth'])
self.mp_data_queue = mp.Queue() #多进程传参队列
self.mp_result_queue = mp.Queue() #多进程结果队列
def preprocess(self, signal_data):
# # 计算每行的平均值
row_means = np.mean(signal_data, axis=-1, keepdims=True)
# 对每一行去均值
signal_data = signal_data - row_means
signal_data = signal.lfilter(self.b_notch, self.a_notch, signal_data, axis=-1) # 工频陷波
signal_data = signal.lfilter(self.b_design, 1, signal_data, axis=-1) # 带通滤波
return signal_data
def run(self):
while self.Runing:
if self.zmqServer.decoder_switch or self.zmqServer.changeTarget:
print(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}")
self.zmqServer.decoder_switch = False
self.zmqServer.changeTarget = False
self.reset_state() # 切换前先统一清理旧状态
self.init_Decoder(self.zmqServer.decoder_class)
# 同步信息
if self.zmqServer.state_mode == 'sync':
self.zmqClient.send_to_all('sync', self.zmqClient.state)
self.zmqServer.state_mode = 'rest'
# 状态异常,报告上位机
if self.status_code != self.thread_data_server.status_code:
self.status_code = self.thread_data_server.status_code
self.zmqClient.send_to_all('status_code', int(self.status_code))
print('status code')
# 返回电量
if self.energy != self.thread_data_server.energy:
self.energy = self.thread_data_server.energy
self.zmqClient.send_to_all('energy', int(self.energy))
print('energy')
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
self.thread_data_server.Impedance(True)
print('Impedance')
self.zmqServer.open_Impedance = -1
elif self.zmqServer.open_Impedance == False:
self.thread_data_server.Impedance(False)
self.zmqServer.open_Impedance = -1
if self.zmqServer.get_Impedance: # 返回阻抗值
# print(self.zmqServer.get_Impedance)
# print(self.thread_data_server.GetDataLenCount())
if self.thread_data_server.GetDataLenCount() > 250:
Impe_data = self.thread_data_server.getData(250)
# 计算阻抗
imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class)
self.zmqClient.send_to_all('impedance', imps.tolist())
else:
pass
if self.zmqServer.getReport: #返回训练报告内容
self.zmqServer.getReport = False
allData = np.array(self.plotData)
allLabel = np.array(self.plotLabel) + 1
nTrials = min(len(allLabel),len(allData))
if nTrials < 30:
self.zmqClient.send_to_all('miReport',0)
else:
allData = allData[:nTrials]
allLabel = allLabel[:nTrials]
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
compare_names = ['C3', 'CZ', 'C4']
miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2,
fs=self.fs)
self.zmqClient.send_to_all('miReport',miReport)
# --- 取数优先:先执行 decoder消费环形缓冲再处理 plot/report 等重负载 ---
try:
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
self.decoder_SSVEP()
elif self.decoder_class == 'ssmvep':
self.decoder_SSMVEP()
elif self.decoder_class == 'mi':
self.decoder_MI()
elif self.decoder_class == 'concentration':
self.decoder_concentration()
elif self.decoder_class == 'blink':
self.decoder_blink()
else:
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25:
time.sleep(0.005)
continue;
self.thread_data_server.getData(25)
except Exception as e:
print(f"Decoder Loop Error: {e}")
import traceback
traceback.print_exc()
time.sleep(0.1) # Prevent CPU spin if error is persistent
def decoder_SSVEP(self):
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
self.decodingSteps = 1
self.thread_data_server.ResetAll()
print('启动预测')
if self.thread_data_server.GetDataLenCount() < 50:
time.sleep(0.005)
return
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
return
data = self.thread_data_server.getDataViaSSVEP(50)
data = data[:self.n_chan, :]
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
self.dw.warmFilter(data) # 预热
self.decodingSteps = 2
print('预热数据完成。开始预测')
return
if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中
choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount)
self.calculateCount += 1
if choosenNum != -1 and self.is_valid_signal(data):
self.decodingSteps = 3
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
self.calculateCount = 0
if self.decodingSteps == 3: # 发送解码后的信息
self.zmqClient.send_to_all('result', int(choosenNum))
self.decodingSteps = 0
print('发送给界面完成。')
def decoder_SSMVEP(self):
'''模型训练'''
if self.load_model == False and all(
self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成
self.trainData = np.array(self.trainData)
self.trainLabel = np.array(self.trainLabel)
print(np.shape(self.trainData), (self.trainLabel))
# 保存多个数组到文件
# np.savez('20250520_yy.npz', array1=self.trainData, array2=self.trainLabel)
# self.decoder = self.fbtdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('模型训练完成', formatted_time)
self.load_model = True
self.zmqClient.send_to_all('paradigm', 1)
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train': # 训练状态
if self.zmqServer.StartTrain:
self.currentLabel = self.zmqServer.currentLabel
self.zmqServer.StartTrain = False
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
self.train_epoch[1] \
+ self.thread_data_server.event_inner_idx:
time.sleep(0.0001)
return
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx])
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.train_epoch[
0]:self.thread_data_server.event_inner_idx + self.train_epoch[1]]
print('trial: ', self.thread_data_server.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
self.trainLabel, list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
elif self.zmqServer.state_mode == 'predict': # 测试状态
if self.load_model == False: # 模型尚未训练完成
time.sleep(0.01)
return
else: # 已有模型
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time)
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx:
time.sleep(0.0001)
return
data = self.thread_data_server.get_SSMVEPData() # 读取全部数据
print('取出的: ', data.shape, 'event: ', data[-2, self.thread_data_server.event_inner_idx])
data = self.preprocess(data[:self.n_chan, :]) # 预处理
data = data[:,
self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
pad_eeg_test = np.zeros(
(data.shape[0], int((self.sample_length + 0.1) * self.fs)))
pad_eeg_test[:, :int(self.sample_length * self.fs)] = data
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
if isinstance(choosenNum, np.ndarray):
choosenNum = choosenNum[0]
print('结果:', choosenNum, 'rho: ', sorted(features_2[0]),
sorted(features_2[0])[-1] - sorted(features_2[0])[-2])
self.zmqClient.send_to_all('result', int(choosenNum))
print('发送给界面完成。')
else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25:
time.sleep(0.005)
return
self.thread_data_server.getData(25)
def decoder_MI(self):
'''模型训练'''
if self.train_started == False and all(
self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练
self.zmqClient.send_to_all('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
self.train_started = True
self.trainData = np.array(self.trainData)
self.trainLabel = np.array(self.trainLabel) + 1
# print('训练集:',np.shape(self.trainData), (self.trainLabel))
p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型
p.start()
self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath,
'n_chan': self.n_chan})
'''检查模型是否训练完成,调用'''
if self.load_model == False and self.train_started == True:
try:
result = self.mp_result_queue.get_nowait()
if result['status'] == 'success':
print("模型训练完成,加载新模型")
# 调用模型
self.model = torch.load(self.modelPath, weights_only=False)
self.model.eval()
# 模型预热
warmup_data = np.random.uniform(-1, 1, (1, 1, self.n_chan, 1000))
warmup_data = torch.from_numpy(warmup_data)
warmup_data = Variable(warmup_data.type(torch.cuda.FloatTensor))
with torch.no_grad():
_ = self.model(warmup_data)
self.load_model = True
self.zmqClient.send_to_all('paradigm', 1) # 模型调用完毕,通知上位机
else:
print("训练失败:", result['msg'])
except Empty:
pass # 还没完成
except Exception as e:
print('模型调用失败: ', e)
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
if self.zmqServer.StartTrain:
self.currentLabel = self.zmqServer.currentLabel
self.zmqServer.StartTrain = False
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx:
time.sleep(0.0001)
return
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx])
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
print('trial: ', self.thread_data_server.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
print('训练集:', np.shape(self.trainData))
self.plotData.append(originalTrial[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
self.plotLabel.append(self.currentLabel)
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time)
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx:
time.sleep(0.0001)
return
originalData = self.thread_data_server.get_MIData() # 读取全部数据
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.thread_data_server.event_inner_idx])
start = time.time()
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
data = data[:,
self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
self.plotData.append(
originalData[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
test_data = data[np.newaxis, np.newaxis, :, :]
test_data = torch.from_numpy(test_data)
test_data = Variable(test_data.type(torch.cuda.FloatTensor))
with torch.no_grad():
Cls = self.model(test_data)
y_pred = torch.max(Cls, 1)[1]
self.plotLabel.append(int(y_pred.item()))
print('运动意图识别: ', y_pred)
self.zmqClient.send_to_all('result', int(y_pred.item()))
end = time.time()
print(f'发送给界面完成,耗时{end - start:.3f}s。')
else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25:
time.sleep(0.005)
return
self.thread_data_server.getData(25)
def decoder_concentration(self):
if self.zmqServer.state_mode == 'predict':
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
self.thread_data_server.ResetAll()
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动专注力预测 ', formatted_time)
if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.fs): # 每win_step得出一次结果
time.sleep(0.005)
return
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
return
data = self.thread_data_server.get_concentrateData(int(self.win_step * self.fs)) # 修改每次读取的数据
result = self.calculate.queueOpt(data)
if result is not None:
self.zmqClient.send_to_all('result', int(result))
else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25:
time.sleep(0.005)
return
self.thread_data_server.getData(25)
#### Blink detection #####
def check_double_blink(self, current_time):
"""
检查是否检测到连续两次眨眼
@param current_time: 当前眨眼时间戳
@return: True表示检测到连续两次眨眼
"""
if len(self.blink_timestamps) < 2:
return False
# 检查是否在去抖期内
if self.last_double_blink_time > 0:
time_since_last_double_blink = current_time - self.last_double_blink_time
if time_since_last_double_blink < self.double_blink_jitter:
return False # 在去抖期内,忽略连续眨眼检测
last_time = self.blink_timestamps[-1] # 当前眨眼
prev_time = self.blink_timestamps[-2] # 上次眨眼
interval = last_time - prev_time
if interval <= self.double_blink_interval:
return True
return False
def process_blink_detection(self):
"""
在缓冲区数据上执行,单次眨眼检测
"""
if len(self.fp1_buffer) < self.window_samples:
return
fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:])
fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:])
# 计算FP1和FP2的平均
fp12_mean = (fp1_data + fp2_data) / 2.0
# 带通滤波
try:
fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean)
except Exception as e:
print(f"Filter error: {e}")
return
F = np.diff(fp12_filtered)
if len(F) < 3:
return
b, d, e = blink_detection(F, self.fs, self.Dmin, self.Dmax, self.EMin, self.EMax)
if b == 1:
samples_since_last = self.total_samples - self.last_blink_time
time_since_last_ms = (samples_since_last / self.fs) * 1000
if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms
self.blink_count += 1
self.last_blink_time = self.total_samples
current_time = time.time()
self.blink_timestamps.append(current_time)
blink_event = {
'count': self.blink_count,
'time': current_time,
'sample_index': self.total_samples,
'duration_ms': d,
'energy': e
}
self.blink_events.append(blink_event)
self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机
if self.check_double_blink(current_time):
self.double_blink_count += 1
interval = self.blink_timestamps[-1] - self.blink_timestamps[-2]
double_blink_event = {
'double_blink_count': self.double_blink_count,
'blink1_time': self.blink_timestamps[-2],
'blink2_time': self.blink_timestamps[-1],
'interval': interval
}
self.double_blink_events.append(double_blink_event)
self.last_double_blink_time = current_time
self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件
def decoder_blink(self):
if self.thread_data_server.GetDataLenCount() < 50:
time.sleep(0.005)
return
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
data = self.thread_data_server.get_blinkData(50)
fp1_data = data[0, :] # ch1 (相当于FP1)
fp2_data = data[1, :] # ch2 (相当于FP2)
for i in range(len(fp1_data)):
self.fp1_buffer.append(fp1_data[i])
self.fp2_buffer.append(fp2_data[i])
self.total_samples += 1
self.sample_counter += 1
if self.sample_counter >= self.step_samples:
self.process_blink_detection()
self.sample_counter = 0
def stop(self):
'''
停止运行
@return:
'''
self.zmqServer.stop()
self.Runing=False
def reset_state(self):
"""清空解码器状态和缓存数据"""
# 重置设备层缓存
self.thread_data_server.reset_state()
# 重置解码状态
self.decodingSteps = 0
self.calculateCount = 0
# 重置训练数据
self.plotData = []
self.plotLabel = []
self.trainData = []
self.trainLabel = []
self.currentLabel = -1
self.train_started = False
self.load_model = False
# 重置多进程队列,确保切换 decoder 时旧数据不会泄漏到新队列
if hasattr(self, 'mp_data_queue'):
while not self.mp_data_queue.empty():
try: self.mp_data_queue.get_nowait()
except Empty: pass
if hasattr(self, 'mp_result_queue'):
while not self.mp_result_queue.empty():
try: self.mp_result_queue.get_nowait()
except Empty: pass