init commit

This commit is contained in:
Ivey Song
2026-06-01 13:42:38 +08:00
parent 94e2886698
commit 2226a22ae8
60 changed files with 12075 additions and 0 deletions

View File

@@ -0,0 +1,632 @@
import ast
import threading
from datetime import datetime
import multiprocessing as mp
import numpy as np
import time
import torch
from queue import Empty
from scipy import signal
from torch.autograd import Variable
from Device.SunnyLinker import SunnyLinker64
from SSMVEP.algorithm.tdca import TDCA
from SSMVEP.algorithm.base import generate_cca_references
from concentration.algorithm.calculate_focus import Calculate
from blinkdetection.algorithm.eye_detection import blink_detection
from Zmq.zmqServer import zmqServer
from Zmq.zmqClient import zmqClient
from MI.Algorithm.conformer_2class import onlineTrain
from PubLibrary.InifileHelper import IniRead
from SSVEP.dwfbcca import FbccaDw
from Tools.plot_MI_EEG import plotMain
from collections import deque
class Decoder_main(threading.Thread):
def __init__(self):
threading.Thread.__init__(self)
self.Runing=True
self.decoder = None
self.fs = 250 # 采样率
self.energy = 0 # 电量
self.status_code = 0 # 与采集设备通信的状态码0为异常1为正常
self.decoder_class = None #解码器类别
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None):
self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type'))
_device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host'))
_device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port'))
_upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host'))
_upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port'))
if self.DeviceType == 1:
self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.fs, 64, method='tcp')
self.thread_data_server.host = _device_host
self.thread_data_server.port = _device_port
self.thread_data_server.toUv = True
self.thread_data_server.start()
self.zmqServer = zmqServer()
self.zmqServer.start()
self.zmqClient = zmqClient(_upper_host, _upper_port)
self.zmqClient.set_zmq_server(self.zmqServer)
self.zmqClient.connect()
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
# data: (chans, samples)
energy = np.mean(np.var(data, axis=1)) # 各通道方差均值
if energy > threshold:
return False
return True
def init_Decoder(self,decoder_class):
'''
初始化解码器
:param decoder_class: 'ssvep' or 'ssmvep' or 'mi' or 'concentration' or ''
:return:
'''
self.decoder_class = decoder_class
if decoder_class == 'ssvep' or decoder_class == 'pvs':
self.n_chan = 8
self.thread_data_server.interval_inited = False
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
self.ListFreq = self.zmqServer.targetFreqs
self.num_target = len(self.ListFreq)
if self.num_target == 0:
return
# 初始化对象 二代算法
self.dw = FbccaDw(self.fs, self.num_target, self.n_chan, 5, 5,
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
# frequence band
self.dw.filterFrequenceBank()
self.dw.setNotchFilterPara()
self.calculateCount = 0
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.fs),
5)
self.dw.filterInit()
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
elif decoder_class == 'ssmvep':
self.thread_data_server.interval_init(decoder_class)
self.n_chan = 8
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
self.single_train = 10 # 单类别数量
self.num_target = 2 # 分类目标数目
self.list_freqs = np.array([8, 9]) # 刺激频率
self.list_phase = np.array([0, 0]) # 相位
self.tdca = TDCA(padding_len=5, n_components=1)
self.Yf = generate_cca_references(self.list_freqs, srate=self.fs, T=self.sample_length,
phases=self.list_phase, n_harmonics=5)
self.parameter_init(5,45)
elif decoder_class == 'mi' or decoder_class == 'ma':
self.thread_data_server.interval_init(decoder_class)
self.n_chan = 21
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
self.single_train = 40 # 单类别数量
self.num_target = 2 # 分类目标数目
self.parameter_init(8, 30)
elif decoder_class == 'concentration':
self.thread_data_server.interval_inited = False
self.n_chan = 6
self.win_len = 10
self.win_step = 1
self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
self.calculate = Calculate(self.low_threshold, self.high_threshold, self.fs, self.win_len)
self.interval_epoch = [0, 1]
self.parameter_init(2, 40)
# self.eegQueue moved to Calculate class
elif decoder_class == 'blink':
self.n_chan = 2
self.l_freq = 0.1 # 带通滤波器低频截止
self.h_freq = 8.0 # 带通滤波器高频截止
self.total_samples = 0 # 总采样点数
self.window_ms = 600 # 检测窗口大小 (ms)
self.step_ms = 100 # 滑动步长 (ms)
self.window_samples = int(self.window_ms * self.fs / 1000) # 150个样本点
self.step_samples = int(self.step_ms * self.fs / 1000) # 25个样本点
self.buffer_size = self.window_samples + self.step_samples * 5
self.fp1_buffer = deque(maxlen=self.buffer_size)
self.fp2_buffer = deque(maxlen=self.buffer_size)
self.sample_counter = 0
# 预计算滤波器系数,避免在循环中重复设计
self.Dmin, self.Dmax, self.EMin, self.EMax, self.jitterwin,self.double_blink_interval,self.double_blink_jitter = ast.literal_eval(IniRead('system', 'blink'))
self.blink_count = 0 # 单次眨眼的次数
self.last_blink_time = 0 # 上次检测到单次眨眼的时间(样本索引)
self.blink_timestamps = deque(maxlen=10) # 记录最近10次 单次眨眼的时间戳
self.double_blink_count = 0 # 连续两次眨眼的次数
self.double_blink_events = [] # 连续眨眼事件记录
self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
self.blink_events = []
self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.fs / 2), self.h_freq / (self.fs / 2)], btype='band')
def parameter_init(self,bandPass_low,bandPass_high):
self.interval_epoch = [int(i * self.fs) for i in self.interval_epoch] # epoch截取信息
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.fs)] # 训练样本epoch
self.trainData = [] #训练数据
self.trainLabel = [] #训练标签
self.plotData = [] #报告分析数据
self.plotLabel = [] #报告分析标签
self.currentLabel = -1 #刺激界面当前显示的训练标签
self.train_started = False #是否开始训练模型
self.load_model = False # 调用模型是否完成的标志
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) # 50Hz工频陷波250是采样率30是质量因子
self.b_design = signal.firwin(65, [bandPass_low / (self.fs/2), bandPass_high / (self.fs/2)], pass_zero=False) # 设计8-30Hz带通滤波器
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
filePath = './online_Models/'
self.modelPath = ''.join([filePath, fileName, '.pth'])
self.mp_data_queue = mp.Queue() #多进程传参队列
self.mp_result_queue = mp.Queue() #多进程结果队列
def preprocess(self, signal_data):
# # 计算每行的平均值
row_means = np.mean(signal_data, axis=-1, keepdims=True)
# 对每一行去均值
signal_data = signal_data - row_means
signal_data = signal.lfilter(self.b_notch, self.a_notch, signal_data, axis=-1) # 工频陷波
signal_data = signal.lfilter(self.b_design, 1, signal_data, axis=-1) # 带通滤波
return signal_data
def run(self):
while self.Runing:
if self.zmqServer.decoder_switch or self.zmqServer.changeTarget:
print(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}")
self.zmqServer.decoder_switch = False
self.zmqServer.changeTarget = False
self.reset_state() # 切换前先统一清理旧状态
self.init_Decoder(self.zmqServer.decoder_class)
# 同步信息
if self.zmqServer.state_mode == 'sync':
self.zmqClient.send_to_all('sync', self.zmqClient.state)
self.zmqServer.state_mode = 'rest'
# 状态异常,报告上位机
if self.status_code != self.thread_data_server.status_code:
self.status_code = self.thread_data_server.status_code
self.zmqClient.send_to_all('status_code', int(self.status_code))
print('status code')
# 返回电量
if self.energy != self.thread_data_server.energy:
self.energy = self.thread_data_server.energy
self.zmqClient.send_to_all('energy', int(self.energy))
print('energy')
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
self.thread_data_server.Impedance(True)
print('Impedance')
self.zmqServer.open_Impedance = -1
elif self.zmqServer.open_Impedance == False:
self.thread_data_server.Impedance(False)
self.zmqServer.open_Impedance = -1
if self.zmqServer.get_Impedance: # 返回阻抗值
# print(self.zmqServer.get_Impedance)
# print(self.thread_data_server.GetDataLenCount())
if self.thread_data_server.GetDataLenCount() > 250:
Impe_data = self.thread_data_server.getData(250)
# 计算阻抗
imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class)
self.zmqClient.send_to_all('impedance', imps.tolist())
else:
pass
if self.zmqServer.getReport: #返回训练报告内容
self.zmqServer.getReport = False
allData = np.array(self.plotData)
allLabel = np.array(self.plotLabel) + 1
nTrials = min(len(allLabel),len(allData))
if nTrials < 30:
self.zmqClient.send_to_all('miReport',0)
else:
allData = allData[:nTrials]
allLabel = allLabel[:nTrials]
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
compare_names = ['C3', 'CZ', 'C4']
miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2,
fs=self.fs)
self.zmqClient.send_to_all('miReport',miReport)
# --- 取数优先:先执行 decoder消费环形缓冲再处理 plot/report 等重负载 ---
try:
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
self.decoder_SSVEP()
elif self.decoder_class == 'ssmvep':
self.decoder_SSMVEP()
elif self.decoder_class == 'mi':
self.decoder_MI()
elif self.decoder_class == 'concentration':
self.decoder_concentration()
elif self.decoder_class == 'blink':
self.decoder_blink()
else:
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25:
time.sleep(0.005)
continue;
self.thread_data_server.getData(25)
except Exception as e:
print(f"Decoder Loop Error: {e}")
import traceback
traceback.print_exc()
time.sleep(0.1) # Prevent CPU spin if error is persistent
def decoder_SSVEP(self):
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
self.decodingSteps = 1
self.thread_data_server.ResetAll()
print('启动预测')
if self.thread_data_server.GetDataLenCount() < 50:
time.sleep(0.005)
return
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
return
data = self.thread_data_server.getDataViaSSVEP(50)
data = data[:self.n_chan, :]
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
self.dw.warmFilter(data) # 预热
self.decodingSteps = 2
print('预热数据完成。开始预测')
return
if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中
choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount)
self.calculateCount += 1
if choosenNum != -1 and self.is_valid_signal(data):
self.decodingSteps = 3
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
self.calculateCount = 0
if self.decodingSteps == 3: # 发送解码后的信息
self.zmqClient.send_to_all('result', int(choosenNum))
self.decodingSteps = 0
print('发送给界面完成。')
def decoder_SSMVEP(self):
'''模型训练'''
if self.load_model == False and all(
self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成
self.trainData = np.array(self.trainData)
self.trainLabel = np.array(self.trainLabel)
print(np.shape(self.trainData), (self.trainLabel))
# 保存多个数组到文件
# np.savez('20250520_yy.npz', array1=self.trainData, array2=self.trainLabel)
# self.decoder = self.fbtdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('模型训练完成', formatted_time)
self.load_model = True
self.zmqClient.send_to_all('paradigm', 1)
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train': # 训练状态
if self.zmqServer.StartTrain:
self.currentLabel = self.zmqServer.currentLabel
self.zmqServer.StartTrain = False
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
self.train_epoch[1] \
+ self.thread_data_server.event_inner_idx:
time.sleep(0.0001)
return
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx])
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.train_epoch[
0]:self.thread_data_server.event_inner_idx + self.train_epoch[1]]
print('trial: ', self.thread_data_server.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
self.trainLabel, list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
elif self.zmqServer.state_mode == 'predict': # 测试状态
if self.load_model == False: # 模型尚未训练完成
time.sleep(0.01)
return
else: # 已有模型
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time)
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx:
time.sleep(0.0001)
return
data = self.thread_data_server.get_SSMVEPData() # 读取全部数据
print('取出的: ', data.shape, 'event: ', data[-2, self.thread_data_server.event_inner_idx])
data = self.preprocess(data[:self.n_chan, :]) # 预处理
data = data[:,
self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
pad_eeg_test = np.zeros(
(data.shape[0], int((self.sample_length + 0.1) * self.fs)))
pad_eeg_test[:, :int(self.sample_length * self.fs)] = data
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
if isinstance(choosenNum, np.ndarray):
choosenNum = choosenNum[0]
print('结果:', choosenNum, 'rho: ', sorted(features_2[0]),
sorted(features_2[0])[-1] - sorted(features_2[0])[-2])
self.zmqClient.send_to_all('result', int(choosenNum))
print('发送给界面完成。')
else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25:
time.sleep(0.005)
return
self.thread_data_server.getData(25)
def decoder_MI(self):
'''模型训练'''
if self.train_started == False and all(
self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练
self.zmqClient.send_to_all('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
self.train_started = True
self.trainData = np.array(self.trainData)
self.trainLabel = np.array(self.trainLabel) + 1
# print('训练集:',np.shape(self.trainData), (self.trainLabel))
p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型
p.start()
self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath,
'n_chan': self.n_chan})
'''检查模型是否训练完成,调用'''
if self.load_model == False and self.train_started == True:
try:
result = self.mp_result_queue.get_nowait()
if result['status'] == 'success':
print("模型训练完成,加载新模型")
# 调用模型
self.model = torch.load(self.modelPath, weights_only=False)
self.model.eval()
# 模型预热
warmup_data = np.random.uniform(-1, 1, (1, 1, self.n_chan, 1000))
warmup_data = torch.from_numpy(warmup_data)
warmup_data = Variable(warmup_data.type(torch.cuda.FloatTensor))
with torch.no_grad():
_ = self.model(warmup_data)
self.load_model = True
self.zmqClient.send_to_all('paradigm', 1) # 模型调用完毕,通知上位机
else:
print("训练失败:", result['msg'])
except Empty:
pass # 还没完成
except Exception as e:
print('模型调用失败: ', e)
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
if self.zmqServer.StartTrain:
self.currentLabel = self.zmqServer.currentLabel
self.zmqServer.StartTrain = False
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx:
time.sleep(0.0001)
return
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx])
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
print('trial: ', self.thread_data_server.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
print('训练集:', np.shape(self.trainData))
self.plotData.append(originalTrial[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
self.plotLabel.append(self.currentLabel)
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time)
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx:
time.sleep(0.0001)
return
originalData = self.thread_data_server.get_MIData() # 读取全部数据
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.thread_data_server.event_inner_idx])
start = time.time()
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
data = data[:,
self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
self.plotData.append(
originalData[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
test_data = data[np.newaxis, np.newaxis, :, :]
test_data = torch.from_numpy(test_data)
test_data = Variable(test_data.type(torch.cuda.FloatTensor))
with torch.no_grad():
Cls = self.model(test_data)
y_pred = torch.max(Cls, 1)[1]
self.plotLabel.append(int(y_pred.item()))
print('运动意图识别: ', y_pred)
self.zmqClient.send_to_all('result', int(y_pred.item()))
end = time.time()
print(f'发送给界面完成,耗时{end - start:.3f}s。')
else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25:
time.sleep(0.005)
return
self.thread_data_server.getData(25)
def decoder_concentration(self):
if self.zmqServer.state_mode == 'predict':
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
self.thread_data_server.ResetAll()
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动专注力预测 ', formatted_time)
if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.fs): # 每win_step得出一次结果
time.sleep(0.005)
return
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
return
data = self.thread_data_server.get_concentrateData(int(self.win_step * self.fs)) # 修改每次读取的数据
result = self.calculate.queueOpt(data)
if result is not None:
self.zmqClient.send_to_all('result', int(result))
else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25:
time.sleep(0.005)
return
self.thread_data_server.getData(25)
#### Blink detection #####
def check_double_blink(self, current_time):
"""
检查是否检测到连续两次眨眼
@param current_time: 当前眨眼时间戳
@return: True表示检测到连续两次眨眼
"""
if len(self.blink_timestamps) < 2:
return False
# 检查是否在去抖期内
if self.last_double_blink_time > 0:
time_since_last_double_blink = current_time - self.last_double_blink_time
if time_since_last_double_blink < self.double_blink_jitter:
return False # 在去抖期内,忽略连续眨眼检测
last_time = self.blink_timestamps[-1] # 当前眨眼
prev_time = self.blink_timestamps[-2] # 上次眨眼
interval = last_time - prev_time
if interval <= self.double_blink_interval:
return True
return False
def process_blink_detection(self):
"""
在缓冲区数据上执行,单次眨眼检测
"""
if len(self.fp1_buffer) < self.window_samples:
return
fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:])
fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:])
# 计算FP1和FP2的平均
fp12_mean = (fp1_data + fp2_data) / 2.0
# 带通滤波
try:
fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean)
except Exception as e:
print(f"Filter error: {e}")
return
F = np.diff(fp12_filtered)
if len(F) < 3:
return
b, d, e = blink_detection(F, self.fs, self.Dmin, self.Dmax, self.EMin, self.EMax)
if b == 1:
samples_since_last = self.total_samples - self.last_blink_time
time_since_last_ms = (samples_since_last / self.fs) * 1000
if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms
self.blink_count += 1
self.last_blink_time = self.total_samples
current_time = time.time()
self.blink_timestamps.append(current_time)
blink_event = {
'count': self.blink_count,
'time': current_time,
'sample_index': self.total_samples,
'duration_ms': d,
'energy': e
}
self.blink_events.append(blink_event)
self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机
if self.check_double_blink(current_time):
self.double_blink_count += 1
interval = self.blink_timestamps[-1] - self.blink_timestamps[-2]
double_blink_event = {
'double_blink_count': self.double_blink_count,
'blink1_time': self.blink_timestamps[-2],
'blink2_time': self.blink_timestamps[-1],
'interval': interval
}
self.double_blink_events.append(double_blink_event)
self.last_double_blink_time = current_time
self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件
def decoder_blink(self):
if self.thread_data_server.GetDataLenCount() < 50:
time.sleep(0.005)
return
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
data = self.thread_data_server.get_blinkData(50)
fp1_data = data[0, :] # ch1 (相当于FP1)
fp2_data = data[1, :] # ch2 (相当于FP2)
for i in range(len(fp1_data)):
self.fp1_buffer.append(fp1_data[i])
self.fp2_buffer.append(fp2_data[i])
self.total_samples += 1
self.sample_counter += 1
if self.sample_counter >= self.step_samples:
self.process_blink_detection()
self.sample_counter = 0
def stop(self):
'''
停止运行
@return:
'''
self.zmqServer.stop()
self.Runing=False
def reset_state(self):
"""清空解码器状态和缓存数据"""
# 重置设备层缓存
self.thread_data_server.reset_state()
# 重置解码状态
self.decodingSteps = 0
self.calculateCount = 0
# 重置训练数据
self.plotData = []
self.plotLabel = []
self.trainData = []
self.trainLabel = []
self.currentLabel = -1
self.train_started = False
self.load_model = False
# 重置多进程队列,确保切换 decoder 时旧数据不会泄漏到新队列
if hasattr(self, 'mp_data_queue'):
while not self.mp_data_queue.empty():
try: self.mp_data_queue.get_nowait()
except Empty: pass
if hasattr(self, 'mp_result_queue'):
while not self.mp_result_queue.empty():
try: self.mp_result_queue.get_nowait()
except Empty: pass

View File

@@ -0,0 +1,814 @@
# -*-coding:utf-8 -*-
'''
SunnyLinker的通讯驱动
'''
import ast
import socket
import threading
import time
import datetime
from typing import Dict
from collections import deque
import numpy as np
from threading import Thread, Event
import serial
from scipy import signal
from serial.serialutil import SerialException
from Device.protocol import ProtocolFrame
from PubLibrary.InifileHelper import IniRead
class RingBuffer:
def __init__(self, n_chan, n_points):
self.n_chan = n_chan
self.n_points = n_points
self.buffer = np.zeros((n_chan, n_points))
self.currentPtr = 0
self.readPtr = 0
self.nUpdate = 0
self.rawData = np.zeros((n_chan, 1))
## append buffer and update current pointer
def appendBuffer(self, data):
if self.nUpdate == self.n_points:
raise Exception("Buffer is full")
n = data.shape[1]
# 计算可以写入的元素数量
write_count = min(self.n_points - self.nUpdate, n)
# 写入新数据
self.buffer[:, np.mod(np.arange(self.currentPtr, self.currentPtr + write_count), self.n_points)] = data[:,:write_count]
# 更新结束指针
self.currentPtr = (self.currentPtr + write_count) % self.n_points
# 更新大小
self.nUpdate += write_count
## get data from buffer
def getData(self, count=50):
# 确保不会尝试读取超过缓冲区当前大小的数据
count = min(count, self.nUpdate)
# 计算读取结束后的下一个位置
next_read_ptr = (self.readPtr + count) % self.n_points
if self.readPtr + count <= self.n_points:
# 情况 1不环绕数据是连续的
end_index = next_read_ptr if next_read_ptr != 0 else self.n_points
data = self.buffer[:, self.readPtr:end_index]
else:
# 情况 2发生环绕数据被分成两部分
# 第一部分:从 readPtr 到缓冲区末尾
part1 = self.buffer[:, self.readPtr:]
# 第二部分:从缓冲区开头到 (count - part1.shape[1]) 个点
part2 = self.buffer[:, :next_read_ptr]
# 将两部分在列方向上拼接
data = np.concatenate((part1, part2), axis=1)
# 更新读指针
self.readPtr = next_read_ptr
# 更新大小
self.nUpdate -= count
return data
# reset buffer
def resetAllPara(self):
self.nUpdate = 0
self.currentPtr = 0
self.readPtr = 0 # add by lizhenhua 清空读指针
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区
class SunnyLinker64(Thread, ):
serial_port = str(IniRead('system', 'Serial_port'))
t_buffer = 10
n_chan = 64
srate = 250
win_len = 10
win_step = 1
ring_buffer = 5
receiveData = b''
toUv=True#转为uV
RingBufferLock = threading.Lock()
# 单例模式
_instance = None
_initialized = False # 检查是否已经初始化
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(SunnyLinker64, cls).__new__(cls)
return cls._instance
def __init__(self, host='127.0.0.1', port=7878, srate=250, n_chan=64,method = 'tcp'):
if SunnyLinker64._initialized:
return
Thread.__init__(self)
self.daemon = True
self.host = host
self.port = port
self.srate = srate
self.n_chan = n_chan
self.method = method #传输方式,'tcp'表示tcp传输,'serial'表示串口传输
self.__ringBuffer = RingBuffer(self.n_chan + 2,
int(np.round(self.t_buffer * self.srate)))
self.energy = 0 # 电量
self.status_code = 0 # 与采集设备通信的状态码0为异常1为正常
self.gain_value = 6 # 增益倍数
self.interval_inited = False #ssmvep或mi时间窗是否初始化
# 设置初始化标志为True防止重复初始化
SunnyLinker64._initialized = True
# --- 新增:用于心跳检测 ---
self.last_called = 0 # 初始化为0
self.last_called_lock = threading.Lock() # 保护 last_called 的访问
def reset_state(self):
"""清空采集器状态和缓存数据"""
with self.RingBufferLock:
self.__ringBuffer.resetAllPara()
self.count_events = {}
self.epoch_finished = False
self.pack_contain_event = False
self.event_inner_idx = -1
self.interval_inited = False
def interval_init(self,decoder_class):
if decoder_class == 'ssmvep':
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
self.interval_epoch = [int(i * self.srate) for i in interval_epoch] # epoch截取信息
self.train_epoch = [int(self.interval_epoch[0]),
int(self.interval_epoch[1] + 0.1 * self.srate)] # 训练样本epoch
self.latency = (self.interval_epoch[
1] + 0.1 * self.srate) // 5 # 提取epoch的延迟标记5代表每次解包得到的5位采样点;0.1表示比实际需要的长度多取0.1,会被截掉
self.train_latency = (self.train_epoch[1] + 0.1 * self.srate) // 5
elif decoder_class == 'mi':
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
self.interval_epoch = [int(i * self.srate) for i in interval_epoch] # epoch截取信息
self.train_epoch = self.interval_epoch.copy()
self.latency = (self.interval_epoch[1]) // 5 # 提取epoch的延迟标记5代表每次解包得到的5位采样点;
self.train_latency = self.latency
print('时间窗:', (interval_epoch))
self.count_events: Dict[str, int] = {} # 表示包延迟的计数信息
self.event_inner_idx = -1 # event在5位数据包内部的idx
self.epoch_finished = False # 接收epoch是否完整
self.pack_contain_event = False # 当前包是否含有event
self.predict_event = 99
self.events = [1, 2, self.predict_event]
if getattr(self, 'serial', None) and self.serial.is_open:
self.serial.close()
self.serial = serial.Serial(self.serial_port, 460800, timeout=1) # 连接同步器串口
self.interval_inited = True
def set_sampleRate(self,sampleRate_Code=0x00):
'''
设置采样率
:param sampleRate_Code: 0x00:250Hz,0x01:500Hz,0x02:1000Hz,0x03:2000Hz
'''
function_code = 0x02
gain_code = 0x06
sampleRate_Code = [gain_code,sampleRate_Code]
packed_data = ProtocolFrame.pack(function_code, sampleRate_Code)
if self.method == 'tcp':
self.sock.send(packed_data)
def push_trigger(self,label):
'''
数据打标
@param label:标签类别
'''
function_code = None
label = [label]
packed_data = ProtocolFrame.pack(function_code, label)
if self.method == 'tcp' and hasattr(self,'serial'):
print('发送:', label, datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3])
self.serial.write(packed_data)
def Impedance(self, On):
'''
阻抗检测开关
:param On:True为开启False为关闭
:return: 组好的协议帧
'''
function_code = 0x01
if On:
data = [0x1]
self.gain_value = 6
else:
data = [0x0]
self.gain_value = 6
packed_data = ProtocolFrame.pack(function_code, data)
if self.method == 'tcp':
self.sock.send(packed_data)
def connect(self):
try:
if self.method == 'serial':
# 开启com口波特率115200超时5
self.sock = serial.Serial(self.host, self.port, timeout=5)
self.sock.flushInput() # 清空缓冲区
count = self.sock.inWaiting() # 获取串口缓冲区数据
while not count:
count = self.sock.inWaiting() # 获取串口缓冲区数据
# # 接收和存储数据
data = (self.sock.read(count))
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
elif self.method == 'tcp':
# 重连前关闭旧 socket避免资源泄漏
if hasattr(self, 'sock') and self.sock:
try:
self.sock.close()
except Exception:
pass
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect((self.host, int(self.port)))
self.set_sampleRate(0x00) #设置250Hz采样率
return True
except Exception as e:
print("请打开头环")
print(e)
return False
print("connected")
return True
def extract_packet(self, packet):
# 存储一个点的八通道数据
dataList = []
# 存储116个点的八通道数据
dataMatrix = []
for j in range(5):
for i in range(self.n_chan):
if not self.toUv:#原始数据直接输出
val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[
194 * j + 25 + 2 + i * 3]
else:#转为uV
val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[
194 * j + 25 + 2 + i * 3]
if val < 8388608:
val = val * 4.5 / self.gain_value / 8388608 * 1000000;
else:
val = (val - 16777216) * 4.5 / self.gain_value / 8388608 * 1000000;
dataList.append(val)
#同步触发源
val = packet[194 * j + 25 + (i+1) * 3]
dataList.append(val)
#同步触发序号
val = packet[194 * j + 25 + (i+1) * 3+1]
dataList.append(val)
# 将数据矩阵进行拼接
if len(dataMatrix) == 0:
dataMatrix = np.asmatrix(dataList)
else:
dataMatrix = np.concatenate((dataMatrix, np.asmatrix(dataList)), axis=0)
dataList.clear()
return np.transpose(dataMatrix)
def run(self):
self.running = True
self.PackageLength = 998
# 尝试连接循环,断开后自动重连
while self.running:
if self.connect():
break
print(f"无法连接到 {self.host}:{self.port}15秒后重试...")
time.sleep(15)
# 启动心跳检测线程
threading.Thread(target=self.heartbeat_checker, daemon=True).start()
while self.running:
try:
if self.method == 'serial':
count = self.sock.inWaiting() # 获取串口缓冲区数据
if count:
# 接收和存储数据
data = (self.sock.read(count))
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
elif self.method == 'tcp':
data = self.sock.recv(600)
if not data:
break
self.receiveData += data
with self.last_called_lock:
self.last_called = time.time()
self.status_code = 1 # 收到数据,标记为正常
if len(self.receiveData) >= self.PackageLength and self.receiveData.rfind(
b'\x55\x55') >= self.PackageLength - 2:
index = self.receiveData.index(b'\xaa')
self.receiveData = self.receiveData[index:]
if len(self.receiveData) >= self.PackageLength:
onepackage = self.receiveData[:self.PackageLength]
if onepackage[7] != 0:
self.energy = onepackage[7] # 电量
self.receiveData = self.receiveData[self.PackageLength:]
dataMatrix = self.extract_packet(onepackage)
try:
with self.RingBufferLock:
if self.interval_inited:
self.epoch_finished = self.detect_event(dataMatrix)
if self.pack_contain_event:
self.__ringBuffer.resetAllPara() # 检测到当前pack含有event清除ringbuffer中之前的数据
self.__ringBuffer.appendBuffer(dataMatrix)
# self.plotBuffer.appendBuffer(dataMatrix)
if self.epoch_finished:
time.sleep(0.005)
print('epoch_finished: ', datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3])
else:
self.__ringBuffer.appendBuffer(dataMatrix)
except Exception as e:
print("锁:写入异常",e)
# self.RingBufferLock.release()
except ConnectionResetError:
self.status_code = 0 # 状态异常
print("Connection was reset by the peer. 正在重新连接...")
self.sock.close()
# 退出循环后run() 开头的重连循环会自动接管
break
# 如果 running=True重连循环会接管不会执行到这里
# 检测是否含有标签
def detect_event(self, samples):
self.pack_contain_event = False
events = np.array(samples[-2])[0].tolist()
for idx, event in enumerate(events):
if int(event) in self.events:
new_key = "".join(
[
str(event),
datetime.datetime.now().strftime("%Y-%m-%d \
-%H-%M-%S"),
]
)
if event == self.predict_event:
self.count_events[new_key] = self.latency + 1
else:
self.count_events[new_key] = self.train_latency + 1
self.event_inner_idx = idx
self.pack_contain_event = True
drop_items = []
for key, value in self.count_events.items():
value = value - 1
if value == 0:
drop_items.append(key)
self.count_events[key] = value
for key in drop_items:
del self.count_events[key]
if drop_items:
return True
return False
# --- 新增:心跳检测线程 ---
def heartbeat_checker(self):
"""
定期检查是否在最近2秒内收到 eegData
如果超过2秒未收到则设置 status_code = 0
"""
while self.running:
time.sleep(0.5) # 每0.5秒检查一次
with self.last_called_lock:
now = time.time()
# 只有收到过一次数据后才开始判断超时
if self.last_called > 0 and (now - self.last_called) > 30:
if self.status_code != 0:
print("EEG data timeout: disconnected")
self.status_code = 0
def getDataViaSSVEP(self,count):
'''
ssvep的视觉通道共8个通道
@param count: 每通道读取的数值数量
@return: 返回最新的数值
'''
data=self.getData(count)
# PO5,POZ,PO6,O2,PO8,OZ,O1,PO7 64是event导联
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55,64]
row_to_select=np.array(rows_to_extract)
data=data[row_to_select,:]
return data
def get_MIData(self):
'''
取出当前所有数值
:return:
'''
data = self.getData(self.__ringBuffer.nUpdate)
#MI选取导联FC3,FC1,FCZ,FC2,FC4,C5,C3,C1,CZ,C2,C4,C6,CP3,CP1,CP2,CP4,P3,P1,PZ,P2,P4,event1,event2
rows_to_extract = [8, 15, 12, 14, 18, 23, 16,59,50,58,17,45,29,11,10,19,20,61,51,60,21,64,65]
row_to_select = np.array(rows_to_extract)
data = data[row_to_select,:]
return data
def get_SSMVEPData(self):
'''
取出当前所有数值
:return:
'''
data = self.getData(self.__ringBuffer.nUpdate)
# PO5,POZ,PO6,O2,PO8,OZ,O1,PO7 64是event导联
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64,65]
row_to_select = np.array(rows_to_extract)
data = data[row_to_select, :]
return data
def get_concentrateData(self,count):
'''
@param count: 每通道读取的数值数量
@return: 返回最新的数值
'''
data=self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
data = data[row_to_select, :]
return data
def get_blinkData(self,count):
'''
@param count: 每通道读取的数值数量
@return: 返回最新的数值
'''
data=self.getData(count)
rows_to_extract = [0,1]
row_to_select = np.array(rows_to_extract)
data = data[row_to_select, :]
return data
def getImpedance(self, data,decoder_class):
'''
获取阻抗值已经放大100倍单位是kΩ
@param data: 准备计算的通道数据每通道200个值注意不要把信号打标的通道传进来
@return: 返回各个通道的阻抗值
'''
impedanceList = []
for channelindex in range(data.shape[0]):
if len(data[channelindex]) > 0:
data_list = []
# 设计陷波滤波器去除50Hz成分
is50filter = True
if is50filter:
b, a = signal.iirnotch(50, 30, self.srate) # 30是带宽1000是采样频率
data_list = signal.lfilter(b, a, data[channelindex].tolist())
else:
data_list.extend(data[channelindex].tolist())
data_list = data_list[-1000:]
# 执行FFT
fft_result = np.fft.fft(data_list)
fft_magnitude = np.abs(fft_result / len(data_list)) # 归一化FFT结果
freqs = np.fft.fftfreq(len(data_list), d=1 / self.srate) # 频率轴
# y_amp_modified = np.concatenate(([fft_magnitude[0] / len(t[0].tolist())],
# fft_magnitude[1:-1] * 2 / len(t[0].tolist()),
# [fft_magnitude[-1] / len(t[0].tolist())]))
# 找到幅值最大的频率成分的索引忽略直流分量即索引0
max_index = np.argmax(fft_magnitude[1:])
# 获取最大幅值的频率索引加上1因为索引0是直流分量
freq_index = max_index + 1
# 获取最大幅值
max_magnitude = fft_magnitude[freq_index]
# 阻抗
import math
result = math.sqrt(2) * math.pi * max_magnitude / 6 / 4
result *= 0.44 * 100 # 统一放大100倍
impedanceList.append(int(result))
# print(max_magnitude, result)
else:
impedanceList.append(0)
impedances = np.array(impedanceList)
if decoder_class in ('mi', 'ma'):
impedances = impedances[np.array([8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21])]
elif decoder_class == 'blink':
impedances = impedances[np.array([0, 1])]
elif decoder_class == 'concentration':
impedances = impedances[np.array([0, 1])]
else:
impedances = impedances[np.array([13, 3, 2, 46, 9, 54, 47, 55])]
return impedances
def getData(self,count):
'''
获取最新的数据
@param count: 每通道返回的最数值数目
@return: 所有通道的最新count个数值
'''
data=None
try:
with self.RingBufferLock:
data = self.__ringBuffer.getData(count)
except:
print("锁:读取异常")
# self.RingBufferLock.release()
return data
def GetDataLenCount(self):
'''
获取最新缓存中每个通道的数量
@return:
'''
return self.__ringBuffer.nUpdate
def ResetAll(self):
'''
清空缓存
@return:
'''
with self.RingBufferLock:
self.__ringBuffer.resetAllPara()
def stop(self):
self.running = False
class SunnyLinker8(Thread, ):
receiveData = ''
t_buffer = 10
n_chan = 9
srate = 1000
receiveData = b''
toUv=False#转为uV
RingBufferLock = threading.Lock()
def __init__(self, host, port, srate=1000, n_chan=9,method = 'tcp'):
Thread.__init__(self)
self.daemon = True
self.host = host
self.port = port
self.srate = srate
self.n_chan = n_chan
self.method = method #传输方式,'tcp'表示tcp传输,'serial'表示串口传输
self.__ringBuffer = RingBuffer(self.n_chan + 2,
int(np.round(self.t_buffer * self.srate)))
self.energy = 0 #电量
self.status_code = 0 #与采集设备通信的状态码0为异常1为正常
self.gain_value = 6 # 增益倍数
def push_trigger(self,label):
'''
数据打标
@param label:标签类别
'''
function_code = None
label = [label]
packed_data = ProtocolFrame.pack(function_code, label)
if self.method == 'tcp':
self.sock.send(packed_data)
elif self.method == 'serial':
self.sock.write(packed_data)
def Impedance(self, On):
'''
阻抗检测开关
:param On:True为开启False为关闭
:return: 组好的协议帧
'''
function_code = None
if On:
data = [0xA1]
self.gain_value = 24
else:
data = [0xA0]
self.gain_value = 6
packed_data = ProtocolFrame.pack(function_code, data)
if self.method == 'tcp':
self.sock.send(packed_data)
elif self.method == 'serial':
self.sock.write(packed_data)
def connect(self):
try:
if self.method == 'serial':
# 开启com口波特率115200超时5
self.sock = serial.Serial(self.host, self.port, timeout=5)
self.sock.flushInput() # 清空缓冲区
count = self.sock.inWaiting() # 获取串口缓冲区数据
while not count:
count = self.sock.inWaiting() # 获取串口缓冲区数据
# # 接收和存储数据
data = (self.sock.read(count))
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
print("connected")
elif self.method == 'tcp':
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect((self.host, int(self.port)))
print("connected")
except Exception as e:
print("请打开头环")
print(e)
print("connected")
def extract_packet(self, packet):
# 存储一个点的八通道数据
dataList = []
# 存储116个点的八通道数据
dataMatrix = []
# index = (packet[1] << 24) | (packet[2] << 16) | (packet[3] << 8) | packet[4]
# print(index)
for j in range(5):
for i in range(self.n_chan):
if not self.toUv:#原始数据直接输出
val = (packet[26 * j + 25 + i * 3] << 16) | (packet[26 * j + 25 + 1 + i * 3] << 8) | packet[
26 * j + 25 + 2 + i * 3]
else:#转为uV
val = (packet[26 * j + 25 + i * 3] << 16) | (packet[26 * j + 25 + 1 + i * 3] << 8) | packet[
26 * j + 25 + 2 + i * 3]
if val < 8388608:
val = val * 4.5 / self.gain_value / 8388608 * 1000000;
else:
val = (val - 16777216) * 4.5 / self.gain_value / 8388608 * 1000000;
dataList.append(val)
#同步触发源
val = packet[26 * j + 25 + (i+1) * 3]
dataList.append(val)
#同步触发序号
val = packet[26 * j + 25 + (i+1) * 3+1]
dataList.append(val)
# 将数据矩阵进行拼接
if len(dataMatrix) == 0:
dataMatrix = np.asmatrix(dataList)
else:
dataMatrix = np.concatenate((dataMatrix, np.asmatrix(dataList)), axis=0)
dataList.clear()
return np.transpose(dataMatrix)
def run(self):
self.connect()
self.running = True
self.PackageLength = 158
start_time = time.time()
try:
while self.running:
if self.method == 'serial':
end_time = time.time()
if end_time-start_time > 2: #超过2s未收到数据
self.status_code = 0 #状态异常
count = self.sock.inWaiting() # 获取串口缓冲区数据
if count:
start_time = time.time()
self.status_code = 1 # 收到数据,状态正常
# 接收和存储数据
data = (self.sock.read(count))
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
elif self.method == 'tcp':
data = self.sock.recv(100)
if not data:
break
self.receiveData += data
if len(self.receiveData) >= self.PackageLength and self.receiveData.rfind(
b'\x55\x55') >= self.PackageLength - 2:
index = self.receiveData.index(b'\xaa')
self.receiveData = self.receiveData[index:]
if len(self.receiveData) >= self.PackageLength:
onepackage = self.receiveData[:self.PackageLength]
if onepackage[7] != 0:
self.energy = onepackage[7] # 电量
self.receiveData = self.receiveData[self.PackageLength:]
dataMatrix = self.extract_packet(onepackage)
try:
with self.RingBufferLock:
self.__ringBuffer.appendBuffer(dataMatrix)
except:
print("锁:写入异常")
self.sock.close()
except ConnectionResetError:
self.status_code = 0 # 状态异常
print("Connection was reset by the peer.")
except SerialException as Se:
self.status_code = 0
print('串口通信异常!请检查适配器')
def process_packet(self):
if self.circular_buffer.buffer_length > 158:
packet = self.circular_buffer.extract_packet()
if packet:
# Here you would parse the packet according to the protocol
# print("Received packet:%s,index:%s", len(packet),str(integer_value))
return packet
else:
print("Received Nothing")
return None
def getDataViaSSVEP(self,count):
'''
ssvep的视觉通道共8个通道
@param count: 每通道读取的数值数量
@return: 返回最新的数值
'''
data=self.getData(count)
data=data[:8,:]
return data
def getImpedance(self, data):
'''
获取阻抗值已经放大100倍单位是kΩ
@param data: 准备计算的通道数据每通道200个值注意不要把信号打标的通道传进来
@return: 返回各个通道的阻抗值
'''
impedanceList = []
for channelindex in range(data.shape[0]):
if len(data[channelindex]) > 0:
data_list = []
# 设计陷波滤波器去除50Hz成分
is50filter = True
if is50filter:
b, a = signal.iirnotch(50, 30, self.srate) # 30是带宽1000是采样频率
data_list = signal.lfilter(b, a, data[channelindex].tolist())
else:
data_list.extend(data[channelindex].tolist())
data_list = data_list[-1000:]
# 执行FFT
fft_result = np.fft.fft(data_list)
fft_magnitude = np.abs(fft_result / len(data_list)) # 归一化FFT结果
freqs = np.fft.fftfreq(len(data_list), d=1 / self.srate) # 频率轴
# y_amp_modified = np.concatenate(([fft_magnitude[0] / len(t[0].tolist())],
# fft_magnitude[1:-1] * 2 / len(t[0].tolist()),
# [fft_magnitude[-1] / len(t[0].tolist())]))
# 找到幅值最大的频率成分的索引忽略直流分量即索引0
max_index = np.argmax(fft_magnitude[1:])
# 获取最大幅值的频率索引加上1因为索引0是直流分量
freq_index = max_index + 1
# 获取最大幅值
max_magnitude = fft_magnitude[freq_index]
# 阻抗
import math
result = math.sqrt(2) * math.pi * max_magnitude / 6 / 4
result *= 0.44 * 100 # 统一放大100倍
impedanceList.append(int(result))
# print(max_magnitude, result)
else:
impedanceList.append(0)
# impedances = ":".join(map(str, impedanceList))
impedances = np.array(impedanceList)
impedances = impedances[:8]
return impedances
def getData(self,count):
'''
获取最新的数据
@param count: 每通道返回的最数值数目
@return: 所有通道的最新count个数值
'''
data=None
try:
with self.RingBufferLock:
data = self.__ringBuffer.getData(count)
except:
print("锁:读取异常")
# self.RingBufferLock.release()
return data
def GetDataLenCount(self):
'''
获取最新缓存中每个通道的数量
@return:
'''
return self.__ringBuffer.nUpdate
def ResetAll(self):
'''
清空缓存
@return:
'''
with self.RingBufferLock:
self.__ringBuffer.resetAllPara()
def stop(self):
self.running = False
if __name__ == "__main__":
# Usage
Linker = SunnyLinker64('127.0.0.1', 5086, 1000, 65)
Linker.start()
try:
while True:
time.sleep(0.005)
if(Linker.count()>0):
# print(Linker.ringBuffer.nUpdate)
t = Linker.getData()
print(t.shape[1], Linker.count())
# Linker.ringBuffer.nUpdate=0
# time.sleep(0.2)
except KeyboardInterrupt:
Linker.stop()

View File

@@ -0,0 +1,193 @@
from typing import List, Tuple, Union, Optional
class ProtocolFrame:
# 协议常量
FRAME_HEADER = 0xAA
FRAME_TAIL1 = 0x55
FRAME_TAIL2 = 0x55
RESERVED_SIZE = 6
MIN_FRAME_SIZE = 13 # 帧头1 + 功能1 + 长度2 + 预留6 + CRC1 + 包尾2
MAX_DATA_LENGTH = 0xFFFF # 最大数据长度 (2字节能表示的最大值)
@staticmethod
def calculate_crc8(data: bytes) -> bytes:
"""
计算CRC8校验值
Args:
data: 需要计算CRC的数据
Returns:
一个字节的CRC值bytes类型
"""
crc = 0
for byte in data:
crc ^= byte
for _ in range(8):
crc = ((crc << 1) ^ 0x07 if crc & 0x80 else crc << 1) & 0xFF
return bytes([crc])
@classmethod
def pack(cls, function, data: Union[bytes, bytearray, List[int]],
reserved: Optional[Union[bytes, bytearray, List[int]]] = None) -> bytes:
"""
协议打包函数
Args:
function: 功能码 (1字节)
data: 数据块
reserved: 预留字节(6字节可选)
Returns:
打包后的字节数据
"""
# 检查功能码
if function != None:
if not 0 <= function <= 0xFF:
raise ValueError("功能码必须是1字节")
# 转换数据为bytearray
if isinstance(data, list):
data = bytearray(data)
elif isinstance(data, bytes):
data = bytearray(data)
# 检查数据长度
data_length = len(data)
if data_length > cls.MAX_DATA_LENGTH:
raise ValueError(f"数据长度超过最大值 {cls.MAX_DATA_LENGTH}")
# 处理预留字节
if reserved is None:
reserved = bytearray([0] * cls.RESERVED_SIZE)
else:
if isinstance(reserved, list):
reserved = bytearray(reserved)
elif isinstance(reserved, bytes):
reserved = bytearray(reserved)
if len(reserved) != cls.RESERVED_SIZE:
raise ValueError(f"预留字节必须是{cls.RESERVED_SIZE}字节")
# 构建帧
frame = bytearray([cls.FRAME_HEADER]) # 帧头 (1字节)
if function != None:
frame.append(function) # 功能码 (1字节)
data_length+=6
# 数据长度 (2字节大端序)
frame.append((data_length >> 8) & 0xFF) # 高字节
frame.append(data_length & 0xFF) # 低字节
if function != None:
frame.extend(reserved) # 预留字节 (6字节)
frame.extend(data) # 数据块 (变长)
# 计算CRC (从功能码开始到数据块结束)
crc = cls.calculate_crc8(frame[1:]) # 不包含帧头
frame.extend(crc) # CRC校验 (1字节)
# 添加帧尾
frame.extend([cls.FRAME_TAIL1, cls.FRAME_TAIL2]) # 帧尾 (2字节)
return bytes(frame)
@classmethod
def unpack(cls, data: Union[bytes, bytearray]) -> Tuple[int, bytearray, bytearray]:
"""
协议解包函数
Args:
data: 待解析的字节数据
Returns:
(功能码, 数据块, 预留字节)
Raises:
ValueError: 当数据格式不正确时
"""
# 检查数据长度
if len(data) < cls.MIN_FRAME_SIZE:
raise ValueError("数据长度不足")
# 检查帧头
if data[0] != cls.FRAME_HEADER:
raise ValueError("帧头错误")
# 检查帧尾
if data[-2:] != bytes([cls.FRAME_TAIL1, cls.FRAME_TAIL2]):
raise ValueError("帧尾错误")
# 解析基本信息
function = data[1] # 功能码 (1字节)
# 数据长度 (2字节大端序)
data_length = (data[2] << 8) | data[3]
reserved = data[4:10] # 预留字节 (6字节)
# 检查数据长度
expected_length = cls.MIN_FRAME_SIZE + data_length
if len(data) != expected_length:
raise ValueError(f"数据长度不匹配: 期望{expected_length}字节,实际{len(data)}字节")
# 提取数据块
payload = data[10:10 + data_length]
# 验证CRC (从功能码开始到数据块结束)
received_crc = data[-3]
calculated_crc = cls.calculate_crc8(data[1:-3])[0] # 获取字节值
if received_crc != calculated_crc:
raise ValueError(f"CRC校验失败: 期望{calculated_crc:02X},实际{received_crc:02X}")
return function, bytearray(payload), bytearray(reserved)
def print_hex(data: bytes, label: str = ""):
"""打印十六进制数据,并按字节添加空格"""
hex_str = ' '.join([f"{b:02X}" for b in data])
if label:
print(f"{label}: {hex_str}")
else:
print(hex_str)
def print_frame_details(data: bytes):
"""打印帧的详细信息"""
print("帧详细信息:")
print(f"帧头: {data[0]:02X}")
print(f"功能码: {data[1]:02X}")
print(f"数据长度: {data[2]:02X} {data[3]:02X} ({(data[2] << 8) | data[3]}字节)")
print(f"预留字节: {' '.join([f'{b:02X}' for b in data[4:10]])}")
data_length = (data[2] << 8) | data[3]
print(f"数据块: {' '.join([f'{b:02X}' for b in data[10:10 + data_length]])}")
print(f"CRC校验: {data[-3]:02X}")
print(f"帧尾: {data[-2]:02X} {data[-1]:02X}")
# 使用示例
def example_usage():
try:
# 示例1简单数据打包
function_code = 0x01
data = [0x1]
packed_data = ProtocolFrame.pack(function_code, data)
print_hex(packed_data, "示例1 - 完整帧")
print_frame_details(packed_data)
print()
# 示例3解包验证
function, payload, reserved = ProtocolFrame.unpack(packed_data)
print("解包结果:")
print(f"功能码: 0x{function:02X}")
print_hex(payload, "数据块")
print_hex(reserved, "预留字节")
except ValueError as e:
print(f"错误: {e}")
if __name__ == "__main__":
example_usage()

View File

@@ -0,0 +1,409 @@
"""
EEG Conformer
Convolutional Transformer for EEG decoding
Couple CNN and Transformer in a concise manner with amazing results
"""
# remember to change paths
import os
gpus = [0]
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
import numpy as np
import math
import random
import time
import datetime
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
# from common_spatial_pattern import csp
# from torch.utils.tensorboard import SummaryWriter
from torch.backends import cudnn
cudnn.benchmark = True
cudnn.deterministic = True
from sklearn.model_selection import train_test_split
# writer = SummaryWriter('./TensorBoardX/')
# Convolution module
# use conv to capture local features, instead of postion embedding.
class PatchEmbedding(nn.Module):
def __init__(self, emb_size=40,n_chan=8):
# self.patch_size = patch_size
super().__init__()
self.shallownet = nn.Sequential(
nn.Conv2d(1, 40, (1, 25), (1, 1)),
nn.Conv2d(40, 40, (n_chan, 1), (1, 1)),
nn.BatchNorm2d(40),
nn.ELU(),
nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
nn.Dropout(0.5),
)
self.projection = nn.Sequential(
nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly
Rearrange('b e (h) (w) -> b (h w) e'),
)
def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape
x = self.shallownet(x)
x = self.projection(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size, num_heads, dropout):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
self.keys = nn.Linear(emb_size, emb_size)
self.queries = nn.Linear(emb_size, emb_size)
self.values = nn.Linear(emb_size, emb_size)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
if mask is not None:
fill_value = torch.finfo(torch.float32).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1 / 2)
att = F.softmax(energy / scaling, dim=-1)
att = self.att_drop(att)
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out
class ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x
class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size, expansion, drop_p):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
)
class GELU(nn.Module):
def forward(self, input: Tensor) -> Tensor:
return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
class TransformerEncoderBlock(nn.Sequential):
def __init__(self,
emb_size,
num_heads=10,
drop_p=0.5,
forward_expansion=4,
forward_drop_p=0.5):
super().__init__(
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
MultiHeadAttention(emb_size, num_heads, drop_p),
nn.Dropout(drop_p)
)),
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
FeedForwardBlock(
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
nn.Dropout(drop_p)
)
))
class TransformerEncoder(nn.Sequential):
def __init__(self, depth, emb_size):
super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
class ClassificationHead(nn.Sequential):
def __init__(self, emb_size, n_classes):
super().__init__()
# global average pooling
self.clshead = nn.Sequential(
Reduce('b n e -> b e', reduction='mean'),
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes)
)
self.fc = nn.Sequential(
nn.Linear(2440, 256),
nn.ELU(),
nn.Dropout(0.5),
nn.Linear(256, 32),
nn.ELU(),
nn.Dropout(0.3),
nn.Linear(32, 2)
)
def forward(self, x):
x = x.contiguous().view(x.size(0), -1)
out = self.fc(x)
return out
class Conformer(nn.Sequential):
def __init__(self, emb_size=40, depth=6, n_classes=2,n_chan=8, **kwargs):
super().__init__(
PatchEmbedding(emb_size,n_chan),
TransformerEncoder(depth, emb_size),
ClassificationHead(emb_size, n_classes)
)
class ExP():
def __init__(self,n_chan):
super(ExP, self).__init__()
self.n_chan = n_chan
self.batch_size = 24
self.n_epochs = 250
self.c_dim = 4
self.lr = 0.0002
self.b1 = 0.5
self.b2 = 0.999
self.start_epoch = 0
# 创建目录
os.makedirs("online_Models", exist_ok=True)
self.log_write = open("./online_Models/log_result.txt", "w")
self.Tensor = torch.cuda.FloatTensor
self.LongTensor = torch.cuda.LongTensor
self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()
self.model = Conformer(n_chan=self.n_chan).cuda()
self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
self.model = self.model.cuda()
# self.model = EEGNet().cuda()
# self.model = nn.DataParallel(self.model,device_ids=[i for i in range(len(gpus))])
# self.model = self.model.cuda()
# summary(self.model, (1, 8, 1000))
# Segmentation and Reconstruction (S&R) data augmentation
def interaug(self, timg, label):
# 确保输入是 numpy 数组CPU
if isinstance(timg, torch.Tensor):
timg = timg.cpu().numpy()
if isinstance(label, torch.Tensor):
label = label.cpu().numpy()
aug_data = []
aug_label = []
for cls4aug in range(2):
cls_idx = np.where(label == cls4aug + 1)
tmp_data = timg[cls_idx]
tmp_label = label[cls_idx]
tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, self.n_chan, 1000))
for ri in range(int(self.batch_size / 2)):
for rj in range(8):
rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :,
rj * 125:(rj + 1) * 125]
aug_data.append(tmp_aug_data)
aug_label.append(tmp_label[:int(self.batch_size / 2)])
aug_data = np.concatenate(aug_data)
aug_label = np.concatenate(aug_label)
aug_shuffle = np.random.permutation(len(aug_data))
aug_data = aug_data[aug_shuffle, :, :]
aug_label = aug_label[aug_shuffle]
# 返回 numpy 数组,由调用方决定是否移到 GPU
return aug_data, aug_label
def train(self,all_data,all_label,model_path):
all_data = np.array(all_data);all_label = np.array(all_label)
all_data = np.expand_dims(all_data, axis=1)
train_data, test_data, train_label, test_label = train_test_split(all_data, all_label, test_size=0.2,
random_state=42, stratify=all_label,shuffle=True)
# === 优化:一次性预生成增强数据,避免每个 batch 都重复计算 ===
aug_data, aug_label = self.interaug(train_data, train_label)
# 将原始数据和增强数据合并,再一起打乱
train_data_full = np.concatenate([train_data, aug_data], axis=0)
train_label_full = np.concatenate([train_label, aug_label], axis=0)
shuffle_idx = np.random.permutation(len(train_data_full))
train_data_full = train_data_full[shuffle_idx]
train_label_full = train_label_full[shuffle_idx]
img = torch.from_numpy(train_data_full)
label = torch.from_numpy(train_label_full-1)
dataset = torch.utils.data.TensorDataset(img, label)
self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
test_data = torch.from_numpy(test_data)
test_label = torch.from_numpy(test_label-1)
test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
# Optimizers
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
test_data = Variable(test_data.type(self.Tensor))
test_label = Variable(test_label.type(self.LongTensor))
bestAcc = 0
averAcc = 0
num = 0
Y_true = 0
Y_pred = 0
# Train the cnn model
for e in range(self.n_epochs):
# in_epoch = time.time()
self.model.train()
for i, (img, label) in enumerate(self.dataloader):
img = Variable(img.cuda().type(self.Tensor))
label = Variable(label.cuda().type(self.LongTensor))
outputs = self.model(img)
loss = self.criterion_cls(outputs, label)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# out_epoch = time.time()
# test process
if (e + 1) % 1 == 0:
self.model.eval()
Cls = self.model(test_data)
loss_test = self.criterion_cls(Cls, test_label)
y_pred = torch.max(Cls, 1)[1]
acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
train_pred = torch.max(outputs, 1)[1]
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
print('Epoch:', e,
' Train loss: %.6f' % loss.detach().cpu().numpy(),
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
' Train accuracy %.6f' % train_acc,
' Test accuracy is %.6f' % acc)
self.log_write.write(str(e) + " " + str(acc) + "\n")
num = num + 1
averAcc = averAcc + acc
if acc > bestAcc:
bestAcc = acc
Y_true = test_label
Y_pred = y_pred
torch.save(self.model, model_path)
averAcc = averAcc / num
print('The average accuracy is:', averAcc)
print('The best accuracy is:', bestAcc)
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
return bestAcc, averAcc, Y_true, Y_pred
# writer.close()
def onlineTrain(data_queue,result_queue):
import torch
print(f"[DEBUG] torch.__version__ = {torch.__version__}")
print(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}")
try:
starttime = datetime.datetime.now()
# seed_n = np.random.randint(2025)
seed_n = 1877
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)
# 从队列获取训练数据
data = data_queue.get(timeout=30)
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
exp = ExP(n_chan)
print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path)
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
print('THE BEST ACCURACY IS ' + str(bestAcc))
endtime = datetime.datetime.now()
print('train duration: ',str(endtime - starttime))
# 将模型或参数传回
result_queue.put({
'status': 'success',
'model_state': model_path, # 或保存路径
'timestamp': time.time()
})
except Exception as e:
result_queue.put({'status': 'error', 'msg': str(e)})
def offlineTrain(all_data,all_label,modelPath):
starttime = datetime.datetime.now()
# seed_n = np.random.randint(2025)
seed_n = 1877
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)
exp = ExP()
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
print('THE BEST ACCURACY IS ' + str(bestAcc))
endtime = datetime.datetime.now()
print('train duration: ',str(endtime - starttime))
if __name__ == "__main__":
print(time.asctime(time.localtime(time.time())))
print(time.asctime(time.localtime(time.time())))

View File

@@ -0,0 +1,382 @@
"""
EEG Conformer
Convolutional Transformer for EEG decoding
Couple CNN and Transformer in a concise manner with amazing results
"""
# remember to change paths
import os
import numpy as np
import math
import random
import time
import datetime
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import nn
from torch import Tensor
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
from torch.backends import cudnn
from sklearn.model_selection import train_test_split
# writer = SummaryWriter('./TensorBoardX/')
# Convolution module
# use conv to capture local features, instead of postion embedding.
class PatchEmbedding(nn.Module):
def __init__(self, emb_size=40):
# self.patch_size = patch_size
super().__init__()
self.shallownet = nn.Sequential(
nn.Conv2d(1, 40, (1, 25), (1, 1)),
nn.Conv2d(40, 40, (8, 1), (1, 1)),
nn.BatchNorm2d(40),
nn.ELU(),
nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
nn.Dropout(0.5),
)
self.projection = nn.Sequential(
nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly
Rearrange('b e (h) (w) -> b (h w) e'),
)
def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape
x = self.shallownet(x)
x = self.projection(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size, num_heads, dropout):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
self.keys = nn.Linear(emb_size, emb_size)
self.queries = nn.Linear(emb_size, emb_size)
self.values = nn.Linear(emb_size, emb_size)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
if mask is not None:
fill_value = torch.finfo(torch.float32).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1 / 2)
att = F.softmax(energy / scaling, dim=-1)
att = self.att_drop(att)
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out
class ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x
class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size, expansion, drop_p):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
)
class GELU(nn.Module):
def forward(self, input: Tensor) -> Tensor:
return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
class TransformerEncoderBlock(nn.Sequential):
def __init__(self,
emb_size,
num_heads=10,
drop_p=0.5,
forward_expansion=4,
forward_drop_p=0.5):
super().__init__(
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
MultiHeadAttention(emb_size, num_heads, drop_p),
nn.Dropout(drop_p)
)),
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
FeedForwardBlock(
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
nn.Dropout(drop_p)
)
))
class TransformerEncoder(nn.Sequential):
def __init__(self, depth, emb_size):
super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
class ClassificationHead(nn.Sequential):
def __init__(self, emb_size, n_classes):
super().__init__()
# global average pooling
self.clshead = nn.Sequential(
Reduce('b n e -> b e', reduction='mean'),
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes)
)
self.fc = nn.Sequential(
nn.Linear(2440, 256),
nn.ELU(),
nn.Dropout(0.5),
nn.Linear(256, 32),
nn.ELU(),
nn.Dropout(0.3),
nn.Linear(32, 2)
)
def forward(self, x):
x = x.contiguous().view(x.size(0), -1)
out = self.fc(x)
return out
class Conformer(nn.Sequential):
def __init__(self, emb_size=40, depth=6, n_classes=2, **kwargs):
super().__init__(
PatchEmbedding(emb_size),
TransformerEncoder(depth, emb_size),
ClassificationHead(emb_size, n_classes)
)
class ExP():
def __init__(self):
super(ExP, self).__init__()
self.batch_size = 24
self.n_epochs = 250
self.c_dim = 4
self.lr = 0.0002
self.b1 = 0.5
self.b2 = 0.999
self.start_epoch = 0
self.log_write = open("./online_Models/log_result.txt", "w")
# 自动选择设备:有 GPU 用 GPU否则用 CPU
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.device = torch.device("cpu")
print(f"Using device: {self.device}")
# 定义张量类型(不再强制使用 cuda
self.Tensor = torch.FloatTensor
self.LongTensor = torch.LongTensor
# 将模型移到指定设备
self.model = Conformer().to(self.device)
# 损失函数也移到设备
self.criterion_cls = torch.nn.CrossEntropyLoss().to(self.device)
# self.model = EEGNet().cuda()
# self.model = nn.DataParallel(self.model,device_ids=[i for i in range(len(gpus))])
# self.model = self.model.cuda()
# summary(self.model, (1, 8, 1000))
# Segmentation and Reconstruction (S&R) data augmentation
def interaug(self, timg, label):
aug_data = []
aug_label = []
for cls4aug in range(2):
cls_idx = np.where(label == cls4aug + 1)
tmp_data = timg[cls_idx]
tmp_label = label[cls_idx]
tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, 8, 1000))
for ri in range(int(self.batch_size / 2)):
for rj in range(8):
rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :,
rj * 125:(rj + 1) * 125]
aug_data.append(tmp_aug_data)
aug_label.append(tmp_label[:int(self.batch_size / 2)])
aug_data = np.concatenate(aug_data)
aug_label = np.concatenate(aug_label)
aug_shuffle = np.random.permutation(len(aug_data))
aug_data = aug_data[aug_shuffle, :, :]
aug_label = aug_label[aug_shuffle]
aug_data = torch.from_numpy(aug_data).float().to(self.device)
aug_label = torch.from_numpy(aug_label - 1).long().to(self.device)
return aug_data, aug_label
def train(self,all_data,all_label,model_path):
all_data = np.array(all_data);all_label = np.array(all_label)
all_data = np.expand_dims(all_data, axis=1)
train_data, test_data, train_label, test_label = train_test_split(all_data, all_label, test_size=0.2,
random_state=42, stratify=all_label,shuffle=True)
# 转为 Tensor
img = torch.from_numpy(train_data).float().to(self.device)
label = torch.from_numpy(train_label - 1).long().to(self.device)
dataset = torch.utils.data.TensorDataset(img, label)
self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
test_data = torch.from_numpy(test_data).float().to(self.device)
test_label = torch.from_numpy(test_label - 1).long().to(self.device)
test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
# Optimizers
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
bestAcc = 0
averAcc = 0
num = 0
Y_true = 0
Y_pred = 0
# Train the cnn model
for e in range(self.n_epochs):
# in_epoch = time.time()
self.model.train()
for i, (img, label) in enumerate(self.dataloader):
# data augmentation
aug_data, aug_label = self.interaug(train_data, train_label)
img = torch.cat((img, aug_data))
label = torch.cat((label, aug_label))
outputs = self.model(img)
loss = self.criterion_cls(outputs, label)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# out_epoch = time.time()
# test process
if (e + 1) % 1 == 0:
self.model.eval()
with torch.no_grad():
Cls = self.model(test_data)
loss_test = self.criterion_cls(Cls, test_label)
y_pred = torch.max(Cls, 1)[1]
acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
train_pred = torch.max(outputs, 1)[1]
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
print('Epoch:', e,
' Train loss: %.6f' % loss.detach().cpu().numpy(),
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
' Train accuracy %.6f' % train_acc,
' Test accuracy is %.6f' % acc)
self.log_write.write(str(e) + " " + str(acc) + "\n")
num = num + 1
averAcc = averAcc + acc
if acc > bestAcc:
bestAcc = acc
Y_true = test_label
Y_pred = y_pred
torch.save(self.model, model_path)
averAcc = averAcc / num
print('The average accuracy is:', averAcc)
print('The best accuracy is:', bestAcc)
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
return bestAcc, averAcc, Y_true, Y_pred
# writer.close()
def onlineTrain(data_queue,result_queue):
try:
starttime = datetime.datetime.now()
# seed_n = np.random.randint(2025)
seed_n = 1877
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)
exp = ExP()
# 从队列获取训练数据
data = data_queue.get(timeout=30)
all_data, all_label,model_path = data['data'], data['label'],data['modelPath']
print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path)
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
print('THE BEST ACCURACY IS ' + str(bestAcc))
endtime = datetime.datetime.now()
print('train duration: ',str(endtime - starttime))
# 将模型或参数传回
result_queue.put({
'status': 'success',
'model_state': model_path, # 或保存路径
'timestamp': time.time()
})
except Exception as e:
result_queue.put({'status': 'error', 'msg': str(e)})
def offlineTrain(all_data,all_label,modelPath):
starttime = datetime.datetime.now()
# seed_n = np.random.randint(2025)
seed_n = 1877
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)
exp = ExP()
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
print('THE BEST ACCURACY IS ' + str(bestAcc))
endtime = datetime.datetime.now()
print('train duration: ',str(endtime - starttime))
if __name__ == "__main__":
print(time.asctime(time.localtime(time.time())))
print(time.asctime(time.localtime(time.time())))

View File

@@ -0,0 +1,184 @@
from torchsummary import summary
import torch
import torch.nn as nn
def weights_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
# nn.init.constant(m.bias, 0) # bias may be none
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0)
def square_activation(x):
return torch.square(x)
def safe_log(x):
return torch.clip(torch.log(x), min=1e-7, max=1e7)
class ShallowConvNet(nn.Module):
def __init__(self, num_classes=3, chans=19, samples=768):
super(ShallowConvNet, self).__init__()
self.conv_nums = 40
self.features = nn.Sequential(
nn.Conv2d(1, self.conv_nums, (1, 25)),
nn.Conv2d(self.conv_nums, self.conv_nums, (chans, 1), bias=False),
nn.BatchNorm2d(self.conv_nums)
)
self.avgpool = nn.AvgPool2d(kernel_size=(1, 75), stride=(1, 15))
self.dropout = nn.Dropout()
out = torch.ones((1, 1, chans, samples))
out = self.features(out)
out = self.avgpool(out)
n_out_time = out.cpu().data.numpy().shape
self.classifier = nn.Linear(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes)
def forward(self, x):
x = self.features(x)
x = square_activation(x)
x = self.avgpool(x)
x = safe_log(x)
x = self.dropout(x)
features = torch.flatten(x, 1)
cls = self.classifier(features)
return cls
class EEGNet(nn.Module):
def __init__(self, num_classes=2, chans=8, samples=1000, dropout_rate=0.5, kernel_length=64, F1=8,
F2=16,):
super(EEGNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, F1, kernel_size=(1, kernel_length), bias=False),
nn.BatchNorm2d(F1),
nn.Conv2d(F1, F1, kernel_size=(chans, 1), groups=F1, bias=False), # groups=F1 for depthWiseConv
nn.BatchNorm2d(F1),
nn.ELU(inplace=True),
# nn.ReLU(),
nn.AvgPool2d((1, 4)),
nn.Dropout(dropout_rate),
# for SeparableCon2D
# SeparableConv2D(F1, F2, kernel1_size=(1, 16), bias=False),
nn.Conv2d(F1, F1, kernel_size=(1, 16), groups=F1, bias=False), # groups=F1 for depthWiseConv
nn.BatchNorm2d(F1),
nn.ELU(inplace=True),
# nn.ReLU(),
nn.Conv2d(F1, F2, kernel_size=(1, 1), groups=1, bias=False), # point-wise cnn
nn.BatchNorm2d(F2),
# nn.ReLU(),
nn.ELU(inplace=True),
nn.AvgPool2d((1, 8)),
nn.Dropout(p=dropout_rate),
# nn.Dropout(p=0.5),
)
out = torch.ones((1, 1, chans, samples))
out = self.features(out)
n_out_time = out.cpu().data.numpy().shape
self.classifier = nn.Linear(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes)
def forward(self, x):
conv_features = self.features(x)
features = torch.flatten(conv_features, 1)
cls = self.classifier(features)
return cls
class LMDA(nn.Module):
"""
LMDA-Net for the paper
"""
def __init__(self, chans=19, samples=768, num_classes=3, depth=9, kernel=75, channel_depth1=24, channel_depth2=9,
ave_depth=1, avepool=5):
super(LMDA, self).__init__()
self.ave_depth = ave_depth
self.channel_weight = nn.Parameter(torch.randn(depth, 1, chans), requires_grad=True)
nn.init.xavier_uniform_(self.channel_weight.data)
self.time_conv = nn.Sequential(
nn.Conv2d(depth, channel_depth1, kernel_size=(1, 1), groups=1, bias=False),
nn.BatchNorm2d(channel_depth1),
nn.Conv2d(channel_depth1, channel_depth1, kernel_size=(1, kernel),
groups=channel_depth1, bias=False),
nn.BatchNorm2d(channel_depth1),
nn.GELU(),
)
# self.avgPool1 = nn.AvgPool2d((1, 24))
self.chanel_conv = nn.Sequential(
nn.Conv2d(channel_depth1, channel_depth2, kernel_size=(1, 1), groups=1, bias=False),
nn.BatchNorm2d(channel_depth2),
nn.Conv2d(channel_depth2, channel_depth2, kernel_size=(chans, 1), groups=channel_depth2, bias=False),
nn.BatchNorm2d(channel_depth2),
nn.GELU(),
)
self.norm = nn.Sequential(
nn.AvgPool3d(kernel_size=(1, 1, avepool)),
# nn.AdaptiveAvgPool3d((9, 1, 35)),
nn.Dropout(p=0.65),
)
# 定义自动填充模块
out = torch.ones((1, 1, chans, samples))
out = torch.einsum('bdcw, hdc->bhcw', out, self.channel_weight)
out = self.time_conv(out)
out = self.chanel_conv(out)
out = self.norm(out)
n_out_time = out.cpu().data.numpy().shape
print('In ShallowNet, n_out_time shape: ', n_out_time)
self.classifier = nn.Linear(n_out_time[-1]*n_out_time[-2]*n_out_time[-3], num_classes)
def EEGDepthAttention(self, x):
# x: input features with shape [N, C, H, W]
N, C, H, W = x.size()
# K = W if W % 2 else W + 1
k = 7
adaptive_pool = nn.AdaptiveAvgPool2d((1, W))
conv = nn.Conv2d(1, 1, kernel_size=(k, 1), padding=(k//2, 0), bias=True).to(x.device) # original kernel k
nn.init.xavier_uniform_(conv.weight)
nn.init.constant_(conv.bias, 0)
softmax = nn.Softmax(dim=-2)
x_pool = adaptive_pool(x)
x_transpose = x_pool.transpose(-2, -3)
y = conv(x_transpose)
y = softmax(y)
y = y.transpose(-2, -3)
return y * C * x
def forward(self, x):
x = torch.einsum('bdcw, hdc->bhcw', x, self.channel_weight)
x_time = self.time_conv(x) # batch, depth1, channel, samples_
x_time = self.EEGDepthAttention(x_time) # DA1
x = self.chanel_conv(x_time) # batch, depth2, 1, samples_
x = self.norm(x)
features = torch.flatten(x, 1)
cls = self.classifier(features)
return cls
if __name__ == '__main__':
model = ShallowConvNet(num_classes=4, chans=22, samples=1125).cuda()
a = torch.randn(12, 1, 3, 875).cuda().float()
l2 = model(a)
model_optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
summary(model, show_input=True)
print(l2.shape)

View File

@@ -0,0 +1,30 @@
# -*-coding:utf-8 -*-
import configparser
import os
import sys
from audioop import error
BASE_DIR = os.getcwd()
IniFileName = os.path.join(BASE_DIR, 'config.ini')
# IniFileName=os.path.join( 'config.ini')
def IniWrite(section,keyname,value):
# 创建ConfigParser对象
config = configparser.ConfigParser()
config.read(IniFileName,encoding='utf-8')
with open(IniFileName, 'w') as configfile:
if not config.has_section(section):
config.add_section(section)
config[section][keyname]=str(value)
config.write(configfile)
def IniRead(section,key):
try:
config = configparser.ConfigParser()
config.read(IniFileName,encoding='utf-8')
return config[section][key]
except error as e:
print(e)
# 读取特定section和键的值
return '5'

View File

@@ -0,0 +1,15 @@
import ctypes
import sys
def is_program_running(name='Global\\Decoder'):
# 创建互斥体
mutex_name =name
h_mutex = ctypes.windll.kernel32.CreateMutexW(None, False, mutex_name)
# 检查互斥体是否已经存在
if ctypes.windll.kernel32.GetLastError() == 183: # ERROR_ALREADY_EXISTS
print("程序已经在运行.")
return True
return False

View File

@@ -0,0 +1,418 @@
# -*- coding: utf-8 -*-
#
# Authors: Swolf <swolfforever@gmail.com>
# Date: 2021/1/07
# License: MIT License
from typing import Optional, List, Tuple, Union
import warnings
import numpy as np
from numpy import ndarray
from numpy.linalg import linalg
from scipy.linalg import solve, qr
from scipy.signal import sosfiltfilt, cheby1, cheb1ord
from sklearn.base import BaseEstimator, TransformerMixin, clone
def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray:
"""Transform spatial filters to spatial patterns based on paper [1]_.
Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine
information from different channels to extract signals of interest from EEG signals, but if our goal is
neurophysiological interpretation or visualization of weights, activation patterns need to be constructed
from the obtained spatial filters.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
W : ndarray
Spatial filters, shape (n_channels, n_filters).
Cx : ndarray
Covariance matrix of eeg data, shape (n_channels, n_channels).
Cs : ndarray
Covariance matrix of source data, shape (n_channels, n_channels).
Returns
-------
A : ndarray
Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern.
References
----------
.. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging.
Neuroimage 87 (2014): 96-110.
"""
# use linalg.solve instead of inv, makes it more stable
# see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py
# and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html
A = solve(Cs.T, np.dot(Cx, W).T).T
return A
class FilterBank(BaseEstimator, TransformerMixin):
"""
Filter bank decomposition is a bandpass filter array that divides the input signal into
multiple subband components and obtains the eigenvalues of each subband component.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
base_estimator : class
Estimator for model training and feature extraction.
filterbank : list[ndarray]
A bandpass filter bank used to divide the input signal into multiple subband components.
n_jobs : int
Sets the number of CPU working cores. The default is None.
References
----------
.. [1] Chen X, Wang Y, Nakanishi M, et al. High-speed spelling with a noninvasive brain-computer interface[J].
Proceedings of the national academy of sciences, 2015, 112(44): E6058-E6067.
"""
def __init__(
self,
base_estimator: BaseEstimator,
filterbank: List[ndarray],
n_jobs: Optional[int] = None,
):
self.base_estimator = base_estimator
self.filterbank = filterbank
self.n_jobs = n_jobs
def fit(self, X: ndarray, y: Optional[ndarray] = None, **kwargs):
"""
Training model
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
X : None
Training signal (parameters can be ignored, only used to maintain code structure).
y : None
Label data (ibid., ignorable).
Yf : None
Reference signal (ibid., ignorable).
"""
self.estimators_ = [
clone(self.base_estimator) for _ in range(len(self.filterbank))
]
X = self.transform_filterbank(X)
for i, est in enumerate(self.estimators_):
est.fit(X[i], y, **kwargs)
# def wrapper(est, X, y, kwargs):
# est.fit(X, y, **kwargs)
# return est
# self.estimators_ = Parallel(n_jobs=self.n_jobs)(
# delayed(wrapper)(est, X[i], y, kwargs) for i, est in enumerate(self.estimators_))
return self
def transform(self, X: ndarray, **kwargs):
"""
The parameters stored in self are used to convert X into features, and X is filtered through the filter bank to
obtain the eigenvalues of each subband component.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
X : ndarray, shape(n_trials, n_channels, n_samples)
Test the signal.
Returns
-------
feat : ndarray, shape(n_trials, n_fre)
Feature array.
"""
X = self.transform_filterbank(X)
feat = [est.transform(X[i], **kwargs) for i, est in enumerate(self.estimators_)]
# def wrapper(est, X, kwargs):
# retval = est.transform(X, **kwargs)
# return retval
# feat = Parallel(n_jobs=self.n_jobs)(
# delayed(wrapper)(est, X[i], kwargs) for i, est in enumerate(self.estimators_))
feat = np.concatenate(feat, axis=-1)
return feat
def transform_filterbank(self, X: ndarray):
"""
The input signal is filtered through a filter bank.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
X : ndarray, shape(n_trials, n_channels, n_samples)
Input signal.
Returns
-------
Xs: ndarray, shape(Nfb, n_trials, n_channels, n_samples)
Individual subband components of the input signal.
"""
Xs = np.stack([sosfiltfilt(sos, X, axis=-1) for sos in self.filterbank])
return Xs
class FilterBankSSVEP(FilterBank):
"""
Filter bank analysis for SSVEP.
The SSVEP is analyzed using filter banks, that is, multiple filters are combined to decompose the SSVEP signal
into specific segments (subbands containing the original data) and obtain its characteristic data.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
filterbank : list[ndarray]
The filter bank.
base_estimator : class
Estimator for model training and feature extraction.
filterweights : ndarray
Filter weight, default is None.
n_jobs : int
Sets the number of CPU working cores. The default is None.
"""
def __init__(
self,
filterbank: List[ndarray],
base_estimator: BaseEstimator,
filterweights: Optional[ndarray] = None,
n_jobs: Optional[int] = None,
):
self.filterweights = filterweights
super().__init__(base_estimator, filterbank, n_jobs=n_jobs)
def transform(self, X: ndarray): # type: ignore[override]
"""
X is converted into features by using the parameters stored in self, and the eigenvalues of each subband
component are obtained after the input signal is filtered by the filter bank.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
X : ndarray, shape(n_trials, n_channels, n_samples)
Test the signal.
Returns
-------
features : ndarray, shape(n_trials, n_fre)
Feature array.
"""
features = super().transform(X)
if self.filterweights is None:
return features
else:
features = np.reshape(
features, (features.shape[0], len(self.filterbank), -1)
)
return np.sum(
features * self.filterweights[np.newaxis, :, np.newaxis], axis=1
)
def generate_filterbank(
passbands: List[Tuple[float, float]],
stopbands: List[Tuple[float, float]],
srate: int,
order: Optional[int] = None,
rp: float = 0.5,
):
"""
Create a filter bank, that is, obtain a bandpass filter coefficient that can divide the input signal into multiple
subband components.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
passbands : list or tuple(float, float)
Passband parameters.
stopbands : list or tuple(float, float)
Stopband parameters.
srate : float
Sampling rate.
order : int
Filter order.
rp : float
The maximum ripple allowed in the passband below the unit gain is 0.5 by default.
Returns
-------
Filterbankndarray, shape(len(passbands), N, 6)
Filter bank coefficient.
"""
filterbank = []
for wp, ws in zip(passbands, stopbands):
if order is None:
N, wn = cheb1ord(wp, ws, 3, 40, fs=srate)
sos = cheby1(N, rp, wn, btype="bandpass", output="sos", fs=srate)
else:
sos = cheby1(order, rp, wp, btype="bandpass", output="sos", fs=srate)
filterbank.append(sos)
return filterbank
def process(data):
# 白化操作
meanValue = np.mat(data.mean(axis=1))
meanData = np.repeat(meanValue, data.shape[1], axis=1)
whiteTemp = data - meanData
# QR 分解
rankWhiteTemp = whiteTemp.shape[0]
whiteTemp = np.transpose(whiteTemp)
Q, R = qr(whiteTemp.A, mode='economic')
# 计算矩阵的秩
rankQ = linalg.matrix_rank(R)
if rankQ == 0:
raise ValueError('stats:canoncorr:badData')
elif rankQ <= rankWhiteTemp:
# warnings.warn('stats:canoncorr:NotFullRank')
Q = Q[:, 0:rankQ]
return Q, rankQ
def reference(listFreqs,fs, numberSmples, num_harms):
numberFrequence = len(listFreqs)
timeIndex = np.arange(1, numberSmples + 1) / fs # time index
referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples))
for frequenceIndex in range(numberFrequence):
temp = []
for harmIndex in range(1, num_harms + 1):
stimFrequence = listFreqs[frequenceIndex] # in HZ
# Sin and Cos
temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence),
np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)])
referenceTemp = np.mat(temp)
# 白化操作和QR分解
Q, rankQ = process(referenceTemp)
referenceData[frequenceIndex] = np.transpose(Q)
return referenceData
def generate_cca_references(
freqs: Union[ndarray, int, float],
srate,
T,
phases: Optional[Union[ndarray, int, float]] = None,
n_harmonics: int = 1,
):
"""
Construct a sine-cosine reference signal for canonical correlation analysis (CCA).
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
freqs : int or float
Frequency.
srate : int
Sampling rate.
T : int
Sampling time.
phases : int or float
Phase, default is None.
n_harmonics : int
The number of harmonics. The default value is 1.
Returns
-------
Yfndarray, shape(srate*T, n_harmonics*2)
Sine and cosine reference signal.
"""
if isinstance(freqs, int) or isinstance(freqs, float):
freqs = np.array([freqs])
freqs = np.array(freqs)[:, np.newaxis]
if phases is None:
phases = 0
if isinstance(phases, int) or isinstance(phases, float):
phases = np.array([phases])
phases = np.array(phases)[:, np.newaxis]
t = np.linspace(0, T, int(T * srate))
Yf = []
for i in range(n_harmonics):
Yf.append(
np.stack(
[
np.sin(2 * np.pi * (i + 1) * freqs * t + np.pi * phases),
np.cos(2 * np.pi * (i + 1) * freqs * t + np.pi * phases),
],
axis=1,
)
)
Yf = np.concatenate(Yf, axis=1)
return Yf
def sign_flip(u, s, vh=None):
"""Flip signs of SVD or EIG using the method in paper [1]_.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
u: ndarray
left singular vectors, shape (M, K).
s: ndarray
singular values, shape (K,).
vh: ndarray or None
transpose of right singular vectors, shape (K, N).
Returns
-------
u: ndarray
corrected left singular vectors.
s: ndarray
singular values.
vh: ndarray
transpose of corrected right singular vectors.
References
----------
.. [1] https://www.sandia.gov/~tgkolda/pubs/pubfiles/SAND2007-6422.pdf
"""
if vh is None:
total_proj = np.sum(u * s, axis=0)
signs = np.sign(total_proj)
random_idx = signs == 0
if np.any(random_idx):
signs[random_idx] = 1
warnings.warn(
"The magnitude is close to zero, the sign will become arbitrary."
)
u = u * signs
return u, s
else:
left_proj = np.sum(s[:, np.newaxis] * vh, axis=-1)
right_proj = np.sum(u * s, axis=0)
total_proj = left_proj + right_proj
signs = np.sign(total_proj)
random_idx = signs == 0
if np.any(random_idx):
signs[random_idx] = 1
warnings.warn(
"The magnitude is close to zero, the sign will become arbitrary."
)
u = u * signs
vh = signs[:, np.newaxis] * vh
return u, s, vh

View File

@@ -0,0 +1,436 @@
# -*- coding: utf-8 -*-
# DSP: Discriminal Spatial Patterns
# Authors: Swolf <swolfforever@gmail.com>
# Junyang Wang <2144755928@qq.com>
# Last update date: 2022-8-11
# License: MIT License
from typing import Optional, List, Tuple
from itertools import combinations
import numpy as np
from scipy.linalg import eigh
from numpy import ndarray
from scipy.linalg import solve
from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin
def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray:
"""Transform spatial filters to spatial patterns based on paper [1]_.
Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine
information from different channels to extract signals of interest from EEG signals, but if our goal is
neurophysiological interpretation or visualization of weights, activation patterns need to be constructed
from the obtained spatial filters.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
W : ndarray
Spatial filters, shape (n_channels, n_filters).
Cx : ndarray
Covariance matrix of eeg data, shape (n_channels, n_channels).
Cs : ndarray
Covariance matrix of source data, shape (n_channels, n_channels).
Returns
-------
A : ndarray
Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern.
References
----------
.. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging.
Neuroimage 87 (2014): 96-110.
"""
# use linalg.solve instead of inv, makes it more stable
# see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py
# and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html
A = solve(Cs.T, np.dot(Cx, W).T).T
return A
def isPD(B: ndarray) -> bool:
"""Returns true when input matrix is positive-definite, via Cholesky decompositon method.
Parameters
----------
B : ndarray
Any matrix, shape (N, N)
Returns
-------
bool
True if B is positve-definite.
Notes
-----
Use numpy.linalg rather than scipy.linalg. In this case, scipy.linalg has unpredictable behaviors.
"""
try:
_ = np.linalg.cholesky(B)
return True
except np.linalg.LinAlgError:
return False
def nearestPD(A: ndarray) -> ndarray:
"""Find the nearest positive-definite matrix to input.
Parameters
----------
A : ndarray
Any square matrxi, shape (N, N)
Returns
-------
A3 : ndarray
positive-definite matrix to A
Notes
-----
A Python/Numpy port of John D'Errico's `nearestSPD` MATLAB code [1]_, which
origins at [2]_.
References
----------
.. [1] https://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd
.. [2] N.J. Higham, "Computing a nearest symmetric positive semidefinite matrix" (1988):
https://doi.org/10.1016/0024-3795(88)90223-6
"""
B = (A + A.T) / 2
_, s, V = np.linalg.svd(B)
H = np.dot(V.T, np.dot(np.diag(s), V))
A2 = (B + H) / 2
A3 = (A2 + A2.T) / 2
if isPD(A3):
return A3
print("Replace current matrix with the nearest positive-definite matrix.")
spacing = np.spacing(np.linalg.norm(A))
# The above is different from [1]. It appears that MATLAB's `chol` Cholesky
# decomposition will accept matrixes with exactly 0-eigenvalue, whereas
# Numpy's will not. So where [1] uses `eps(mineig)` (where `eps` is Matlab
# for `numpy.spacing`), we use the above definition. CAVEAT: our `spacing`
# will be much larger than [1]'s `eps(mineig)`, since `mineig` is usually on
# the order of 1e-16, and `eps(1e-16)` is on the order of 1e-34, whereas
# `spacing` will, for Gaussian random matrixes of small dimension, be on
# othe order of 1e-16. In practice, both ways converge, as the unit test
# below suggests.
eye = np.eye(A.shape[0])
k = 1
while not isPD(A3):
mineig = np.min(np.real(np.linalg.eigvals(A3)))
A3 += eye * (-mineig * k**2 + spacing)
k += 1
return A3
def xiang_dsp_kernel(
X: ndarray, y: ndarray
) -> Tuple[ndarray, ndarray, ndarray, ndarray]:
"""
DSP: Discriminal Spatial Patterns, only for two classes[1]_.
Import train data to solve spatial filters with DSP,
finds a projection matrix that maximize the between-class scatter matrix and
minimize the within-class scatter matrix. Currently only support for two types of data.
Author: Swolf <swolfforever@gmail.com>
Created on: 2021-1-07
Update log:
Parameters
----------
X : ndarray
EEG train data assuming removing mean, shape (n_trials, n_channels, n_samples)
y : ndarray
labels of EEG data, shape (n_trials, )
Returns
-------
W : ndarray
spatial filters, shape (n_channels, n_filters)
D : ndarray
eigenvalues in descending order
M : ndarray
mean value of all classes and trials, i.e. common mode signals, shape (n_channel, n_samples)
A : ndarray
spatial patterns, shape (n_channels, n_filters)
Notes
-----
the implementation removes regularization on within-class scatter matrix Sw.
References
----------
.. [1] Liao, Xiang, et al. "Combining spatial filters for the classification of single-trial EEG in
a finger movement task." IEEE Transactions on Biomedical Engineering 54.5 (2007): 821-831.
"""
X, y = np.copy(X), np.copy(y)
labels = np.unique(y)
X = np.reshape(X, (-1, *X.shape[-2:]))
X = X - np.mean(X, axis=-1, keepdims=True)
# the number of each label
n_labels = np.array([np.sum(y == label) for label in labels])
# average template of all trials
M = np.mean(X, axis=0)
# class conditional template
Ms, Ss = zip(
*[
(
np.mean(X[y == label], axis=0),
np.sum(
np.matmul(X[y == label], np.swapaxes(X[y == label], -1, -2)), axis=0
),
)
for label in labels
]
)
Ms, Ss = np.stack(Ms), np.stack(Ss)
# within-class scatter matrix
Sw = np.sum(
Ss
- n_labels[:, np.newaxis, np.newaxis] * np.matmul(Ms, np.swapaxes(Ms, -1, -2)),
axis=0,
)
Ms = Ms - M
# between-class scatter matrix
Sb = np.sum(
n_labels[:, np.newaxis, np.newaxis] * np.matmul(Ms, np.swapaxes(Ms, -1, -2)),
axis=0,
)
D, W = eigh(nearestPD(Sb), nearestPD(Sw))
ix = np.argsort(D)[::-1] # in descending order
D, W = D[ix], W[:, ix]
A = robust_pattern(W, Sb, W.T @ Sb @ W)
return W, D, M, A
def xiang_dsp_feature(
W: ndarray, M: ndarray, X: ndarray, n_components: int = 1
) -> ndarray:
"""
Return DSP features in paper [1]_.
Author: Swolf <swolfforever@gmail.com>
Created on: 2021-1-07
Update log:
Parameters
----------
W : ndarray
spatial filters from csp_kernel, shape (n_channels, n_filters)
M : ndarray
common template for all classes, shape (n_channel, n_samples)
X : ndarray
eeg test data, shape (n_trials, n_channels, n_samples)
n_components : int, optional
length of the spatial filters, first k components to use, by default 1
Returns
-------
features: ndarray
features, shape (n_trials, n_components, n_samples)
Raises
------
ValueError
n_components should less than half of the number of channels
Notes
-----
1. instead of meaning of filtered signals in paper [1]_., we directly return filtered signals.
References
----------
.. [1] Liao, Xiang, et al. "Combining spatial filters for the classification of single-trial EEG in
a finger movement task." IEEE Transactions on Biomedical Engineering 54.5 (2007): 821-831.
"""
W, M, X = np.copy(W), np.copy(M), np.copy(X)
max_components = W.shape[1]
if n_components > max_components:
raise ValueError("n_components should less than the number of channels")
X = np.reshape(X, (-1, *X.shape[-2:]))
X = X - np.mean(X, axis=-1, keepdims=True)
# print('************: ',np.shape(W),np.shape(X),np.shape(M))
features = np.matmul(W[:, :n_components].T, X - M)
return features
class DSP(BaseEstimator, TransformerMixin, ClassifierMixin):
"""
DSP: Discriminal Spatial Patterns
Author: Swolf <swolfforever@gmail.com>
Created on: 2021-1-07
Update log:
Parameters
----------
n_components : int
length of the spatial filter, first k components to use, by default 1
transform_method : str
method of template matching, by default corr (pearson correlation coefficient)
classes_ : int
number of the EEG classes
Attributes
----------
n_components : int
length of the spatial filter, first k components to use, by default 1
transform_method : str
method of template matching, by default corr (pearson correlation coefficient)
classes_ : int
number of the EEG classes
W_ : ndarray, shape(n_channels, n_filters)
Spatial filters, shape(n_channels, n_filters), in which n_channels = n_filters
D_ : ndarray, shape(n_filters )
eigenvalues in descending order, shape(n_filters, )
M_ : ndarray, shape(n_channels, n_samples)
mean value of all classes and trials, i.e. common mode signals, shape(n_channels, n_samples)
A_ : ndarray, shape(n_channels, n_filters)
spatial patterns, shape(n_channels, n_filters)
templates_: ndarray, shape(n_classes, n_filters, n_samples)
templates of train data, shape(n_classes, n_filters, n_samples)
"""
def __init__(self, n_components: int = 1, transform_method: str = "corr"):
self.n_components = n_components
self.transform_method = transform_method
def fit(self, X: ndarray, y: ndarray, Yf: Optional[ndarray] = None):
"""
Import the train data to get a model.
Parameters
----------
X : ndarray
train data, shape(n_trials, n_channels, n_samples)
y : ndarray
labels of train data, shape (n_trials, )
Yf : ndarray
optional parameter
Returns
-------
W_ : ndarray
spatial filters, shape (n_channels, n_filters), in which n_channels = n_filters
D_ : ndarray
eigenvalues in descending order, shape (n_filters, )
M_ : ndarray
template for all classes, shape (n_channel, n_samples)
A_ : ndarray
spatial patterns, shape (n_channels, n_filters)
templates_ : ndarray
templates of train data, shape (n_channels, n_filters, n_samples)
"""
X -= np.mean(X, axis=-1, keepdims=True)
self.classes_ = np.unique(y)
self.W_, self.D_, self.M_, self.A_ = xiang_dsp_kernel(X, y)
self.templates_ = np.stack(
[
np.mean(
xiang_dsp_feature(
self.W_, self.M_, X[y == label], n_components=self.W_.shape[1]
),
axis=0,
)
for label in self.classes_
]
)
return self
def transform(self, X: ndarray):
"""
Import the test data to get features.
Parameters
----------
X : ndarray
test data, shape(n_trials, n_channels, n_samples)
Returns
-------
feature : ndarray, shape(n_trials,n_classes)
correlation coefficients of templates of train data and features of test data, shape(n_trials, n_classes)
"""
n_components = self.n_components
X -= np.mean(X, axis=-1, keepdims=True)
features = xiang_dsp_feature(self.W_, self.M_, X, n_components=n_components)
if self.transform_method is None:
return features.reshape((features.shape[0], -1))
elif self.transform_method == "mean":
return np.mean(features, axis=-1)
elif self.transform_method == "corr":
return self._pearson_features(
features, self.templates_[:, :n_components, :]
)
else:
raise ValueError("non-supported transform method")
def _pearson_features(self, X: ndarray, templates: ndarray):
"""
Calculate pearson correlation coefficient.
Parameters
----------
X : ndarray
features of test data after spatial filters, shape(n_trials, n_components, n_samples)
templates : ndarray
templates of train data, shape(n_classes, n_components, n_samples)
Returns
-------
corr : ndarray
pearson correlation coefficient, shape(n_trials, n_classes)
"""
X = np.reshape(X, (-1, *X.shape[-2:]))
templates = np.reshape(templates, (-1, *templates.shape[-2:]))
X = X - np.mean(X, axis=-1, keepdims=True)
templates = templates - np.mean(templates, axis=-1, keepdims=True)
X = np.reshape(X, (X.shape[0], -1))
templates = np.reshape(templates, (templates.shape[0], -1))
istd_X = 1 / np.std(X, axis=-1, keepdims=True)
istd_templates = 1 / np.std(templates, axis=-1, keepdims=True)
corr = (X @ templates.T) / (templates.shape[1] - 1)
corr = istd_X * corr * istd_templates.T
return corr
def predict(self, X: ndarray):
"""
Import the templates and the test data to get prediction labels.
Parameters
----------
X : ndarray
test data, shape(n_trials, n_channels, n_samples)
Returns
-------
labels : ndarray
prediction labels of test data, shape(n_trials,)
"""
feat = self.transform(X)
if self.transform_method == "corr":
labels = self.classes_[np.argmax(feat, axis=-1)]
else:
raise NotImplementedError()
return labels

View File

@@ -0,0 +1,175 @@
# -*- coding: utf-8 -*-
#
# Authors: Swolf <swolfforever@gmail.com>
# Date: 2021/10/10
# License: MIT License
"""
Task Decomposition Component Analysis.
"""
from typing import List
import numpy as np
from scipy.linalg import qr
from scipy.stats import pearsonr
from numpy import ndarray
from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin
from typing import Optional, List
from SSMVEP.algorithm.base import FilterBankSSVEP
from SSMVEP.algorithm.dsp import xiang_dsp_kernel, xiang_dsp_feature
def proj_ref(Yf: ndarray):
Q, R = qr(Yf.T, mode="economic")
P = Q @ Q.T
return P
def aug_2(X: ndarray, n_samples: int, padding_len: int, P: ndarray, training: bool = True):
X = X.reshape((-1, *X.shape[-2:]))
n_trials, n_channels, n_points = X.shape
# if n_points < padding_len + n_samples:
# raise ValueError("the length of X should be larger than l+n_samples.")
aug_X = np.zeros((n_trials, (padding_len + 1) * n_channels, n_samples))
if training:
for i in range(padding_len + 1):
aug_X[:, i * n_channels : (i + 1) * n_channels, :] = X[
..., i : i + n_samples
]
else:
for i in range(padding_len + 1):
aug_X[:, i * n_channels : (i + 1) * n_channels, : n_samples - i] = X[
..., i:n_samples
]
aug_Xp = aug_X @ P
aug_X = np.concatenate([aug_X, aug_Xp], axis=-1)
return aug_X
def tdca_feature(
X: ndarray,
templates: ndarray,
W: ndarray,
M: ndarray,
Ps: List[ndarray],
padding_len: int,
n_components: int = 1,
training=False,
):
rhos = []
for Xk, P in zip(templates, Ps):
a = xiang_dsp_feature(
W,
M,
aug_2(X, P.shape[0], padding_len, P, training=training),
n_components=n_components,
)
b = Xk[:n_components, :]
a = np.reshape(a, (-1))
b = np.reshape(b, (-1))
rhos.append(pearsonr(a, b)[0])
return rhos
class TDCA(BaseEstimator, TransformerMixin, ClassifierMixin):
def __init__(self, padding_len: int, n_components: int = 1):
self.padding_len = padding_len
self.n_components = n_components
def fit(self, X: ndarray, y: ndarray, Yf: ndarray):
X -= np.mean(X, axis=-1, keepdims=True)
self.classes_ = np.unique(y)
self.Ps_ = [proj_ref(Yf[i]) for i in range(len(self.classes_))]
# print(np.shape(self.Ps_))
aug_X_list, aug_Y_list = [], []
for i, label in enumerate(self.classes_):
aug_X_list.append(
aug_2(
X[y == label],
self.Ps_[i].shape[0],
self.padding_len,
self.Ps_[i],
training=True,
)
)
aug_Y_list.append(y[y == label])
aug_X = np.concatenate(aug_X_list, axis=0)
aug_Y = np.concatenate(aug_Y_list, axis=0)
self.W_, _, self.M_, _ = xiang_dsp_kernel(aug_X, aug_Y)
self.templates_ = np.stack(
[
np.mean(
xiang_dsp_feature(
self.W_,
self.M_,
aug_X[aug_Y == label],
n_components=self.W_.shape[1],
),
axis=0,
)
for label in self.classes_
]
)
return self
def transform(self, X: ndarray):
n_components = self.n_components
X -= np.mean(X, axis=-1, keepdims=True)
X = X.reshape((-1, *X.shape[-2:]))
rhos = [
tdca_feature(
tmp,
self.templates_,
self.W_,
self.M_,
self.Ps_,
self.padding_len,
n_components=n_components,
)
for tmp in X
]
rhos = np.stack(rhos)
return rhos
def predict(self, X: ndarray):
feat = self.transform(X)
labels = self.classes_[np.argmax(feat, axis=-1)]
return labels,feat
class FBTDCA(FilterBankSSVEP, ClassifierMixin):
def __init__(
self,
filterbank: List[ndarray],
padding_len: int,
n_components: int = 1,
filterweights: Optional[ndarray] = None,
n_jobs: Optional[int] = None,
):
self.padding_len = padding_len
self.n_components = n_components
self.filterweights = filterweights
self.n_jobs = n_jobs
super().__init__(
filterbank,
TDCA(padding_len, n_components=n_components),
filterweights=filterweights,
n_jobs=n_jobs,
)
def fit(self, X: ndarray, y: ndarray, Yf: Optional[ndarray] = None): # type: ignore[override]
self.classes_ = np.unique(y)
super().fit(X, y, Yf=Yf)
return self
def predict(self, X: ndarray):
features = self.transform(X)
if self.filterweights is None:
features = np.reshape(
features, (features.shape[0], len(self.filterbank), -1)
)
features = np.mean(features, axis=1)
labels = self.classes_[np.argmax(features, axis=-1)]
return labels,features

View File

@@ -0,0 +1,529 @@
# -*- coding: utf-8 -*-
import os
import time
import warnings
from os import error
import numpy as np
import scipy
from numpy.linalg import linalg
from scipy.io import loadmat
from scipy.linalg import qr
from scipy.signal import filtfilt, lfilter
# from numpy.linalg import _umath_linalg
class FbccaDw:
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
print('******************************************')
print('parameter list')
print('target:', num_target)
print('number of filter bank:', num_filter)
print('parameter:', parameter)
print('width:', width)
self.phase = 0
self.bandWidth = width
self.winNum = winNum
self.num_harms = num_harms
self.num_target = num_target
self.num_chans = num_chans
self.winTimeDelay = stimTime
self.fs = fs
self.winMaxSampleNum = self.winTimeDelay * self.winNum * self.fs
self.winDelayNum = round(self.winTimeDelay * self.fs)
self.num_fbs = num_filter
parameterValue = np.power(np.arange(1, self.num_fbs + 1), -(parameter[0])) + parameter[1]
self.weightValue = parameterValue / (sum(parameterValue))
self.dataUseLen = [0] * self.winNum
self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans])
self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans])
self.rhoNum = 2
self.notchZh = [0]
self.filterZf = [0] * self.num_fbs
self.north_b = []
self.north_a = []
self.filterBank_A = []
self.filterBank_B = []
self.winStep = 1
self.DW_cost_method = 'DW11' if method==1 else 'DW1'
'''
filterFrequenceBank根据刺激频率生成的通带和阻带用于滤波器组频带分解
'''
def filterFrequenceBank(self):
# 阻带的最高频率
lastFrequence = 90
freqBandWidth = self.bandWidth[1]
fStep = self.bandWidth[0]
bandFrequence = np.zeros((5, 4))
# 第二列频率带
band = list(range(freqBandWidth, lastFrequence, fStep))
band[:] = [x - 2 for x in band]
colValue = np.maximum(np.asmatrix(band), 1)
bandFrequence[:, 1] = colValue[0, 0:5]
# 第一列频率带
bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1)
# 第三列频率带
bandFrequence[:, 2] = lastFrequence + 2
# 第四列频率带
bandFrequence[:, 3] = bandFrequence[:, 2] + 10
# bandFrequence = np.array([[30,33,77,82],
# [62,68,77,82]])
for idx_fb in range(self.num_fbs):
Nq = self.fs / 2
Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq]
Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq]
[N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3,
40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
[B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency
self.filterBank_A.append(A)
self.filterBank_B.append(B)
# def filterFrequenceBank(self):
# # 阻带的最高频率
# lastFrequence = 90
# freqBandWidth = self.bandWidth[1]
# fStep = self.bandWidth[0]
# bandFrequence = np.zeros((5, 4))
# # 第二列频率带
# band = list(range(freqBandWidth, lastFrequence, fStep))
# band[:] = [x - 2 for x in band]
# colValue = np.maximum(np.asmatrix(band), 1)
# bandFrequence[:, 1] = colValue[0, 0:5]
# # 第一列频率带
# bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1)
# # 第三列频率带
# bandFrequence[:, 2] = lastFrequence + 2
# # 第四列频率带
# bandFrequence[:, 3] = bandFrequence[:, 2] + 10
# for idx_fb in range(self.num_fbs):
# Nq = self.fs / 2
# Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq]
# Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq]
# [N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3,
# 40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
# [B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency
# self.filterBank_A.append(A)
# self.filterBank_B.append(B)
'''
Filter bank analysis
Input:
eeg : Input eeg data (# of targets, # of channels, Data length [sample])
Output:
filterData : Generated filter Data
'''
def filterbank(self, eeg):
filterData = np.zeros((self.num_fbs, eeg.shape[1], eeg.shape[0]))
for filterIndex in range(self.num_fbs):
if np.all(self.filterZf[filterIndex] == 0):
zi = np.zeros(
[max(len(self.filterBank_A[filterIndex]), len(self.filterBank_B[filterIndex])) - 1, self.num_chans])
_, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex],
eeg, zi=zi.T)
Data = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex], eeg)
else:
Data, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex],
self.filterBank_A[filterIndex], eeg,
zi=self.filterZf[filterIndex])
filterData[filterIndex, :, :] = Data.T
return filterData
'''
process
矩阵的白化和QR正则化分解降低矩阵的维度加速计算时间
Input:
data : 输入的二维脑电信号
Output:
Q : 降维后的矩阵
rankQ :正则矩阵的秩
'''
def process(self, data):
# 白化操作
meanValue = np.asmatrix(data.mean(axis=1))
meanData = np.repeat(meanValue, data.shape[1], axis=1)
whiteTemp = data - meanData
# QR 分解
rankWhiteTemp = whiteTemp.shape[0]
whiteTemp = np.transpose(whiteTemp)
Q, R = qr(whiteTemp.A, mode='economic')
# 计算矩阵的秩
rankQ = linalg.matrix_rank(R)
if rankQ == 0:
raise ValueError('stats:canoncorr:badData')
elif rankQ <= rankWhiteTemp:
# warnings.warn('stats:canoncorr:NotFullRank')
Q = Q[:, 0:rankQ]
return Q, rankQ
'''
reference
Input:
listFreqs : 刺激频率列表
numberSmples : 用于分类的脑电信号采样点个数
num_harms : 谐波数
Output:
y_ref : 生成的参考信号 (刺激目标数, 2 * 谐波数, 数据长度/采样点数)
'''
def reference(self, listFreqs, numberSmples, num_harms):
numberFrequence = len(listFreqs)
timeIndex = np.arange(1, numberSmples + 1) / self.fs # time index
referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples))
for frequenceIndex in range(numberFrequence):
temp = []
for harmIndex in range(1, num_harms + 1):
stimFrequence = listFreqs[frequenceIndex] # in HZ
# Sin and Cos
temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence),
np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)])
referenceTemp = np.asmatrix(temp)
# 白化操作和QR分解
Q, rankQ = self.process(referenceTemp)
referenceData[frequenceIndex] = np.transpose(Q)
return referenceData
'''
setNorthFilterPara
陷波器的参数初始化
self.north_b, self.north_a : 陷波器的参数设计
'''
def setNotchFilterPara(self):
# notchFilterNum = 3
# northFreq = 50
# bwDen = 35
# wo = northFreq / (self.fs / 2)
# bw = wo / bwDen
# self.north_b, self.north_a = iirnotch(wo, Q=35) # self.north_b, self.north_a = iircomb(northFreq, bwDen, 'notch')
# # n倍零极点相当于重复滤波n次
# if notchFilterNum > 1:
# z, p, k = tf2zpk(self.north_b, self.north_a)
# zNew = np.repeat(z, notchFilterNum, axis=0)
# zNew[1], zNew[4] = zNew[4], zNew[1]
# pNew = np.repeat(p, notchFilterNum, axis=0)
# pNew[1], pNew[4] = pNew[4], pNew[1]
# kNew = np.power(k, notchFilterNum)
# self.north_b, self.north_a = zpk2tf(zNew, pNew, kNew)
self.north_b = [0.94801603944125245604368501517456, -1.7577184027642647201616910024313,
3.9303778338832491279219993884908, -3.7392330345967859095424046245171,
3.9303778338832482397435796883656, -1.7577184027642638319832713023061,
0.94801603944125156786526531504933]
self.north_a = [1, -1.8214007435820627200939725298667, 4.0000101767406484043476666556671,
-3.7380998614928691026193519064691, 3.8589119784285759173769747576443,
-1.6951692350503837491970671180752, 0.89786559147978006745205448169145]
'''
northFilter
进行信号的50hz陷波处理
Input:
data :输入脑电数据
Output:
dataFiltered : 陷波处理后的脑电数据
'''
def northFilter(self, data):
try:
if np.all(self.notchZh[0] == 0):
zi = np.zeros([max(len(self.north_a), len(self.north_b)) - 1, self.num_chans])
_, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=zi.T)
dataFiltered = lfilter(self.north_b, self.north_a, data)
else:
dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0])
return np.asmatrix(dataFiltered)
except Exception:
print(Exception)
'''
getDataQ
Inputs:
data脑电数据
Rbuffer待更新的中间系数
Output:
Qs1 : 脑电特征1
Qs2 : 脑电特征2
Rbuffer : 单窗口更新后的系数
'''
def getDataQ(self, data, Rbuffer):
Qs1 = [0] * self.num_fbs
Qs2 = [0] * self.num_fbs
nulldata = np.zeros([self.num_chans, self.num_chans])
Rnum = self.num_chans
for fb_num in range(self.num_fbs):
fb_data = np.squeeze(data[fb_num, :, :])
if np.all(Rbuffer[fb_num] == 0):
whiteTemp = fb_data
Q, R = qr(whiteTemp, mode='economic')
Qs1[fb_num] = nulldata
Qs2[fb_num] = Q
Rbuffer[fb_num] = R
else:
whiteTemp = np.concatenate((Rbuffer[fb_num], fb_data), axis=0)
Q, R = qr(whiteTemp, mode='economic')
Qs1[fb_num] = Q[0:Rnum, :]
Qs2[fb_num] = Q[Rnum:, :]
Rbuffer[fb_num] = R
return Qs1, Qs2, Rbuffer
'''
myCCA根据脑电特征和参考信号计算相关系数
Inputs:
dataQ脑电特征
Qc2y参考信号
d 相关系数取值数
Output:
rho : 相关系数
'''
def myCCA(self, dataQ, Qc2y, d):
if len(Qc2y) == 0:
Cov = dataQ
else:
Cov = np.dot(Qc2y, dataQ)
# U, S, V = scipy.linalg.svd(Cov, 0)
# rho = np.minimum(np.maximum(np.diag(S[0: d]).T, 0), 1)
# gufunc = _umath_linalg.svd_n
# rho = gufunc(Cov)
rho = np.linalg.svd(Cov, compute_uv=False)
return rho[0:d]
'''
weightCCA计算分类标签
Inputs:
Qs1脑电特征1
Qs2脑电特征2
ref 正余弦参考信号
Cxy 协方差中间参数
Output:
result : 分类标签
rho : 相关系数
Cxy : 更新后的协方差中间参数
'''
def weightCCA(self, Qs1, Qs2, ref, Cxy):
rMax = np.zeros([self.num_fbs, self.num_target])
for fi in range(self.num_fbs):
for si in range(self.num_target):
Qc2y = np.squeeze(ref[si, :, :])
# 更新协方差矩阵
if np.all(Cxy[fi][si] == 0):
Cxy[fi, si] = np.dot(Qc2y, Qs2[fi])
else:
Cxy[fi, si] = np.dot(Cxy[fi, si], Qs1[fi]) + np.dot(Qc2y, Qs2[fi])
r = self.myCCA(Cxy[fi, si], [], self.rhoNum)
rMax[fi, si] = r[0]
rho = np.dot(self.weightValue, np.power(rMax, 2)) # weighted sum of r from all different filter banks' result
result = np.argmax(rho)
return result, rho, Cxy
'''
costF损失函数根据计算的相关系数生成决策值用于和阈值进行比较
Inputs:
rho相关系数
method相关系数计算参数
C 参数
Output:
decideValue : 决策阈值
'''
def costF(self, rho, method, C):
rho = rho.tolist()
rho.sort(reverse=True)
if method == 'DW1':
decideValue = (rho[0] - rho[1]) / (sum(rho) - self.num_target * np.log(sum(np.exp(rho))))
elif method == 'DW11':
decideValue = -(rho[0] - rho[1])
elif method == 'DW2':
decideValue = (rho[0] - C) / (rho[1] - rho[0])
return decideValue
'''
onlineInit将窗口长度相位值、中间参数初始化
'''
def onlineInit(self):
self.dataUseLen = [0] * self.winNum
self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans])
self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans])
self.phase = 0
'''
filterInit重置陷波器和滤波器的滤波参数
'''
def filterInit(self):
self.notchZh = [0]
self.filterZf = [0] * self.num_fbs
'''
warmFilter预热滤波器去刺激前的4S数据对陷波器和滤波器参数进行初始化迭代去除过渡带的效果
Inputs:
data预处理脑电数据
'''
def warmFilter(self, data):
# 降采样在采集前完成
temp = self.preprocessFilter(data) #预热陷波滤波器
# 滤波器组频带分解
filterData = self.filterbank(temp) #预热滤波器组
'''
myDownSample数据降采样
Inputs:
data脑电数据
n降采样的倍数
Output:
eegData2 : 降采样后的数据
'''
def myDownSample(self, data, n):
data = data[:8, self.phase:]
dataNum = data.shape[1]
remainNum = (dataNum - 1) % n
self.phase = n - 1 - remainNum
dataDowmSample = []
for value in data:
value = value[0:value.size:n]
dataDowmSample.append(value)
eegData2 = np.array(dataDowmSample).reshape([8, int(np.round(data.shape[1] / n))])
return eegData2
'''
preprocessFilter预处理调用函数降采样和陷波处理
Inputs:
data脑电数据
Output:
filterData : 降采样和陷波后的数据
'''
def preprocessFilter(self, data):
# data = self.myDownSample(data, 4)
# filterData = self.northFilter(data[:8, :])
filterData = self.northFilter(data[:, :])
return filterData
'''
fbccaDWMW分类函数对输入的脑电信号进行识别输出决策标签
Inputs:
testdata脑电数据
referenceData参考信号
tValue出决策阈值
Output:
res : 决策标签
rho_new相关系数
minEps得到的决策阈值
'''
# 动态窗算法主函数
def fbccaDWMW(self, testdata, referenceData, tValue,calculateCount):
t1 = time.time()
# try:
# 初始参数
res = -1
minEps = float("inf")
# 降采样和陷波器处理
northData = self.preprocessFilter(testdata)
newSampleNum = northData.shape[1]
# 数据大于延迟长度,则无法根据后面的规则更新窗口
if newSampleNum > self.winDelayNum:
error('need add window delay time')
# 防止秩小于导联数
if newSampleNum < self.num_chans:
warnings.warn('data shape is [%d %d] need more data' % (newSampleNum, northData.shape[0]))
# 滤波器组频带分解
filterData = self.filterbank(northData)
winMinTime = 0
# 计算每个窗口的结果
for wi in range(0, self.winNum, self.winStep):
# print('dataUseLen:',wi,calculateCount, self.dataUseLen)
if wi == 0:
self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum
else:
if self.dataUseLen[wi] == 0:
# 判断当前窗是否为新的窗口(因为每一次新的窗口进来时都会使上一个窗口datauseLen>50)
if self.dataUseLen[wi - self.winStep] > self.winDelayNum*self.winStep:
self.dataUseLen[wi] = newSampleNum
else:
# print('中断: ',wi,calculateCount)
break
else:
self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum
if self.dataUseLen[wi] > self.winMaxSampleNum:
self.dataUseLen[wi] = newSampleNum
self.Rbuffer[wi, :, :, :] = 0
self.Cxy[wi, :, :, :, :] = 0
Qs1, Qs2, self.Rbuffer[wi, :, :, :] = self.getDataQ(filterData, self.Rbuffer[wi, :, :, :])
si = self.dataUseLen[wi] - newSampleNum
ei = self.dataUseLen[wi]
ref = referenceData[:, :, si:ei]
# 更新协方差
predLabel_new, rho_new, self.Cxy[wi, :, :, :] = self.weightCCA(Qs1, Qs2, ref, self.Cxy[wi, :, :, :])
# 增加限制,数据长度不能太短
if self.dataUseLen[wi] > winMinTime * self.fs:
epsilon = self.costF(rho_new, self.DW_cost_method, C=0)
if epsilon < minEps:
minEps = epsilon
predLabel = predLabel_new
xxx = rho_new
if minEps < tValue:
res = predLabel
if time.time() - t1 > 0.2 and self.winStep < 16:
self.winStep = self.winStep * 2
# print(self.winStep, " ", time.time() - t1)
# if res != -1:
# print('--------------------- ',res,xxx,' --------------------------')
return res
if __name__ == '__main__':
# The number of sub-bands in filter bank analysis
fs = 250
num_chans = 8
num_target = 40
num_filterBank = 3
num_harm = 5
stimTime = 0.2 # 多窗口窗长
winNum = 50 # 窗口的个数
trials = 1
step = 50
res = -1
list_freqs = [8.0, 8.2, 8.4, 8.6, 8.8, 9.0, 9.2, 9.4, 9.6, 9.8, 10.0, 10.2, 10.4, 10.6, 10.8, 11.0, 11.2, 11.4,
11.6, 11.8, 12.0, 12.2, 12.4, 12.6, 12.8, 13., 13.2, 13.4, 13.6, 13.8, 14., 14.2, 14.4, 14.6, 14.8,
15.0, 15.2, 15.4, 15.6, 15.8]
# 初始化对象
dw = FbccaDw(fs, num_target, num_chans, num_filterBank, num_harm, stimTime, [1.0, 0.3], [8, 8], winNum)
# frequenceband
dw.filterFrequenceBank()
referenceData = dw.reference(list_freqs, int(winNum * stimTime * fs), num_harm)
dw.setNotchFilterPara()
prelabels = np.zeros((1, 40))
coefficient = np.zeros([1, 1])
path = "D:\\工作相关\\项目代码\\SDK封装与测试\\八神BCI脑电信号SDK20230105\\双鹰SDK对比\\offline_data250\\"
for index in range(1, trials + 1):
D = loadmat(os.path.join(path + str(1) + '-warmData.mat'))
warmData = D['warmData']
dw.onlineInit()
dw.filterInit()
dw.warmFilter(warmData.T)
tagget_i = 0
for tagget_i in range(1, step + 1):
D = loadmat(os.path.join(path + str(1) + '-' + str(tagget_i) + '.mat'))
dataSlice = D['dataTemp']
res = dw.fbccaDWMW(dataSlice.T, referenceData, tValue=-0.2)
if res != -1:
break
prelabels[0, index - 1] = res + 1
print(index, '--', res + 1," 计算轮数", tagget_i)

View File

@@ -0,0 +1,851 @@
import matplotlib
matplotlib.use('Agg')
import os
import io
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from scipy.spatial import Delaunay
from scipy.interpolate import Rbf
from scipy.signal import welch
from scipy.stats import sem
from scipy.signal import butter, filtfilt, hilbert
import base64
# 位置坐标
def read_ch_pos(file_path=r'xy_64.xlsx'):
"""
将电极位置信息转换为Dict
参数:
file_path: 电极位置存储文件, 必须包含'channel', 'x', 'y', 'z'
"""
script_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(script_dir,file_path )
df = pd.read_excel(file_path)
# 确保列名正确
if not all(col in df.columns for col in ['channel', 'x', 'y', 'z']):
raise ValueError("DataFrame必须包含'channel', 'x', 'y', 'z'")
# 创建电极位置字典
ch_pos = {}
for _, row in df.iterrows():
ch_pos[row['channel']] = [row['x'], row['y'], row['z']]
return ch_pos
# 头部轮廓
def draw_head(ax, center=(0, 0), radius=1.0, zorder=4):
"""
绘制头部轮廓、鼻子和耳朵。
参数:
- ax : matplotlib Axes 对象
- center : (x, y) 头中心坐标
- radius : float, 头半径
- zorder : 绘制层级
"""
# 头圆
head = plt.Circle(center, radius, fill=False, color='k', linewidth=1, zorder=zorder)
ax.add_artist(head)
# 鼻子(参考 _make_head_outlines
dx = np.exp(np.arccos(np.deg2rad(12)) * 1j)
dx_real, dx_imag = dx.real, dx.imag
nose_x = np.array([-dx_real, 0, dx_real]) * radius + center[0]
nose_y = np.array([dx_imag, 1.15, dx_imag]) * radius + center[1]
ax.plot(nose_x, nose_y, color='k', linewidth=1, zorder=zorder)
# 耳朵(参考 _make_head_outlines 手动标定)
ear_radius = radius * 0.12
ear_scale = radius * 2 # 根据半径缩放
theta = np.linspace(np.pi / 2, 3 * np.pi / 2, 30)
# 左耳
left_ear_x_array = np.array([0.497, 0.510, 0.518, 0.5299, 0.5419,
0.54, 0.547, 0.532, 0.510, 0.489]) * ear_scale
left_ear_y_array = np.array([0.0555, 0.0775, 0.0783, 0.0746, 0.0555,
-0.0055, -0.0932, -0.1313, -0.1384, -0.1199]) * ear_scale + center[1]
ax.plot(center[0] - left_ear_x_array, left_ear_y_array, color='k', linewidth=1, zorder=zorder)
# 右耳
right_ear_x_array = np.array([0.497, 0.510, 0.518, 0.5299, 0.5419,
0.54, 0.547, 0.532, 0.510, 0.489]) * ear_scale
right_ear_y_array = np.array([0.0555, 0.0775, 0.0783, 0.0746, 0.0555,
-0.0055, -0.0932, -0.1313, -0.1384, -0.1199]) * ear_scale + center[1]
ax.plot(center[0] + right_ear_x_array, right_ear_y_array, color='k', linewidth=1, zorder=zorder)
# 地形图 插值
def rbf_D_interpolate(xy, v, center=(0, 0), radius=1.1, grid_res=300,
n_extra=32, rbf_func='multiquadric', smooth=0,
border='mean', border_scale=1.0001, n_ngb=4):
"""
使用 RBF + Delaunay 邻域均值方式生成平滑的 EEG topomap 插值表面。
参数
----
xy : (N,2) array
电极二维坐标(与绘图坐标系一致)
v : (N,) array
每个电极对应的值e.g. PSD
center : tuple (x0, y0)
头部圆心(默认 (0,0)
radius : float
头部半径(用于生成边界点与网格范围)
grid_res : int
网格分辨率(每轴点数)
n_extra : int
边界虚拟点数量
rbf_func : str
RBF 内核名称('multiquadric','thin_plate','gaussian',...
smooth : float
RBF 平滑参数
border : 'mean' or float
'mean':边界点用邻近真实通道均值赋值(推荐)
若 float边界点赋相同常数值
border_scale : float
边界点半径相对 radius 的缩放(略微 >1 用以外推)
n_ngb : int
为每个边界点取值时使用的最近真实通道数
返回
----
zi : (grid_res, grid_res) ndarray
插值结果(与 grid_x, grid_y 对齐)
grid_x, grid_y : ndarrays
meshgrid由 np.meshgrid 生成)
"""
xy = np.asarray(xy)
v = np.asarray(v)
if xy.ndim != 2 or xy.shape[1] != 2:
raise ValueError("xy must be shape (n_channels, 2)")
n_points = xy.shape[0]
# --- 1. 生成边界虚拟点(圆周) ---
theta = np.linspace(0.0, 2 * np.pi, n_extra, endpoint=False)
r_border = radius * border_scale
border_xy = np.column_stack([center[0] + r_border * np.cos(theta),
center[1] + r_border * np.sin(theta)])
# --- 2. 用 Delaunay 建图以便找到邻居(对边界点取邻居均值) ---
# 合并用于三角化的位置(真实点 + 边界点)
tri_xy = np.vstack([xy, border_xy])
tri = Delaunay(tri_xy)
# --- 3. 为边界点赋值 ---
if isinstance(border, str) and border == 'mean':
# 使用 Delaunay 的 vertex_neighbor_vertices 索引
# 注意tri.vertex_neighbor_vertices 给出 vertices -> neighbor indptr
indices, indptr = tri.vertex_neighbor_vertices
v_extra = np.zeros(n_extra)
used = np.zeros(n_extra, dtype=bool)
# 边界点在 tri_xy 中的索引范围
rng = range(n_points, n_points + n_extra)
for idx, extra_idx in enumerate(rng):
neigh = indptr[indices[extra_idx]:indices[extra_idx + 1]]
# 仅保留原始点索引(小于 n_points
neigh = neigh[neigh < n_points]
if neigh.size > 0:
used[idx] = True
# 使用最近 n_ngb 个邻居的均值(若邻居多则取最近的 n_ngb
if neigh.size > n_ngb:
# 计算距离并选取最近 n_ngb
d = np.linalg.norm(xy[neigh] - tri_xy[extra_idx], axis=1)
order = np.argsort(d)[:n_ngb]
sel = neigh[order]
else:
sel = neigh
v_extra[idx] = v[sel].mean()
if not used.all() and used.any():
v_extra[~used] = np.mean(v_extra[used])
elif not used.any():
v_extra[:] = np.mean(v)
else:
# border 是数值
v_extra = np.full(n_extra, float(border))
# --- 4. 合并所有已知点并构建 RBF ---
all_xy = np.vstack([xy, border_xy])
all_v = np.concatenate([v, v_extra])
rbf = Rbf(all_xy[:, 0], all_xy[:, 1], all_v,
function=rbf_func, smooth=smooth)
# --- 5. 生成网格(使用 meshgrid与主函数保持一致 ---
xmin, xmax = center[0] - radius, center[0] + radius
ymin, ymax = center[1] - radius, center[1] + radius
xi = np.linspace(xmin, xmax, grid_res)
yi = np.linspace(ymin, ymax, grid_res)
grid_x, grid_y = np.meshgrid(xi, yi) # meshgrid 与 imshow 对齐
# --- 6. 评估 RBF返回与 grid 对齐的 zi ---
zi = rbf(grid_x, grid_y)
return zi, grid_x, grid_y
# plv矩阵计算
def calculate_plv(data):
"""
计算相位锁定值PLV矩阵。
Parameters
----------
data : ndarray, shape (num_channels, num_samples)
EEG 数据,通道数为 num_channels样本数为 num_samples。
Returns
-------
plv_matrix : ndarray, shape (num_channels, num_channels)
计算得到的 PLV 矩阵,表示各通道间的相位同步。
"""
num_channels, num_samples = data.shape
plv_matrix = np.zeros((num_channels, num_channels))
# 计算每个通道的解析信号
analytic_signals = np.apply_along_axis(hilbert, axis=1, arr=data)
for i in range(num_channels):
for j in range(i + 1, num_channels): # 只计算上三角矩阵,避免重复计算
# 计算 phase difference
phase_diff = np.angle(analytic_signals[i] * np.conj(analytic_signals[j]))
plv = np.abs(np.mean(np.exp(1j * phase_diff)))
plv_matrix[i, j] = plv
plv_matrix[j, i] = plv # 对称矩阵
return plv_matrix
# 矩阵阈值化
def threshold_proportional(adj, prop=0.2):
"""
Apply a proportional threshold to retain the top proportion of strongest edges.
Parameters
----------
adj : ndarray, shape (n_channels, n_channels)
Adjacency matrix to threshold.
prop : float
Proportion of edges to retain (0 < prop <= 1).
Returns
-------
bin_adj : ndarray, shape (n_channels, n_channels)
Binary adjacency matrix after thresholding.
"""
n = adj.shape[0]
triu_idx = np.triu_indices(n, k=1)
weights = adj[triu_idx]
k = int(np.floor(len(weights) * prop))
# Ensure that at least one edge is retained
k = max(k, 1)
# Get the threshold value
thr = np.sort(weights)[-k]
# Apply the threshold to create a binary adjacency matrix
bin_adj = np.where(adj >= thr, adj, 0.0)
return bin_adj
# 单个脑网络
def plot_single_network(ch_names,adj,ax=None,
node_size=20, node_color='orange',highlight_nodes=[], show_names=True,
edge_color='gray', weighted=True,
radius=1.1, figsize=(6, 6),cmap='RdYlBu_r'):
# 若 ax 未传入,则自己创建
own_fig = False
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
own_fig = True
else:
fig = ax.figure
# 坐标归一化
pos3d = read_ch_pos()
all_chs_xy = np.array([pos3d[ch][:2] for ch in pos3d.keys()])
all_chs_xy -= all_chs_xy.mean(axis=0)
all_chs_xy /= np.sqrt((all_chs_xy ** 2).sum(axis=1)).max()
xy_dict = dict(zip(pos3d.keys(), all_chs_xy))
xy = np.array([xy_dict[ch] for ch in ch_names])
center = xy_dict.get('CZ', np.mean(list(xy_dict.values()), axis=0))
# ===== 初始化绘图窗口 =====
ax.set_aspect('equal')
ax.axis('off')
# 设置边界(与原类保持一致)
ear_radius = radius * 0.12
nose_height = radius * 0.15
margin_x = radius * 0.12 + 0.05
ax.set_xlim(center[0] - radius - margin_x, center[0] + radius + margin_x)
ax.set_ylim(center[1] - radius - ear_radius, center[1] + radius + nose_height + ear_radius)
# 绘制头部轮廓
draw_head(ax, center=center, radius=radius)
# 节点
for ch in ch_names:
color = 'red' if ch in highlight_nodes else node_color
ax.scatter(*xy_dict[ch], s=node_size, color=color, edgecolor='k', zorder=4)
if show_names:
ax.text(xy_dict[ch][0], xy_dict[ch][1] + 0.03, ch,
ha='center', va='bottom', fontsize=8, zorder=5)
# colorbar
norm = mcolors.Normalize(vmin=0, vmax=1)
color_map = matplotlib.colormaps.get_cmap(cmap)
# ========= 边 ==========
N = len(ch_names)
for i in range(N):
for j in range(i + 1, N):
w = adj[i, j]
if w > 0:
x = [xy[i, 0], xy[j, 0]]
y = [xy[i, 1], xy[j, 1]]
lw = 1.5
if weighted:
ax.plot(x, y,
color=color_map(norm(w)),
linewidth=lw,
alpha=0.7,
zorder=3)
else:
ax.plot(x, y,
color=edge_color,
linewidth=lw,
alpha=0.7,
zorder=3)
if own_fig:
# 不回传 添加颜色条
sm = cm.ScalarMappable(norm=norm, cmap=color_map)
cbar = plt.colorbar(sm, ax=ax, fraction=0.035)
cbar.set_label('Connection Strength', fontsize=10)
cbar.ax.tick_params(direction='in', labelsize=10)
plt.show()
return fig
else:
return ax
# 脑网络对比
def plot_multiband_network(ch_names, adj_MI, adj_Rest,cmap='RdYlBu_r'):
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
fontsize = 16
fig.text(0.285, 0.08, 'MI', fontsize=fontsize, ha='center', va='center', rotation=0)
fig.text(0.68, 0.08, 'Rest', fontsize=fontsize, ha='center', va='center', rotation=0)
im1 = plot_single_network(ch_names,adj_MI,ax=axes[0], show_names=True,cmap=cmap)
# Rest 行
im2 = plot_single_network(ch_names,adj_Rest,ax=axes[1],show_names=True,cmap=cmap)
# --- 合并 colorbar右侧一个 ---
norm = mcolors.Normalize(vmin=0, vmax=1)
color_map = matplotlib.colormaps.get_cmap(cmap)
sm = cm.ScalarMappable(norm=norm, cmap=color_map)
cbar = plt.colorbar(sm, ax=axes.ravel().tolist(), fraction=0.02)
cbar.set_label('Connection Strength', fontsize=10)
cbar.ax.tick_params(direction='in', labelsize=10)
# 将图像保存到内存字节流PNG 格式)
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
plt.close(fig) # 释放内存
buf.seek(0)
image_bytes = buf.read()
buf.close()
return image_bytes
# 多个频带psd
def compute_band_psd(eeg, fs, bands, labels, trial_idx=0,MI_label=1, Rest_label=2,avg = True):
"""
eeg: (n_trials, n_channels, n_samples)
"""
n_trials, n_channels, n_samples = eeg.shape
band_names = list(bands.keys())
n_bands = len(band_names)
psd_MI = np.zeros((n_bands, n_channels))
psd_Rest = np.zeros((n_bands, n_channels))
# 先计算所有 trial 的功率谱
f, Pxx = welch(eeg, fs=fs, axis=-1, nperseg=fs,noverlap = fs // 2)
for bi, (bname, (f1, f2)) in enumerate(bands.items()):
idx = np.logical_and(f >= f1, f <= f2)
band_power = Pxx[:, :, idx].mean(axis=-1)
band_power_flat = band_power.flatten()
power_min = band_power_flat.min()
power_max = band_power_flat.max()
if power_max - power_min > 1e-12:
band_power_norm = (band_power - power_min) / (power_max - power_min)
else:
band_power_norm = band_power
if avg:
psd_MI[bi] = band_power_norm[labels == MI_label].mean(axis=0)
psd_Rest[bi] = band_power_norm[labels == Rest_label].mean(axis=0)
else:
psd_MI[bi] = band_power_norm[labels == MI_label][trial_idx]
psd_Rest[bi] = band_power_norm[labels == Rest_label][trial_idx]
return band_names, psd_MI, psd_Rest
# 单个脑地形图
def plot_single_topomap(ch_names, psd_values, cmap='RdYlBu_r', vlim=(0, 1),
show_names=True, node_size=3, radius=1.1, grid_res=300,
n_contours=None, contour_color='k',
ax=None,figsize=(6,6)):
# 若 ax 未传入,则自己创建
own_fig = False
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
own_fig = True
else:
fig = ax.figure
# ===== 初始化绘图窗口 =====
ax.set_aspect('equal')
ax.axis('off')
# ax.set_title("EEG topomap (MNE-like)")
# 坐标归一化
pos3d = read_ch_pos()
all_chs_xy = np.array([pos3d[ch][:2] for ch in pos3d.keys()])
all_chs_xy -= all_chs_xy.mean(axis=0)
all_chs_xy /= np.sqrt((all_chs_xy ** 2).sum(axis=1)).max()
pos2d_dict = dict(zip(pos3d.keys(), all_chs_xy))
xy = np.array([pos2d_dict[ch] for ch in ch_names])
center = pos2d_dict.get('CZ', np.mean(list(pos2d_dict.values()), axis=0))
# 绘制头部轮廓
draw_head(ax, center=center, radius=radius)
# 绘制电极
fontsize = 4
ax.scatter(xy[:, 0], xy[:, 1], c='k', s=node_size, zorder=5)
if show_names:
for i, ch in enumerate(ch_names):
ax.text(xy[i, 0], xy[i, 1] + 0.03, ch,
ha='center', va='bottom', fontsize=fontsize, zorder=6)
# 数据插值
zi, grid_x, grid_y = rbf_D_interpolate(
xy, psd_values, radius=radius,
grid_res=grid_res
)
xmin, xmax = center[0] - radius, center[0] + radius
ymin, ymax = center[1] - radius, center[1] + radius
extent = (xmin, xmax, ymin, ymax)
im = ax.imshow(zi, extent=extent, origin='lower',
cmap=cmap, vmin=vlim[0], vmax=vlim[1],
interpolation='bicubic', zorder=0)
# 裁剪路径
patch_ = Ellipse(center, 2 * radius, 2 * radius, clip_on=True, transform=ax.transData)
im.set_clip_path(patch_)
# 初始等高线
linewidths = 0.5
if n_contours is None:
cset = ax.contour(grid_x, grid_y, zi,
colors=contour_color, linewidths=linewidths, zorder=2)
else:
cset = ax.contour(grid_x, grid_y, zi, levels=n_contours,
colors=contour_color, linewidths=linewidths, zorder=2)
cset.set_clip_path(patch_)
if own_fig:
# 不回传 添加颜色条
plt.colorbar(im, ax=ax, fraction=0.035)
plt.show()
return fig
else:
# plt.colorbar(im, ax=ax, fraction=0.035)
return im
# 脑地形图对比
def plot_multiband_topomaps(ch_names, psd_MI, psd_Rest, bands):
band_names = list(bands.keys()) # 改动 1新增这行
n_bands = len(band_names)
fig, axes = plt.subplots(2, n_bands, figsize=(3*n_bands, 6))
fontsize = 16
axes[0, 0].text(-0.1, 0.5, 'MI', transform=axes[0, 0].transAxes, rotation=0, va='center', ha='center', fontsize=fontsize-2)
axes[1, 0].text(-0.1, 0.5, 'Rest', transform=axes[1, 0].transAxes, rotation=0, va='center', ha='center', fontsize=fontsize-2)
imgs = []
for i, bname in enumerate(band_names):
axes[0, i].set_title(bname, fontsize=fontsize, pad=0)
# MI 行
im1 = plot_single_topomap(ch_names,psd_MI[i],ax=axes[0, i], show_names=True)
# Rest 行
im2 = plot_single_topomap(ch_names,psd_Rest[i],ax=axes[1, i],show_names=True)
imgs.append(im1)
# --- 单个右侧合并 colorbar ---
cbar = fig.colorbar(imgs[0], ax=axes,fraction=0.02)
# cbar.set_label("PSD Power",fontsize=fontsize-4)
cbar.ax.tick_params(direction='in', labelsize=10)
# 将图像保存到内存字节流PNG 格式)
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
plt.close(fig) # 释放内存
buf.seek(0)
image_bytes = buf.read()
buf.close()
return image_bytes
# 小波
def morlet_wavelet(f, fs, n_cycles=7):
"""
创建 Morlet 小波
f: 频率
fs: 采样率
"""
sigma_t = n_cycles / (2 * np.pi * f)
t = np.arange(-3*sigma_t, 3*sigma_t, 1/fs)
wavelet = (np.pi**-0.25) * np.exp(2j*np.pi*f*t) * np.exp(-(t**2)/(2*sigma_t**2))
return wavelet
# 希尔伯特变换 计算ERDS 效果不佳
def bandpass_filter(data, fs, band, order=4):
nyq = fs / 2
b, a = butter(order, [band[0]/nyq, band[1]/nyq], btype='band')
return filtfilt(b, a, data, axis=-1)
def compute_power_hilbert(filtered_data,is_dB =True):
analytic = hilbert(filtered_data, axis=-1)
power = np.abs(analytic) ** 2
if is_dB:
power = 10 * np.log10(power)
return power
def compute_power(data, fs=250,
bands={"mu": (8,12), "beta": (13,30)}):
"""
返回:
power_dict[band] = (n_trials, n_ch, n_samples)
"""
power_dict = {}
for band_name, band_range in bands.items():
filt = bandpass_filter(data, fs, band_range)
power = compute_power_hilbert(filt)
power_dict[band_name] = power
return power_dict
def compute_erds(power_MI, power_Rest, baseline_period=None):
"""
计算事件相关去同步/同步 (ERDS)
Parameters
----------
power_MI, power_Rest: (n_trials, n_ch, n_samples)
功率数据,单位为 µV² 或 dB取决于 compute_power_hilbert 的 is_dB 参数)
baseline_period: tuple (start_idx, end_idx) or None
基线时间段索引。如果为None使用 Rest 状态的平均值作为基线
返回:
MI_erds_mean, MI_erds_sem
Rest_erds_mean, Rest_erds_sem
所有返回值的形状为 (n_ch, n_samples)
"""
if baseline_period is not None:
start_idx, end_idx = baseline_period
baseline = np.concatenate([power_MI[:, :, start_idx:end_idx],
power_Rest[:, :, start_idx:end_idx]], axis=0)
baseline = baseline.mean(axis=(0, 2), keepdims=True)
else:
baseline = power_Rest.mean(axis=(0,2), keepdims=True)
# === ERDS (%) ===
MI_erds = (power_MI - baseline) / baseline * 100
Rest_erds = (power_Rest - baseline) / baseline * 100
return (
MI_erds.mean(axis=0), sem(MI_erds, axis=0),
Rest_erds.mean(axis=0), sem(Rest_erds, axis=0),
)
def compute_all_erds(MI_power_dict, Rest_power_dict):
"""
对多个频带同时计算 ERDS。
输入:
MI_power_dict[band] = (n_trials, n_ch, n_samples)
Rest_power_dict[band] = (n_trials, n_ch, n_samples)
输出:
erds_MI[band] = (mean, sem)
erds_Rest[band] = (mean, sem)
"""
erds_MI = {}
erds_Rest = {}
for band in MI_power_dict.keys():
MI_power = MI_power_dict[band]
Rest_power = Rest_power_dict[band]
MI_mean, MI_sem, Rest_mean, Rest_sem = compute_erds(MI_power, Rest_power)
erds_MI[band] = (MI_mean, MI_sem)
erds_Rest[band] = (Rest_mean, Rest_sem)
return erds_MI, erds_Rest
def plot_compare_erds(data_MI, data_Rest, mode="power",
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'],
compare_names=['C3', 'CZ', 'C4'], bands=['mu', 'beta'],
fs=250, t=None, figsize=(12,6)):
n_bands = len(bands)
n_chs = len(compare_names)
# 自动添加单位
if mode == "power":
# y_unit = "Power (µV²)"
y_unit = "Power (dB)"
elif mode == "erds":
y_unit = "ERDS (%)"
else:
y_unit = ""
if t is None:
n_samples = next(iter(data_MI.values())).shape[-1] \
if mode=="power" else next(iter(data_MI.values()))[0].shape[-1]
t = np.arange(n_samples) / fs
fig, axes = plt.subplots(n_bands, n_chs, figsize=figsize, sharex=True, sharey=True)
for i, band in enumerate(bands):
# 选择数据结构
if mode == "power":
MI_band = data_MI[band] # (trials, ch, samples)
Rest_band = data_Rest[band]
avg_MI = MI_band.mean(axis=0)
sem_MI = MI_band.std(axis=0)/np.sqrt(MI_band.shape[0])
avg_Rest = Rest_band.mean(axis=0)
sem_Rest = Rest_band.std(axis=0)/np.sqrt(Rest_band.shape[0])
elif mode == "erds":
avg_MI, sem_MI = data_MI[band]
avg_Rest, sem_Rest = data_Rest[band]
for j, ch in enumerate(compare_names):
ax = axes[i, j] if n_bands > 1 else axes[j]
ch_idx = ch_names.index(ch)
# 绘制 MI
ax.plot(t, avg_MI[ch_idx], color="C0", label="MI")
ax.fill_between(t,
avg_MI[ch_idx]-sem_MI[ch_idx],
avg_MI[ch_idx]+sem_MI[ch_idx],
alpha=0.3, color="C0")
# 绘制 Rest
ax.plot(t, avg_Rest[ch_idx], color="C1", label="Rest")
ax.fill_between(t,
avg_Rest[ch_idx]-sem_Rest[ch_idx],
avg_Rest[ch_idx]+sem_Rest[ch_idx],
alpha=0.3, color="C1")
if i == 0:
ax.set_title(ch)
# ← Y 轴加单位
if j == 0:
ax.set_ylabel(f"{band}\n{y_unit}")
if i == n_bands - 1:
ax.set_xlabel("Time (s)")
ax.grid(alpha=0.3)
if i == 0 and j == n_chs - 1:
ax.legend()
plt.tight_layout()
# 将图像保存到内存字节流PNG 格式)
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
plt.close(fig) # 释放内存
buf.seek(0)
image_bytes = buf.read()
buf.close()
return image_bytes
# 对比 MI vs Rest 的功率谱密度 PSD
def plot_psd_compare(MI_data, Rest_data, ch_names, compare_names=['C3', 'CZ', 'C4'],
fs=250, nperseg=None, average=True, show_sem=True,
figsize=(12, 3), save_dir=None, filename="psd.png"):
"""
对比 MI vs Rest 的功率谱密度 PSD
MI_data, Rest_data: (n_trials, n_ch, n_samples)
channels: 需要绘制的通道
average: 是否对所有试次平均
show_sem: 是否绘制 SEM 阴影
"""
n_trials, n_ch, n_samples = MI_data.shape
n_trials = min(len(MI_data), len(Rest_data))
# assert Rest_data.shape == MI_data.shape, "MI 和 Rest 数据维度必须一致"
if nperseg is None:
nperseg = fs # 每 1 秒窗长度
# 计算 MI PSD
psd_MI_all = []
for trial in range(n_trials):
psd_trial = []
for ch in range(n_ch):
f, Pxx = welch(MI_data[trial, ch], fs=fs, nperseg=nperseg)
psd_trial.append(Pxx)
psd_MI_all.append(psd_trial)
psd_MI_all = np.array(psd_MI_all)
# 计算 Rest PSD
psd_Rest_all = []
for trial in range(n_trials):
psd_trial = []
for ch in range(n_ch):
_, Pxx = welch(Rest_data[trial, ch], fs=fs, nperseg=nperseg)
psd_trial.append(Pxx)
psd_Rest_all.append(psd_trial)
psd_Rest_all = np.array(psd_Rest_all)
# ---- Plot ----
fig, ax = plt.subplots(1, len(compare_names), figsize=figsize)
if len(compare_names) == 1:
ax = [ax]
for i, ch in enumerate(compare_names):
ch_idx = ch_names.index(ch)
psd_MI_ch = psd_MI_all[:, ch_idx, :]
psd_Rest_ch = psd_Rest_all[:, ch_idx, :]
if average:
mean_MI = psd_MI_ch.mean(axis=0)
mean_Rest = psd_Rest_ch.mean(axis=0)
ax[i].plot(f, mean_MI, color='C0', label='MI')
ax[i].plot(f, mean_Rest, color='C1', label='Rest')
if show_sem:
ax[i].fill_between(f, mean_MI - sem(psd_MI_ch, axis=0),
mean_MI + sem(psd_MI_ch, axis=0), color='C0', alpha=0.3)
ax[i].fill_between(f, mean_Rest - sem(psd_Rest_ch, axis=0),
mean_Rest + sem(psd_Rest_ch, axis=0), color='C1', alpha=0.3)
else:
ax[i].plot(f, psd_MI_ch.T, color='C0', alpha=0.3)
ax[i].plot(f, psd_Rest_ch.T, color='C1', alpha=0.3)
ax[i].set_title(ch)
ax[i].set_xlabel("Frequency (Hz)")
ax[i].set_ylabel("PSD (μV²/Hz)")
ax[i].grid(alpha=0.3)
if i == 0:
ax[i].legend()
plt.tight_layout()
# 将图像保存到内存字节流PNG 格式)
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
plt.close(fig) # 释放内存
buf.seek(0)
image_bytes = buf.read()
buf.close()
return image_bytes
def plotMain(
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'],
compare_names = [ 'C3','CZ','C4'],
Data = None,labels = None,MI_label = None,Rest_label = None,
fs = 250):
trial_idx = 0
# 数据划分
if not MI_label:
label_ = np.unique(labels)
else:
label_ = (MI_label,Rest_label)
MI_data = Data[labels == label_[0]]
Rest_data = Data[labels == label_[1]]
# 典型 EEG 频带
FREQ_BANDS = {
"Delta (0.8-4Hz)": (0.8, 4),
"Theta (4-8Hz)": (4, 8),
"Alpha (8-12Hz)": (8, 12),
"Beta (12-30Hz)": (12, 30),
"All (0.8-30Hz)": (0.8, 30)
}
# 利用welch估算PSD
band_names, psd_MI, psd_Rest= compute_band_psd(
eeg=Data,
fs=fs,
bands=FREQ_BANDS,
labels=labels,
trial_idx=trial_idx,
MI_label=MI_label,
Rest_label=Rest_label,
avg= True
)
# 绘制地形图
topomaps_imgBytes = plot_multiband_topomaps(
ch_names=ch_names,
psd_MI=psd_MI,
psd_Rest=psd_Rest,
bands=FREQ_BANDS
)
# 绘制脑网络
mi_plv_matrix = calculate_plv(MI_data[trial_idx])
mi_BI_matrix = threshold_proportional(mi_plv_matrix, prop=0.3)
rest_plv_matrix = calculate_plv(Rest_data[trial_idx])
rest_BI_matrix = threshold_proportional(rest_plv_matrix, prop=0.3)
network_imgBytes = plot_multiband_network(ch_names, mi_BI_matrix, rest_BI_matrix)
# ERDS 先计算erds后平均
MI_power = compute_power(MI_data)
Rest_power = compute_power(Rest_data)
erds_dict_MI, erds_dict_Rest = compute_all_erds(MI_power, Rest_power)
erds_imgBytes = plot_compare_erds(erds_dict_MI, erds_dict_Rest, ch_names=ch_names,
compare_names=compare_names, bands=['mu', 'beta'],
fs=fs, mode="erds")
# 绘制PSD
psd_imgBytes = plot_psd_compare(MI_data, Rest_data, ch_names = ch_names, compare_names=compare_names,
fs=fs, nperseg=None, average=True, show_sem=True,
figsize=(12, 3))
return {'topomaps_imgBytes':base64.b64encode(topomaps_imgBytes).decode(),'network_imgBytes':base64.b64encode(network_imgBytes).decode(),
'erds_imgBytes':base64.b64encode(erds_imgBytes).decode(),'psd_imgBytes':base64.b64encode(psd_imgBytes).decode()}
if __name__ == '__main__':
allData = np.random.uniform(-50,50,size=(80,21,1000))
allLabel = np.random.randint(1,3,size=(80,))
allData = allData[:len(allLabel)]
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
compare_names = ['C3', 'CZ', 'C4']
ret = plotMain(ch_names=ch_names, compare_names=compare_names, Data=allData, labels=allLabel, MI_label=1, Rest_label=2,
fs=250)
print('计算完成,开始发送')
from Zmq.zmqClient import zmqClient
zmqClient = zmqClient('192.168.76.101', 8088)
zmqClient.connect()
zmqClient.send_to_all('miReport', ret)

View File

@@ -0,0 +1,68 @@
import threading
import time
import json
import zmq
class zmqClient:
def __init__(self, host, port):
self.host = host
self.port = port
self.client_socket = None
self.running = False
self.zmq_server = None # Reference to zmqServer for Unity communication
# 记录客户端连接前的状态
self.state = {
'status_code': None,
'energy': None
}
def set_zmq_server(self, server):
"""Set the zmqServer instance to forward messages to Unity"""
self.zmq_server = server
def connect(self):
# 创建 ZeroMQ 上下文
self.context = zmq.Context()
# 创建 REQ 套接字(请求端)
self.client_socket = self.context.socket(zmq.DEALER)
# client_id = b'client1'
# self.client_socket.setsockopt(zmq.IDENTITY,client_id)
self.client_socket.connect(f"tcp://{self.host}:{self.port}") # 连接到服务器
self.running = True
def send_to_all(self, method,params):
if method in self.state.keys():
self.state[method] = params
# Also send to Unity via zmqServer if connected
if self.zmq_server:
self.zmq_server.broadcast_message(method, params)
try:
if self.running and self.client_socket != None:
msg = {'method': method, 'params': params}
if method in ['single_trial_plot', 'miReport']:
print(f"{{'method': '{method}', 'params': <Base64 Image Data>}}")
else:
print(msg)
self.client_socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
else:
if method in self.state.keys():
self.state[method] = params
except ConnectionResetError:
print("Connection lost.")
self.running = False
except Exception as e:
print(f"An error occurred: {e}")
def close_connection(self):
self.running = False
self.client_socket.close()
self.context.term()
print("Client closed explicitly.")
# 使用TCP客户端
if __name__ == "__main__":
client = zmqClient('127.0.0.1', 8099)
client.connect()
# client.close_connection()

View File

@@ -0,0 +1,149 @@
import numpy as np
import zmq
import threading
import json
import queue
from Device.SunnyLinker import SunnyLinker64
class zmqServer(threading.Thread):
def __init__(self, host='0.0.0.0', port=8099):
threading.Thread.__init__(self)
self.host = host
self.port = port
self.running = False
self.get_Impedance = False # 是否返回阻抗值
self.open_Impedance = None # 是否开启阻抗检测功能
self.StartDecode = False # false 停止解码true=开始解码
self.StartTrain = False # False未进入训练状态True处于训练状态
self.state_mode = None # 'train'为训练状态rest'为休息状态,'test'为测试状态
self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签
self.IsExitApp = False # 当socket收到2的时候就置为True代表要退出系统了。
self.getReport = False # 获取训练报告内容
self.daemon = True
# 创建 ZeroMQ 上下文
self.context = zmq.Context()
# 创建 REP 套接字(响应端)
self.socket = self.context.socket(zmq.ROUTER)
self.socket.bind(f"tcp://{self.host}:{self.port}") # 绑定到端口 8099
self.targetFreqs = []
self.changeTarget = False # 更换目标频率
self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类已在Decoder实例化
self.labels = [0x01, 0x02,0x03]
self.decoder_switch = False #更换解码器
self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi'
# Client Management (e.g. Unity, Other listeners)
self.clients = set() # 维护客户端ID
self.send_queue = queue.Queue() # 发送队列安全信箱维护socket线程
def broadcast_message(self, method, params):
"""Put message into queue to be sent to all connected clients"""
self.send_queue.put((method, params))
def run(self):
self.running = True
print(f"Server is running on {self.host}:{self.port}")
# Use Poller for non-blocking receive
poller = zmq.Poller()
poller.register(self.socket, zmq.POLLIN)
try:
while self.running:
# 1. Process Send Queue (Send to all clients)
while not self.send_queue.empty():
method, params = self.send_queue.get()
if self.clients:
try:
msg = {'method': method, 'params': params}
msg_bytes = json.dumps(msg).encode('utf-8')
if method in ['single_trial_plot', 'single_trial_plot', 'miReport']:
print(f"{{'method': '{method}', 'params': <Base64 Image Data>}}")
else:
print(f"Sending message: {msg}")
# Broadcast to all maintained clients
for client_id in list(self.clients):
try:
# Send: [ID, Empty, JSON]
self.socket.send_multipart([client_id, b'', msg_bytes])
except Exception as e:
print(f"Error sending to client {client_id}: {e}")
except Exception as e:
print(f"Error preparing broadcast: {e}")
# 2. Process Receive (Commands)
socks = dict(poller.poll(10)) # 100ms timeout
if self.socket in socks and socks[self.socket] == zmq.POLLIN:
frames = self.socket.recv_multipart()
if len(frames) < 3:
continue
ident, _, message_bytes = frames[:3]
if ident not in self.clients: # register client ID
self.clients.add(ident)
print(f"New Client Detected: {ident}")
try:
message = json.loads(message_bytes.decode('utf-8'))
except json.JSONDecodeError:
continue
print(f"Received request: {message}")
method = message.get("method") # process request
params = message.get("params")
if method == "sync":
self.state_mode = 'sync'
if method == "targetFreqs":
if not isinstance(params,list):
print('targetFreqs must be a list')
continue
if params != self.targetFreqs:
self.targetFreqs = params
self.changeTarget = True
if method == "decoderClass":
if not isinstance(params,str):
print('decoderClass must be a str')
continue
if params != self.decoder_class:
self.decoder_class = params
self.decoder_switch = True
if method == "getReport":
self.getReport = True
if method == "train":#训练状态
self.state_mode = 'train'
self.StartTrain = True
self.currentLabel = params # 当前刺激端的训练标签
self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
elif method == "predict":#预测状态
self.state_mode = 'predict'
if params == 1: #开始解码
self.StartDecode = True
self.sunnyLinker.push_trigger(0x63)
elif params == 2: #停止解码
self.IsExitApp = True
self.running = False
elif method == "rest": #休息状态
self.state_mode = 'rest'
elif method == "impedance":
if params == 1:
self.open_Impedance = True # 开启阻抗
self.get_Impedance = True # 返回阻抗
elif params == 2:
self.open_Impedance = False # 关闭阻抗
self.get_Impedance = False # 停止返回阻抗
except Exception as e:
print(f"An socket error occurred: {e}")
finally:
self.running = False
# 关闭套接字和上下文
self.socket.close()
self.context.term()
print("Server socket and context closed.")
def stop(self):
"""显式关闭服务器"""
self.running = False
self.socket.close()
self.context.term()
print("Server closed explicitly.")
if __name__ == '__main__':
server = zmqServer()
server.start()

View File

@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 29 16:14:17 2025
@author: 23749
"""
import numpy as np
from scipy.signal import butter, filtfilt
## 1.Bandpass Filter
def butter_bandpass(lowcut, highcut, fs, order=4):
# 滤波器
nyq = 0.5 * fs #ny:Nyquist频率即能表示的最大有效频率
low = lowcut / nyq
high = highcut / nyq
b, a = butter(order, [low, high], btype='band') #巴特沃斯滤波器order=4阶
return b, a
def bandpass_filter(data, lowcut, highcut, fs, order=4):
b, a = butter_bandpass(lowcut, highcut, fs, order)
return filtfilt(b, a, data)
## 2.Eye Blink Dectection
def blink_detection(F, fs, Dmin, Dmax, Emin, Emax):
"""
波形检测
输入: 差分特征向量 F, 采样率 fs
输出: b (0/1), 以及计算出的 d, e
"""
if F is None or len(F) < 3:
return 0, None, None
# 找最大时间(peak) & 最小时间(valley)
t_peak = np.argmax(F)
t_valley = np.argmin(F)
# 要求 peak 在 valley 之前(符合 blink 形态),否则交换
if t_valley < t_peak:
t_peak, t_valley = t_valley, t_peak
# 计算持续时间 d (ms)
d = (t_valley - t_peak) * 1000.0 / fs
# 计算能量 e (差分平方和)
e = np.sum(F[t_peak:t_valley + 1] ** 2)
# 阈值判定
if Dmin <= d <= Dmax and Emin <= e <= Emax:
b = 1 # 检测到眨眼
else:
b = 0 # 否则 no blink
return b, d, e
if __name__ == '__main__':
import matplotlib.pyplot as plt
fs = 250 # 采样率
t = np.arange(0, 5, 1/fs)
eog = 0.01 * np.random.randn(len(t)) # 基线+噪声
# 模拟眨眼(在 2.0s 注入脉冲)
center = int(2.0 * fs)
eog[center:center+5] += 0.5
eog[center+5:center+15] -= 0.4
# 测试 blink_detection
F = np.diff(eog)
b, d, e = blink_detection(F, fs, 70, 500, 0.1, 10)
print(f"Detected: {b}, Duration: {d}ms, Energy: {e}")

View File

@@ -0,0 +1,98 @@
# -*- mode: python ; coding: utf-8 -*-
import sys
import os
from PyInstaller.utils.hooks import collect_submodules, collect_data_files
# ========================================================
# 1. 工程配置区 (Project Config)
# ========================================================
block_cipher = None
ENTRY_POINT = 'runDecoder.py'
APP_NAME = 'runDecoder'
# ========================================================
# 2. 依赖分析 (Dependency Analysis)
# ========================================================
# 收集 sklearn, scipy 可能遗漏的隐藏导入
hidden_imports = [
'sklearn.utils._cython_blas',
'sklearn.neighbors.typedefs',
'sklearn.neighbors.quad_tree',
'sklearn.tree',
'sklearn.tree._utils',
'einops', # 必须显式添加
]
# 收集 torch 相关的隐式导入
hidden_imports += ['torch', 'torchvision']
# 收集 pandas 相关的隐式导入
hidden_imports += ['pandas']
# ========================================================
# 3. 资源锚定 (Data Anchoring)
# ========================================================
# Analysis 中的 datas 用于将文件嵌入到内部
datas = []
# ========================================================
# 4. 构建流程 (Build Process)
# ========================================================
a = Analysis(
[ENTRY_POINT],
pathex=[],
binaries=[],
datas=datas,
hiddenimports=hidden_imports,
hookspath=[],
hooksconfig={},
runtime_hooks=['rthook.py'], # 添加运行时钩子,处理路径和多进程
excludes=['tkinter', 'PyQt5', 'PySide2', 'IPython', 'notebook'], # 排除 GUI 和交互式库减小体积
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
[],
exclude_binaries=True,
name=APP_NAME,
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=False,
console=True, # 保持 True 以便查看日志,部署时可改为 False
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)
# ========================================================
# 5. 打包模式: OneDir (单文件夹) + 资源旁路
# ========================================================
# 使用 Tree 将文件夹原样复制到 dist/APP_NAME/ 下
# 格式: Tree('源路径', prefix='目标子目录')
coll = COLLECT(
exe,
a.binaries,
a.zipfiles,
a.datas,
# 显式复制资源文件夹到 exe 同级目录
Tree('online_Models', prefix='online_Models', excludes=['*.pyc']),
Tree('Tools', prefix='Tools', excludes=['*.pyc']),
# config.ini 作为单独文件
[('config.ini', 'config.ini', 'DATA')],
strip=False,
upx=False,
upx_exclude=[],
name=APP_NAME,
)

View File

@@ -0,0 +1,88 @@
import os
import shutil
import subprocess
import sys
def main():
# 1. 定义路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DIST_DIR = os.path.join(BASE_DIR, 'dist')
APP_NAME = 'runDecoder'
TARGET_DIR = os.path.join(DIST_DIR, APP_NAME)
# 定义需要复制的资源 {源路径: 目标子路径}
# 目标子路径相对于 TARGET_DIR
RESOURCES = {
'config.ini': 'config.ini',
'online_Models': 'online_Models',
'Tools': 'Tools',
}
# 2. 清理旧构建
print("[1/3] Cleaning up old builds...")
if os.path.exists(DIST_DIR):
try:
shutil.rmtree(DIST_DIR)
print(" Cleaned dist/")
except Exception as e:
print(f" Warning: Could not clean dist/: {e}")
BUILD_DIR = os.path.join(BASE_DIR, 'build')
if os.path.exists(BUILD_DIR):
try:
shutil.rmtree(BUILD_DIR)
print(" Cleaned build/")
except Exception as e:
print(f" Warning: Could not clean build/: {e}")
# 3. 运行 PyInstaller
print("[2/3] Running PyInstaller...")
# 注意:我们这里不传 --noupx因为已经在 spec 文件里把 upx=False 写死了
cmd = [
"pyinstaller",
"build_algorithm.spec",
"--clean"
]
try:
subprocess.check_call(cmd, shell=True)
except subprocess.CalledProcessError:
print("Error: PyInstaller failed.")
sys.exit(1)
# 4. 复制外部资源文件夹
print("[3/3] Verifying and Copying external resources...")
for src_name, dst_name in RESOURCES.items():
src_path = os.path.join(BASE_DIR, src_name)
dst_path = os.path.join(TARGET_DIR, dst_name)
if os.path.exists(src_path):
if os.path.isfile(src_path):
# 如果是文件
try:
shutil.copy2(src_path, dst_path)
print(f" Copied file: {src_name} -> {dst_name}")
except Exception as e:
print(f" Error copying file {src_name}: {e}")
else:
# 如果是文件夹
if os.path.exists(dst_path):
try:
shutil.rmtree(dst_path) # 先删除 spec 生成的旧文件夹 (如果有)
except Exception as e:
print(f" Warning: Could not remove existing dir {dst_path}: {e}")
try:
shutil.copytree(src_path, dst_path, ignore=shutil.ignore_patterns('*.pyc', '__pycache__'))
print(f" Copied dir: {src_name} -> {dst_name}")
except Exception as e:
print(f" Error copying dir {src_name}: {e}")
else:
print(f" Warning: Source resource not found at {src_path}")
print("\n" + "="*50)
print(f"SUCCESS! Build artifacts are in: {TARGET_DIR}")
print("="*50)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,396 @@
import numpy as np
from scipy.signal import welch
from scipy.fft import fft
from scipy import signal
from collections import deque
import time
import os
# import logging
import base64
import io
# logger = logging.getLogger(__name__)
#
# try:
# import matplotlib
# matplotlib.use('Agg')
# import matplotlib.pyplot as plt
# MATPLOTLIB_AVAILABLE = True
# except ImportError:
# MATPLOTLIB_AVAILABLE = False
# logger.warning("matplotlib未安装报告图表功能不可用")
class Calculate():
def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10):
self.Threshold_value_low = Threshold_value_low
self.Threshold_value_high = Threshold_value_high
self.fs = fs
self.focus_result = []
self.CLI_result = []
self.EVI_result = []
self.eegQueue = deque(maxlen=win_len)
# # 存储历史数据用于绘图
# self.beta_history = []
# self.alpha_history = []
# self.theta_history = []
# self.focus_history = []
# self.timestamp_history = []
#
# # 记录开始时间
# self.start_time = None
# self.recording = False
#
# # 图表保存路径
# self.chart_dir = "reports"
# if not os.path.exists(self.chart_dir):
# os.makedirs(self.chart_dir)
# print(f"[调试] 创建目录: {self.chart_dir}")
# 初始化滤波器
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30)
self.b_design = signal.firwin(65, [2 / (self.fs/2), 40 / (self.fs/2)], pass_zero=False)
print("[调试] Calculate 类初始化完成")
def calculate_focus(self, beta, alpha, theta):
"""
专注度计算 - 固定映射版本
"""
# 原始比值
raw = beta / (alpha + theta + 1e-10)
# Sigmoid 映射:让 raw 在 0.3-1.5 区间敏感
# 参数可调:
# k = 12 (斜率,越大越陡)
# x0 = 0.6 (中心点raw=0.6时focus≈50)
k = 12.0
x0 = 0.6
focus = 100.0 / (1.0 + np.exp(-k * (raw - x0)))
# 可选:添加滑动平均平滑
return int(focus)
def calculate_all(self, data, fs, nperseg=1000):
mean_x = np.mean(data, axis=-1, keepdims=True)
data = data - mean_x
freqs, psd = self.compute_psd_multichannel(data, fs, nperseg)
beta_psd = np.sum(self.band_psd(freqs, psd, (13, 30)))
alpha_psd = np.sum(self.band_psd(freqs, psd, (8, 13)))
theta_psd = np.sum(self.band_psd(freqs, psd, (4, 8)))
print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}")
focus_score = self.calculate_focus(beta_psd, alpha_psd, theta_psd)
focus_score = max(0, min(100, focus_score))
self.focus_result.append(focus_score)
if len(self.focus_result) > 3:
self.focus_result.pop(0)
final_focus = int(self.simple_moving_average(self.focus_result, window_size=5))
cli_denom = alpha_psd + beta_psd
CLI_score = np.log(theta_psd / (cli_denom + 1e-10)) if cli_denom > 0 else 0
self.CLI_result.append(CLI_score)
if len(self.CLI_result) > 5:
self.CLI_result.pop(0)
final_CLI = round(self.simple_moving_average(self.CLI_result, window_size=5), 2)
return final_focus, final_CLI, beta_psd, alpha_psd, theta_psd
def compute_psd_multichannel(self, data, fs=250, nperseg=1000):
n_samples = data.shape[-1]
if n_samples < nperseg:
nperseg = n_samples
noverlap = 500
if noverlap >= nperseg:
noverlap = int(nperseg / 2)
if nperseg == 0:
return np.array([]), np.zeros((data.shape[0], 0))
freqs, psd = welch(data, fs=fs, nperseg=nperseg, noverlap=noverlap, axis=-1)
return freqs, psd
def band_psd(self, freqs, psd, band):
idx = np.logical_and(freqs >= band[0], freqs <= band[1])
return np.sum(psd[:, idx], axis=-1)
def simple_moving_average(self, data, window_size=5):
if len(data) == 0:
return 30
window = data[-window_size:]
return sum(window) / len(window)
def reset_queue(self):
self.eegQueue.clear()
# def start_recording(self):
# """开始记录数据"""
# self.recording = True
# self.start_time = time.time()
# self.beta_history = []
# self.alpha_history = []
# self.theta_history = []
# self.focus_history = []
# self.timestamp_history = []
# print("[调试] ========== 开始记录专注度数据 ==========")
# def stop_recording(self):
# """停止记录并生成图表"""
# print(f"[调试] stop_recording被调用, recording={self.recording}, focus_history长度={len(self.focus_history)}")
# self.recording = False
# if len(self.focus_history) > 0:
# print("[调试] 数据非空,开始生成图表...")
# # 保存到本地文件
# chart_path = self.save_chart_to_file()
# if chart_path:
# print(f"[调试] 本地文件保存成功: {chart_path}")
# else:
# print("[调试] 本地文件保存失败")
# # 生成base64编码
# base64_data = self.generate_chart_base64()
# return base64_data
# else:
# print("[调试] 没有数据可保存focus_history为空")
# return None
# def add_data_point(self, focus, beta, alpha, theta):
# if not self.recording:
# return
# current_time = time.time()
# elapsed = current_time - self.start_time
#
# self.beta_history.append(beta)
# self.alpha_history.append(alpha)
# self.theta_history.append(theta)
# self.focus_history.append(focus)
# self.timestamp_history.append(elapsed)
# print(f"[调试] 记录数据点: time={elapsed:.1f}s, focus={focus}, beta={beta:.2f}")
# def save_chart_to_file(self):
# """
# 保存图表到本地文件(唯一实现)
# """
# print(f"[调试] save_chart_to_file被调用, MATPLOTLIB_AVAILABLE={MATPLOTLIB_AVAILABLE}")
#
# if not MATPLOTLIB_AVAILABLE:
# print("[调试] matplotlib不可用无法保存")
# return None
#
# if len(self.focus_history) < 2:
# print(f"[调试] 数据点不足需要至少2个点当前{len(self.focus_history)}个点")
# return None
#
# print(f"[调试] 开始保存图表到本地文件...")
#
# # 确保所有列表长度一致
# min_len = min(len(self.beta_history), len(self.alpha_history),
# len(self.theta_history), len(self.focus_history),
# len(self.timestamp_history))
#
# print(f"[调试] 数据长度: min_len={min_len}")
#
# beta_list = self.beta_history[:min_len]
# alpha_list = self.alpha_history[:min_len]
# theta_list = self.theta_history[:min_len]
# focus_list = self.focus_history[:min_len]
# times = self.timestamp_history[:min_len]
#
# # 生成文件名
# timestamp = time.strftime("%Y%m%d_%H%M%S")
# chart_path = os.path.join(self.chart_dir, f"concentration_report_{timestamp}.png")
# print(f"[调试] 保存路径: {chart_path}")
#
# try:
# # 创建图表
# fig, ax1 = plt.subplots(figsize=(14, 8))
#
# # 左Y轴功率数据
# ax1.plot(times, beta_list, 'b-', linewidth=1.5, alpha=0.8, label='Beta Power')
# ax1.plot(times, alpha_list, 'g-', linewidth=1.5, alpha=0.8, label='Alpha Power')
# ax1.plot(times, theta_list, 'orange', linewidth=1.5, alpha=0.8, label='Theta Power')
# ax1.set_xlabel('Time (s)', fontsize=12)
# ax1.set_ylabel('Band Power', fontsize=12, color='black')
# ax1.tick_params(axis='y', labelcolor='black')
# ax1.legend(loc='upper left')
# ax1.grid(True, alpha=0.3)
#
# # 右Y轴专注度
# ax2 = ax1.twinx()
# ax2.plot(times, focus_list, 'r-', linewidth=2, alpha=0.9, label='Focus (%)')
# ax2.set_ylabel('Focus (%)', fontsize=12, color='red')
# ax2.tick_params(axis='y', labelcolor='red')
# ax2.set_ylim(0, 105)
# ax2.legend(loc='upper right')
#
# # 标题
# duration = times[-1] if times else 0
# avg_focus = np.mean(focus_list) if focus_list else 0
# plt.title(f'Concentration and EEG Band Power Trend\nDuration: {duration:.1f}s, Avg Focus: {avg_focus:.1f}%',
# fontsize=14)
#
# plt.tight_layout()
# plt.savefig(chart_path, dpi=150, bbox_inches='tight')
# plt.close()
#
# print(f"\n{'='*60}")
# print(f"专注度报告图片已保存到本地:")
# print(f" 文件路径: {chart_path}")
# print(f" 数据点数: {min_len}")
# print(f" 时长: {duration:.1f}秒")
# print(f" 平均专注度: {avg_focus:.1f}%")
# print(f"{'='*60}\n")
#
# return chart_path
#
# except Exception as e:
# print(f"[调试] 保存文件时出错: {e}")
# import traceback
# traceback.print_exc()
# return None
#
# def generate_chart_base64(self):
# """
# 生成图表的base64编码用于网络传输
# """
# if not MATPLOTLIB_AVAILABLE:
# return None
#
# if len(self.focus_history) < 2:
# return None
#
# min_len = min(len(self.beta_history), len(self.alpha_history),
# len(self.theta_history), len(self.focus_history),
# len(self.timestamp_history))
#
# beta_list = self.beta_history[:min_len]
# alpha_list = self.alpha_history[:min_len]
# theta_list = self.theta_history[:min_len]
# focus_list = self.focus_history[:min_len]
# times = self.timestamp_history[:min_len]
#
# fig, ax1 = plt.subplots(figsize=(14, 8))
#
# ax1.plot(times, beta_list, 'b-', linewidth=1.5, alpha=0.8, label='Beta Power')
# ax1.plot(times, alpha_list, 'g-', linewidth=1.5, alpha=0.8, label='Alpha Power')
# ax1.plot(times, theta_list, 'orange', linewidth=1.5, alpha=0.8, label='Theta Power')
# ax1.set_xlabel('Time (s)', fontsize=12)
# ax1.set_ylabel('Band Power', fontsize=12, color='black')
# ax1.tick_params(axis='y', labelcolor='black')
# ax1.legend(loc='upper left')
# ax1.grid(True, alpha=0.3)
#
# ax2 = ax1.twinx()
# ax2.plot(times, focus_list, 'r-', linewidth=2, alpha=0.9, label='Focus (%)')
# ax2.set_ylabel('Focus (%)', fontsize=12, color='red')
# ax2.tick_params(axis='y', labelcolor='red')
# ax2.set_ylim(0, 105)
# ax2.legend(loc='upper right')
#
# duration = times[-1] if times else 0
# avg_focus = np.mean(focus_list) if focus_list else 0
# plt.title(f'Concentration and EEG Band Power Trend\nDuration: {duration:.1f}s, Avg Focus: {avg_focus:.1f}%',
# fontsize=14)
#
# plt.tight_layout()
#
# buffer = io.BytesIO()
# plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
# buffer.seek(0)
# image_base64 = base64.b64encode(buffer.read()).decode('utf-8')
# plt.close()
#
# return image_base64
def queueOpt(self, data):
if data is None or data.size == 0:
return None
if len(self.eegQueue) < self.eegQueue.maxlen:
self.eegQueue.append(data)
else:
self.eegQueue.append(data)
if len(self.eegQueue) == self.eegQueue.maxlen:
eegData = np.hstack([self.eegQueue[i] for i in range(len(self.eegQueue))])
if eegData.size == 0:
return None
eegData -= np.mean(eegData, axis=-1, keepdims=True)
eegData = signal.lfilter(self.b_notch, self.a_notch, eegData)
eegData = signal.lfilter(self.b_design, 1, eegData)
focus_score, CLI_score, beta, alpha, theta = self.calculate_all(eegData, fs=self.fs, nperseg=1000)
# self.add_data_point(focus_score, beta, alpha, theta)
return focus_score
return None
class Calculate2():
def __init__(self, Threshold_value_low, Threshold_value_high):
self.Threshold_value_low = Threshold_value_low
self.Threshold_value_high = Threshold_value_high
self.focus_result = []
self.theta_result = []
self.alpha_result = []
self.flow_result = []
def calculate_all(self, data, fs, L=2500):
mean_x = np.mean(data, axis=-1, keepdims=True)
data = data - mean_x
Y = fft(data, axis=-1)
P2 = np.abs(Y / L)
P1 = P2[:, :L // 2 + 1]
P1[:, 1:-1] = 2 * P1[:, 1:-1]
beta_power = self.PSD(P1, L, fs, 13, 30)
alpha_power = self.PSD(P1, L, fs, 8, 13)
theta_power = self.PSD(P1, L, fs, 4, 8)
gamma_power = self.PSD(P1, L, fs, 30, 100)
focus_score = beta_power / (alpha_power + theta_power)
print('focus score:', focus_score)
focus_score = ((focus_score - self.Threshold_value_low) * 100) / (self.Threshold_value_high - self.Threshold_value_low)
self.focus_result.append(focus_score)
if len(self.focus_result) > 3:
self.focus_result.pop(0)
final_focus = int(self.simple_moving_average(self.focus_result, window_size=3))
self.theta_result.append(theta_power)
if len(self.theta_result) > 30:
self.theta_result.pop(0)
self.alpha_result.append(alpha_power)
if len(self.alpha_result) > 30:
self.alpha_result.pop(0)
rest_theta = self.simple_moving_average(self.theta_result, window_size=30)
rest_alpha = self.simple_moving_average(self.alpha_result, window_size=30)
distraction_score = (theta_power / rest_theta) * (1 - (alpha_power / rest_alpha))
flow_score = gamma_power / beta_power
flow_score = (flow_score / self.Threshold_value_high) * 100
self.flow_result.append(flow_score)
if len(self.flow_result) > 3:
self.flow_result.pop(0)
final_flow = int(self.simple_moving_average(self.flow_result, window_size=3))
return final_focus, distraction_score, final_flow
def PSD(self, P1, L, Fs, s_freq, e_freq):
s_point = round(s_freq * L / Fs)
e_point = round(e_freq * L / Fs)
x, y = P1.shape
band_PSD = 0
for i in range(x):
for j in range(s_point, e_point):
band_PSD += P1[i, j] ** 2
return band_PSD
def simple_moving_average(self, data, window_size=3):
if len(data) == 0:
return []
window = data[-window_size:]
return sum(window) / len(window)

View File

@@ -0,0 +1,161 @@
[system]
SSVEP_ThresholdValue = [1,-0.023]
;SSVEP_ThresholdValue = [2,-0.00200]
SSMVEP_IntervalEpoch = [0.2,2.2]
list_freqs = [8, 9]
phase = [0, 0]
concentration_ThresholdValue = [0.1, 0.8]
MI_IntervalEpoch = [0.5,4.5]
blink = [70,500,100,500,800,3,2]
Right_rehabilitation = 5
Fault_rehabilitation = 5
Num_blocks = 1
Num_trials = 10
Audio_device = 0
Rest_time = 2
Device_type = 1
Device_Host = 127.0.0.1
Device_Port = 5086
Upper_Host = 127.0.0.1
Upper_Port = 8088
Serial_port = COM44
[Layout]
main_splitter_left = 993
main_splitter_right = 922
right_splitter_left = 233
right_splitter_right = 771
left_splitter_left = 503
left_splitter_right = 501q
[channel]
channel_x_fp1 = 419
channel_y_fp1 = 124
channel_x_fc1 = 439
channel_y_fc1 = 296
channel_x_fp2 = 576
channel_y_fp2 = 124
channel_x_fc2 = 556
channel_y_fc2 = 299
channel_x_f3 = 397
channel_y_f3 = 231
channel_x_cp1 = 439
channel_y_cp1 = 426
channel_x_f4 = 601
channel_y_f4 = 232
channel_x_cp2 = 559
channel_y_cp2 = 425
channel_x_fc3 = 379
channel_y_fc3 = 295
channel_x_af4 = 571
channel_y_af4 = 171
channel_x_po8 = 645
channel_y_po8 = 564
channel_x_fpz = 499
channel_y_fpz = 112
channel_x_fcz = 499
channel_y_fcz = 300
channel_x_poz = 500
channel_y_poz = 554
channel_x_po5 = 387
channel_y_po5 = 551
channel_x_po6 = 611
channel_y_po6 = 551
channel_x_c3 = 373
channel_y_c3 = 363
channel_x_fc5 = 319
channel_y_fc5 = 292
channel_x_c4 = 620
channel_y_c4 = 363
channel_x_fc6 = 676
channel_y_fc6 = 288
channel_x_p3 = 398
channel_y_p3 = 491
channel_x_cp5 = 322
channel_y_cp5 = 430
channel_x_p4 = 600
channel_y_p4 = 489
channel_x_cp6 = 678
channel_y_cp6 = 430
channel_x_c5 = 313
channel_y_c5 = 361
channel_x_f6 = 650
channel_y_f6 = 223
channel_x_f5 = 349
channel_y_f5 = 224
channel_x_po4 = 573
channel_y_po4 = 551
channel_x_po3 = 429
channel_y_po3 = 550
channel_x_cp4 = 619
channel_y_cp4 = 424
channel_x_cp3 = 381
channel_y_cp3 = 426
channel_x_fc4 = 619
channel_y_fc4 = 295
channel_x_o1 = 423
channel_y_o1 = 598
channel_x_ft9 = 252
channel_y_ft9 = 168
channel_x_o2 = 576
channel_y_o2 = 597
channel_x_ft10 = 798
channel_y_ft10 = 277
channel_x_f7 = 295
channel_y_f7 = 214
channel_x_tp9 = 202
channel_y_tp9 = 445
channel_x_f8 = 701
channel_y_f8 = 215
channel_x_t7 = 252
channel_y_t7 = 362
channel_x_tp7 = 261
channel_y_tp7 = 436
channel_x_ft8 = 734
channel_y_ft8 = 283
channel_x_ft7 = 264
channel_y_ft7 = 286
channel_x_af8 = 645
channel_y_af8 = 159
channel_x_af7 = 351
channel_y_af7 = 160
channel_x_p6 = 652
channel_y_p6 = 499
channel_x_p5 = 348
channel_y_p5 = 499
channel_x_c6 = 683
channel_y_c6 = 362
channel_x_f1 = 447
channel_y_f1 = 236
channel_x_t8 = 745
channel_y_t8 = 361
channel_x_f2 = 549
channel_y_f2 = 235
channel_x_p7 = 300
channel_y_p7 = 505
channel_x_c1 = 435
channel_y_c1 = 363
channel_x_p8 = 698
channel_y_p8 = 508
channel_x_c2 = 559
channel_y_c2 = 359
channel_x_fz = 499
channel_y_fz = 238
channel_x_po7 = 354
channel_y_po7 = 562
channel_x_tp8 = 735
channel_y_tp8 = 438
channel_x_oz = 498
channel_y_oz = 609
channel_x_af3 = 428
channel_y_af3 = 170
channel_x_pz = 501
channel_y_pz = 486
channel_x_p2 = 551
channel_y_p2 = 483
channel_x_cz = 499
channel_y_cz = 361
channel_x_p1 = 448
channel_y_p1 = 488

View File

@@ -0,0 +1,252 @@
0 0.5
1 0.5
2 0.375
3 0.5
4 0.4375
5 0.375
6 0.5
7 0.5
8 0.375
9 0.375
10 0.375
11 0.375
12 0.5
13 0.5625
14 0.5625
15 0.5
16 0.5
17 0.5
18 0.5
19 0.5625
20 0.4375
21 0.5
22 0.5
23 0.375
24 0.375
25 0.375
26 0.375
27 0.375
28 0.3125
29 0.375
30 0.5625
31 0.5
32 0.5
33 0.5625
34 0.5625
35 0.3125
36 0.3125
37 0.3125
38 0.375
39 0.5625
40 0.3125
41 0.5625
42 0.3125
43 0.375
44 0.5625
45 0.5
46 0.375
47 0.375
48 0.3125
49 0.375
50 0.375
51 0.5
52 0.5625
53 0.375
54 0.5625
55 0.5625
56 0.375
57 0.375
58 0.375
59 0.5
60 0.3125
61 0.375
62 0.375
63 0.375
64 0.375
65 0.375
66 0.3125
67 0.375
68 0.5625
69 0.5625
70 0.5625
71 0.5
72 0.5625
73 0.375
74 0.375
75 0.375
76 0.375
77 0.375
78 0.5
79 0.375
80 0.375
81 0.5
82 0.375
83 0.375
84 0.375
85 0.375
86 0.3125
87 0.375
88 0.375
89 0.5
90 0.375
91 0.4375
92 0.3125
93 0.3125
94 0.375
95 0.375
96 0.375
97 0.375
98 0.3125
99 0.4375
100 0.375
101 0.375
102 0.375
103 0.3125
104 0.5625
105 0.5
106 0.5625
107 0.5625
108 0.5
109 0.3125
110 0.5625
111 0.5625
112 0.5
113 0.3125
114 0.5
115 0.3125
116 0.375
117 0.3125
118 0.3125
119 0.3125
120 0.3125
121 0.375
122 0.375
123 0.375
124 0.375
125 0.3125
126 0.375
127 0.375
128 0.375
129 0.375
130 0.5625
131 0.375
132 0.5
133 0.3125
134 0.3125
135 0.3125
136 0.375
137 0.5
138 0.3125
139 0.375
140 0.3125
141 0.3125
142 0.3125
143 0.5625
144 0.3125
145 0.375
146 0.5
147 0.5
148 0.375
149 0.4375
150 0.5
151 0.3125
152 0.375
153 0.375
154 0.375
155 0.3125
156 0.375
157 0.4375
158 0.4375
159 0.375
160 0.375
161 0.3125
162 0.375
163 0.375
164 0.375
165 0.3125
166 0.3125
167 0.3125
168 0.375
169 0.3125
170 0.3125
171 0.3125
172 0.375
173 0.3125
174 0.3125
175 0.5
176 0.3125
177 0.375
178 0.375
179 0.3125
180 0.3125
181 0.3125
182 0.3125
183 0.5625
184 0.5625
185 0.3125
186 0.5
187 0.5
188 0.5625
189 0.5
190 0.5625
191 0.5625
192 0.5625
193 0.5
194 0.5
195 0.5625
196 0.5625
197 0.5625
198 0.5625
199 0.5
200 0.5625
201 0.5625
202 0.375
203 0.375
204 0.375
205 0.375
206 0.375
207 0.5
208 0.5
209 0.5625
210 0.5625
211 0.5625
212 0.3125
213 0.5
214 0.5
215 0.5625
216 0.5
217 0.5
218 0.5
219 0.5625
220 0.5
221 0.4375
222 0.5
223 0.5
224 0.4375
225 0.5
226 0.4375
227 0.5
228 0.5
229 0.375
230 0.375
231 0.3125
232 0.375
233 0.375
234 0.375
235 0.5625
236 0.5625
237 0.5625
238 0.5625
239 0.5625
240 0.5
241 0.5
242 0.5
243 0.5625
244 0.5625
245 0.375
246 0.375
247 0.375
248 0.3125
249 0.375
The average accuracy is: 0.42675
The best accuracy is: 0.5625

View File

@@ -0,0 +1,13 @@
import sys
import os
import multiprocessing
# 1. 路径自适应:在 Frozen 模式下,将当前工作目录切换到可执行文件所在目录
# 这样代码中使用的相对路径(如 './config.ini')就能正确指向 exe 旁边的文件
if getattr(sys, 'frozen', False):
os.chdir(os.path.dirname(sys.executable))
# 2. 多进程保护:防止 Windows 下的无限递归炸弹
# Windows 下 multiprocessing 需要 freeze_support()
if sys.platform.startswith('win'):
multiprocessing.freeze_support()

View File

@@ -0,0 +1,35 @@
import matplotlib
matplotlib.use('Agg')
import argparse
import sys
import time
from Decoder import Decoder_main
from PubLibrary.RunOnce import is_program_running
if __name__ == "__main__":
if not is_program_running():
# 解析命令行参数
parser = argparse.ArgumentParser(description="EEG Decoder Application")
parser.add_argument('-dt', '--device-type', type=int, default=None, help="Device Type")
parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP")
parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port")
parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP")
parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port")
args = parser.parse_args()
decoder = Decoder_main()
decoder.connect(
device_type=args.device_type,
device_host=args.device_host,
device_port=args.device_port,
upper_host=args.upper_host,
upper_port=args.upper_port
)
try:
decoder.start()
while not decoder.zmqServer.IsExitApp:
time.sleep(1)
except KeyboardInterrupt:
decoder.stop()