diff --git a/.gitignore b/.gitignore index d8a7988..021d662 100644 --- a/.gitignore +++ b/.gitignore @@ -5,8 +5,8 @@ __pycache__/ build/ dist/ upperHost_stim/ -!upperHost_stim/MI_headless.py -!upperHost_stim/ssmvep_headless.py +#!upperHost_stim/MI_headless.py +#!upperHost_stim/ssmvep_headless.py .env .venv env/ diff --git a/upperHost_stimmock/MI_headless.py b/upperHost_stimmock/MI_headless.py new file mode 100644 index 0000000..92aa669 --- /dev/null +++ b/upperHost_stimmock/MI_headless.py @@ -0,0 +1,305 @@ +""" +MI_headless.py +无界面版 MI 运动想象范式通讯流程模拟脚本。 +复现 MI_main.py 的完整指令序列(train 0/1, rest, predict, saveData), +但不依赖 psychopy 也不打开任何窗口/音频,用 time.sleep 替代帧循环等待。 + +启动顺序: + 1. runDecoder.py + 2. datamock.py + 3. MI_headless.py +""" + +import sys +import os +import json +import time +import threading +import zmq +import numpy as np +import ast +from datetime import datetime + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from PubLibrary.InifileHelper import IniRead + +personname = 'demo' +session = '01' + +DATAMOCK_LABEL_ADDR = 'tcp://127.0.0.1:8101' # datamock 标签命令地址 + + +# ========== ZMQ 结果接收服务 ========== +class ZmqResultServer(threading.Thread): + def __init__(self, port=8088): + threading.Thread.__init__(self) + self.port = port + self.running = True + self.energy = 0 + self.paradigm = 0 # 0=个体校准, 1=康复训练, 2=等待模型训练 + self.ChoosenNum = -1 + self.context = zmq.Context() + self.socket = self.context.socket(zmq.ROUTER) + self.socket.bind(f"tcp://0.0.0.0:{self.port}") + self.daemon = True + self.trial_idx = 0 + + def run(self): + print(f"[Server] UpperHost_Server listening on {self.port}") + while self.running: + try: + frames = self.socket.recv_multipart(zmq.NOBLOCK) + if len(frames) < 3: + continue + message = json.loads(frames[2].decode('utf-8')) + method = message.get('method') + params = message.get('params') + if method == 'energy': + self.energy = params + elif method == 'paradigm': + self.paradigm = params + print(f"[Server] paradigm -> {params}") + elif method == 'result': + self.ChoosenNum = params + self.trial_idx += 1 + print(f"[Server] result={self.ChoosenNum} (trial {self.trial_idx})") + except zmq.Again: + time.sleep(0.005) + except Exception as e: + print(f"[Server] error: {e}") + + def stop(self): + self.running = False + self.socket.close() + self.context.term() + + +# ========== ZMQ 命令发送客户端 ========== +class ZmqCmdClient: + def __init__(self, host, port): + self.host = host + self.port = port + self.context = zmq.Context() + self.socket = self.context.socket(zmq.DEALER) + # PUSH socket 用于向 datamock.py 发送标签命令 + self._label_sock = self.context.socket(zmq.PUSH) + self._label_sock.connect(DATAMOCK_LABEL_ADDR) + print(f"[Client] label PUSH connected to {DATAMOCK_LABEL_ADDR}") + + def connect(self): + self.socket.connect(f"tcp://{self.host}:{self.port}") + print(f"[Client] connected to {self.host}:{self.port}") + + def start_recv_thread(self, result_server): + """启动后台线程,持续接收 decoder 通过 8099 ROUTER 回发的消息,并更新 result_server 的状态""" + self._result_server = result_server + self._stop_recv = threading.Event() + + def _recv_loop(): + while not self._stop_recv.is_set(): + try: + frames = self.socket.recv_multipart(zmq.NOBLOCK) + # DEALER 收到的格式: [b'', json_bytes] + data_bytes = frames[-1] + message = json.loads(data_bytes.decode('utf-8')) + method = message.get('method') + params = message.get('params') + ts = datetime.now().strftime('%H:%M:%S.%f')[:-3] + print(f"[{ts}] [CmdClient] recv: {method}={params}") + if method == 'paradigm': + self._result_server.paradigm = params + print(f"[{ts}] [CmdClient] paradigm updated -> {params}") + elif method == 'result': + self._result_server.ChoosenNum = params + self._result_server.trial_idx += 1 + print(f"[{ts}] [CmdClient] result={params} (trial {self._result_server.trial_idx})") + elif method == 'energy': + self._result_server.energy = params + except zmq.Again: + time.sleep(0.005) + except Exception as e: + print(f"[CmdClient recv] error: {e}") + time.sleep(0.01) + + self._recv_thread = threading.Thread(target=_recv_loop, daemon=True) + self._recv_thread.start() + print(f"[Client] 后台接收线程已启动(监听 decoder 8099 回发消息)") + + def stop_recv_thread(self): + if hasattr(self, '_stop_recv'): + self._stop_recv.set() + + def _send_label(self, label_value): + """向 datamock.py 发送标签命令""" + try: + self._label_sock.send_string(str(label_value), zmq.NOBLOCK) + except Exception as e: + print(f"[Client] label send error: {e}") + + def send_data(self, method, params): + msg = {'method': method, 'params': params} + try: + self.socket.send_multipart([b'', json.dumps(msg).encode('utf-8')]) + ts = datetime.now().strftime('%H:%M:%S.%f')[:-3] + print(f"[{ts}] send_data: {method}={params}") + # 根据 train/predict 命令向 datamock 发送标签 + if method == 'train': + if params == 0: + self._send_label(1) + print(f"[Label] train 0 -> datamock label=1") + elif params == 1: + self._send_label(2) + print(f"[Label] train 1 -> datamock label=2") + elif method == 'predict': + self._send_label(99) + print(f"[Label] predict -> datamock label=99") + except Exception as e: + print(f"[Client] send error: {e}") + + +# ========== 主流程 ========== +def run_headless(): + server = ZmqResultServer(port=8088) + server.start() + + _dh = str(IniRead('system', 'Decoder_Host')) + _dp = int(IniRead('system', 'Decoder_Port')) + client = ZmqCmdClient(_dh, _dp) + client.connect() + client.start_recv_thread(server) # 启动后台接收线程,监听 decoder 8099 回发的 paradigm/result 消息 + + time.sleep(1) # 等待连接建立 + client.send_data('decoderClass', 'mi') + + # MI_IntervalEpoch = [0.5, 4.5],trial时长 = 4.5-0.5 = 4.0s + _mi_iv = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) + _trial_sec = float(_mi_iv[1] - _mi_iv[0]) + _margin = 1.0 + train_time = max(5.0, _trial_sec + _margin) # 训练刺激时长(与 MI_main.py 保持一致) + + # MI epoch latency = interval_epoch[1] // 5 = (4.5*250)//5 = 225包 × 20ms = 4.5s + # train_latency = 225包(MI中 train_latency == latency) + # 在 train_time 后需再等 epoch_wait 秒,decoder 才能完成 epoch 采集 + epoch_wait = _mi_iv[1] / _mi_iv[1] * (_mi_iv[1] * 250 // 5) * 0.02 # = latency * 20ms + # 更直接的计算:latency = interval_epoch[1] // 5 = int(4.5*250)//5 = 225,225*0.02 = 4.5s + epoch_wait = (int(_mi_iv[1] * 250) // 5) * 0.02 # 4.5s + + # predict epoch wait(与 train 相同,MI中 latency == train_latency) + predict_epoch_wait = epoch_wait # 4.5s + + test_time = 7.0 # 预测窗口时长(与 MI_main.py 保持一致) + right_rehabilitation = float(IniRead('system', 'Right_rehabilitation')) + fault_rehabilitation = float(IniRead('system', 'Fault_rehabilitation')) + rest_time = float(IniRead('system', 'Rest_time')) + + num_blocks = int(IniRead('system', 'Num_blocks')) + num_trials = int(IniRead('system', 'Num_trials')) + + trained = 0 + Num_Total = 0 + Num_Success = 0 + user_choice = [] + + print("=" * 50) + print("[Headless] 开始运行 MI 通讯流程(无界面)") + print(f" MI_IntervalEpoch={_mi_iv}, trial_sec={_trial_sec:.2f}s") + print(f" train_time={train_time:.2f}s, epoch_wait={epoch_wait:.2f}s") + print(f" test_time={test_time:.2f}s, predict_epoch_wait={predict_epoch_wait:.2f}s") + print(f" num_blocks={num_blocks}, num_trials={num_trials}") + print("=" * 50) + + try: + while True: + # -------- 个体校准阶段 -------- + print("\n[Phase] 个体校准阶段 (paradigm=0)") + client.send_data('rest', 0) + time.sleep(1) + + while server.paradigm == 0: + # 左侧 MI 刺激(train 0,label=1) + print(f"\n[Train] 左侧 MI 刺激 (train 0) trained={trained}") + client.send_data('rest', 0) + time.sleep(0.5) # ding 提示后等待 + + client.send_data('train', 0) + time.sleep(train_time + epoch_wait) # 等待刺激时间 + epoch 完成时间 + + trained += 1 + client.send_data('rest', 0) + time.sleep(1.0) # 类间休息 + + # 空闲态样本采集(train 1,label=2) + print(f"\n[Train] 空闲态采集 (train 1) trained={trained}") + client.send_data('train', 1) + time.sleep(train_time + epoch_wait) # 等待刺激时间 + epoch 完成时间 + + trained += 1 + client.send_data('rest', 0) + time.sleep(1.0) # 类间休息 + + # 个体校准阶段结束 + print("\n[Phase] 个体校准结束,等待模型训练 (paradigm=2) ...") + trained = 0 + time.sleep(1) + + # 等待模型训练完成 (paradigm=2 -> paradigm=1) + while server.paradigm == 2: + print("[Phase] 等待模型训练完成 ...") + time.sleep(0.5) + + # -------- 康复训练阶段 -------- + while server.paradigm == 1: + print("\n[Phase] 康复训练阶段 (paradigm=1)") + for block_idx in range(num_blocks): + print(f"\n [Block {block_idx+1}/{num_blocks}]") + time.sleep(10) # 每轮开始前等待 + + for trial_idx in range(num_trials): + print(f" [Trial {trial_idx+1}/{num_trials}]") + + time.sleep(0.5) # ding 提示 + server.ChoosenNum = -1 + + # 开始预测 + # MI predict epoch latency = 225包 × 20ms = 4.5s,需额外等待 epoch 完成 + client.send_data('predict', 1) + t_start = time.perf_counter() + while time.perf_counter() - t_start < test_time + predict_epoch_wait: + if server.ChoosenNum >= 0: + Num_Total += 1 + user_choice.append(server.ChoosenNum) + if server.ChoosenNum == 0: + Num_Success += 1 + rest_time = right_rehabilitation + elif server.ChoosenNum == 1: + rest_time = fault_rehabilitation + break + time.sleep(0.02) + + trained += 1 + client.send_data('rest', 0) + time.sleep(0.5) + time.sleep(rest_time) + server.ChoosenNum = -1 + + # 训练结束 + print("\n[Phase] 康复训练结束") + break # 退出康复训练循环 + + # 统计结果 + overall_accuracy = Num_Success / Num_Total if Num_Total > 0 else 0 + print(f"\n[Result] Overall={overall_accuracy:.3f} ({Num_Success}/{Num_Total})") + print(f"[Result] user_choice={user_choice}") + break # 完成一个完整流程后退出 + + except KeyboardInterrupt: + print("\n[Headless] 用户中断") + finally: + client.send_data('predict', 2) # 关闭系统 + client.send_data('saveData', 0) + server.stop() + print("[Headless] 已发送关闭指令,退出。") + + +if __name__ == '__main__': + run_headless() diff --git a/upperHost_stimmock/ssmvep_headless.py b/upperHost_stimmock/ssmvep_headless.py new file mode 100644 index 0000000..ea5567b --- /dev/null +++ b/upperHost_stimmock/ssmvep_headless.py @@ -0,0 +1,301 @@ +""" +ssmvep_headless.py +无界面版 SSMVEP 范式通讯流程模拟脚本。 +复现 ssmvep_main.py 的完整指令序列(train 0/1/2, rest, predict, saveData), +但不依赖 psychopy 也不打开任何窗口/音频,用 time.sleep 替代帧循环等待。 + +启动顺序: + 1. runDecoder.py + 2. datamock.py + 3. ssmvep_headless.py +""" + +import sys +import os +import json +import time +import threading +import zmq +import numpy as np +from datetime import datetime + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from PubLibrary.InifileHelper import IniRead + +personname = 'demo' +session = '01' + +DATAMOCK_LABEL_ADDR = 'tcp://127.0.0.1:8101' # datamock 标签命令地址 + + +# ========== ZMQ 结果接收服务 ========== +class ZmqResultServer(threading.Thread): + def __init__(self, port=8088): + threading.Thread.__init__(self) + self.port = port + self.running = True + self.energy = 0 + self.paradigm = 0 # 0=个体校准, 1=康复训练, 2=等待模型训练 + self.ChoosenNum = -1 + self.context = zmq.Context() + self.socket = self.context.socket(zmq.ROUTER) + self.socket.bind(f"tcp://0.0.0.0:{self.port}") + self.daemon = True + self.trial_idx = 0 + + def run(self): + print(f"[Server] UpperHost_Server listening on {self.port}") + while self.running: + try: + frames = self.socket.recv_multipart(zmq.NOBLOCK) + if len(frames) < 3: + continue + message = json.loads(frames[2].decode('utf-8')) + method = message.get('method') + params = message.get('params') + if method == 'energy': + self.energy = params + elif method == 'paradigm': + self.paradigm = params + print(f"[Server] paradigm -> {params}") + elif method == 'result': + self.ChoosenNum = params + self.trial_idx += 1 + print(f"[Server] result={self.ChoosenNum} (trial {self.trial_idx})") + except zmq.Again: + time.sleep(0.005) + except Exception as e: + print(f"[Server] error: {e}") + + def stop(self): + self.running = False + self.socket.close() + self.context.term() + + +# ========== ZMQ 命令发送客户端 ========== +class ZmqCmdClient: + def __init__(self, host, port): + self.host = host + self.port = port + self.context = zmq.Context() + self.socket = self.context.socket(zmq.DEALER) + # PUSH socket 用于向 datamock.py 发送标签命令 + self._label_sock = self.context.socket(zmq.PUSH) + self._label_sock.connect(DATAMOCK_LABEL_ADDR) + print(f"[Client] label PUSH connected to {DATAMOCK_LABEL_ADDR}") + + def connect(self): + self.socket.connect(f"tcp://{self.host}:{self.port}") + print(f"[Client] connected to {self.host}:{self.port}") + + def start_recv_thread(self, result_server): + """启动后台线程,持续接收 decoder 通过 8099 ROUTER 回发的消息,并更新 result_server 的状态""" + self._result_server = result_server + self._stop_recv = threading.Event() + + def _recv_loop(): + while not self._stop_recv.is_set(): + try: + frames = self.socket.recv_multipart(zmq.NOBLOCK) + # DEALER 收到的格式: [b'', json_bytes] + data_bytes = frames[-1] + message = json.loads(data_bytes.decode('utf-8')) + method = message.get('method') + params = message.get('params') + ts = datetime.now().strftime('%H:%M:%S.%f')[:-3] + print(f"[{ts}] [CmdClient] recv: {method}={params}") + if method == 'paradigm': + self._result_server.paradigm = params + print(f"[{ts}] [CmdClient] paradigm updated -> {params}") + elif method == 'result': + self._result_server.ChoosenNum = params + self._result_server.trial_idx += 1 + print(f"[{ts}] [CmdClient] result={params} (trial {self._result_server.trial_idx})") + elif method == 'energy': + self._result_server.energy = params + except zmq.Again: + time.sleep(0.005) + except Exception as e: + print(f"[CmdClient recv] error: {e}") + time.sleep(0.01) + + self._recv_thread = threading.Thread(target=_recv_loop, daemon=True) + self._recv_thread.start() + print(f"[Client] 后台接收线程已启动(监听 decoder 8099 回发消息)") + + def stop_recv_thread(self): + if hasattr(self, '_stop_recv'): + self._stop_recv.set() + + def _send_label(self, label_value): + """向 datamock.py 发送标签命令""" + try: + self._label_sock.send_string(str(label_value), zmq.NOBLOCK) + except Exception as e: + print(f"[Client] label send error: {e}") + + def send_data(self, method, params): + msg = {'method': method, 'params': params} + try: + self.socket.send_multipart([b'', json.dumps(msg).encode('utf-8')]) + ts = datetime.now().strftime('%H:%M:%S.%f')[:-3] + print(f"[{ts}] send_data: {method}={params}") + # 根据 train/predict 命令向 datamock 发送标签 + if method == 'train': + if params == 0: + self._send_label(1) + print(f"[Label] train 0 -> datamock label=1") + elif params == 1: + self._send_label(2) + print(f"[Label] train 1 -> datamock label=2") + elif method == 'predict': + self._send_label(99) + print(f"[Label] predict -> datamock label=99") + except Exception as e: + print(f"[Client] send error: {e}") + + +# ========== 主流程 ========== +def run_headless(): + server = ZmqResultServer(port=8088) + server.start() + + _dh = str(IniRead('system', 'Decoder_Host')) + _dp = int(IniRead('system', 'Decoder_Port')) + client = ZmqCmdClient(_dh, _dp) + client.connect() + client.start_recv_thread(server) # 启动后台接收线程,监听 decoder 8099 回发的 paradigm/result 消息 + + time.sleep(1) # 等待连接建立 + client.send_data('decoderClass', 'ssmvep') + + train_time = 2.5 # 每轮训练刺激时长 (s) + test_time = 2.5 # 每轮测试刺激时长 (s) + right_rehabilitation = float(IniRead('system', 'Right_rehabilitation')) + fault_rehabilitation = float(IniRead('system', 'Fault_rehabilitation')) + rest_time = float(IniRead('system', 'Rest_time')) + + num_blocks = int(IniRead('system', 'Num_blocks')) + num_trials = int(IniRead('system', 'Num_trials')) + + position = [0, 1] + truePos_seq = position * int(num_trials / len(position)) + truePos_seq = np.random.permutation(truePos_seq).tolist() + user_choice = [] + + os.makedirs('EEGFiles', exist_ok=True) + seq_file_path = f'EEGFiles/pos_seq_{personname}{session}_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json' + seq_info = { + 'position': position, + 'sequence': truePos_seq, + 'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S') + } + with open(seq_file_path, 'w', encoding='utf-8') as f: + json.dump(seq_info, f, ensure_ascii=False, indent=2) + + trained = 0 + Num_Total = 0 + Num_Success = 0 + + print("=" * 50) + print("[Headless] 开始运行 SSMVEP 通讯流程(无界面)") + print(f" num_blocks={num_blocks}, num_trials={num_trials}") + print(f" train_time={train_time}s, test_time={test_time}s") + print("=" * 50) + + try: + while True: + # -------- 个体校准阶段 -------- + print("\n[Phase] 个体校准阶段 (paradigm=0)") + client.send_data('rest', 0) + time.sleep(1) + + # epoch完成需要的额外等待时间:train_latency=120包×20ms=2.4s + # 在train_time后需再等epoch_wait秒,decoder才能完成epoch采集并取出数据 + epoch_wait = 2.4 # 秒,与train_latency对应 + + while server.paradigm == 0: + # 左腿刺激 + print(f"\n[Train] 左腿刺激 (train 0) trained={trained}") + client.send_data('train', 0) + time.sleep(train_time + epoch_wait) # 等待刺激时间+epoch完成时间 + + trained += 1 + client.send_data('rest', 0) + time.sleep(max(0, abs(fault_rehabilitation - train_time) - epoch_wait)) + + # 右腿刺激 + print(f"\n[Train] 右腿刺激 (train 1) trained={trained}") + client.send_data('train', 1) + time.sleep(train_time + epoch_wait) # 等待刺激时间+epoch完成时间 + + trained += 1 + client.send_data('rest', 0) + time.sleep(max(0, fault_rehabilitation - epoch_wait)) + + # 个体校准阶段结束 + print("\n[Phase] 个体校准结束,等待 paradigm=1 ...") + trained = 0 + time.sleep(1) + + # -------- 康复训练阶段 -------- + while server.paradigm == 1: + print("\n[Phase] 康复训练阶段 (paradigm=1)") + for block_idx in range(num_blocks): + print(f"\n [Block {block_idx+1}/{num_blocks}]") + time.sleep(10) # 每轮开始前等待 + + for trial_idx in range(num_trials): + true_position = truePos_seq[trial_idx] + print(f" [Trial {trial_idx+1}/{num_trials}] true_pos={true_position}") + + time.sleep(0.5) # 提示 + 叮声 + server.ChoosenNum = -1 + + # 开始测试 + # predict epoch latency = 115包×20ms = 2.3s,需额外等待epoch完成 + predict_epoch_wait = 2.3 # 秒,与predict latency=115包对应 + client.send_data('predict', 1) + t_start = time.perf_counter() + while time.perf_counter() - t_start < test_time + predict_epoch_wait: + if server.ChoosenNum >= 0: + Num_Total += 1 + user_choice.append(server.ChoosenNum) + if server.ChoosenNum in [0, 1]: + Num_Success += 1 + rest_time = right_rehabilitation + break + time.sleep(0.02) + + trained += 1 + client.send_data('rest', 0) + time.sleep(0.5) + time.sleep(rest_time) + server.ChoosenNum = -1 + + # 训练结束 + print("\n[Phase] 康复训练结束") + break # 退出康复训练循环 + + # 统计结果 + overall_accuracy = Num_Success / Num_Total if Num_Total > 0 else 0 + expected_seq = truePos_seq * num_blocks + min_len = min(len(user_choice), len(expected_seq)) + same_count = sum(1 for a, b in zip(user_choice[:min_len], expected_seq[:min_len]) if a == b) + true_accuracy = same_count / min_len if min_len > 0 else 0 + print(f"\n[Result] Overall={overall_accuracy:.3f} ({Num_Success}/{Num_Total})") + print(f"[Result] TrueAcc={true_accuracy:.3f} ({same_count}/{min_len})") + break # 完成一个完整流程后退出 + + except KeyboardInterrupt: + print("\n[Headless] 用户中断") + finally: + client.send_data('predict', 2) # 关闭系统 + client.send_data('saveData', 0) + server.stop() + print("[Headless] 已发送关闭指令,退出。") + + +if __name__ == '__main__': + run_headless() diff --git a/upperHost_stimmock/ssvep_main.py b/upperHost_stimmock/ssvep_main.py new file mode 100644 index 0000000..90127a6 --- /dev/null +++ b/upperHost_stimmock/ssvep_main.py @@ -0,0 +1,364 @@ +import time + +from psychopy import visual, core, logging # import some libraries from PsychoPy +import random +from datetime import datetime + +# LAB STREAMING LAYER1 +from pylsl import StreamInfo, StreamOutlet +from psychopy import event +import numpy as np +from DecoderDW.Server import TCPServer +from DecoderDW.Client import TCPClient +# import subprocess + +# ---------------------- +# constants +# size of the window +WINWIDTH = 1920 +WINHEIGHT = 1080 +REFRESH_RATE = 144 + + + +def get_keypress(): + keys = event.getKeys() + if keys: + return keys[0] + else: + return None + + +def shutdown(win,client): + client.send_data('saveData', 0) + client.send_data('predict',2) + win.close() + core.quit() + + +# end of configuration +# ---------------------- + +def generate_square_wave(frequency, sampling_rate=REFRESH_RATE, duration=5): + """ + 生成方波序列 + + 参数: + frequency (float): 频率(Hz) + sampling_rate (int): 采样率(Hz),应与屏幕刷新率一致 + duration (float): 时长(秒) + + 返回: + square_wave (list): 方波序列 + """ + # 计算总点数 + n_points = int(duration * sampling_rate) + + # 生成时间序列 + time = np.linspace(0, duration, n_points, endpoint=False) + + # 生成正弦波数据 + sin_wave = np.sin(2 * np.pi * frequency * time) + # 生成方波数据 + square_wave = np.where(sin_wave >= 0, 1, 0) + + return square_wave.tolist() + + +# 启动一个进程,不等待其完成 +import os +if __name__ == "__main__": + # ---------------------------------------------------------------------------------- + # main window settings + main_win = visual.Window(size=(WINWIDTH, WINHEIGHT), units='height', screen=0, fullscr=False, + gammaErrorPolicy='warn', color=(0.7, 0.7, 0.7)) + print('starting 1') + # Set up LabStreamingLayer stream. + info = StreamInfo(name='psychopy_stimuli', type='Markers', channel_count=1, channel_format='string', + source_id='psychopy_stimuli_001') + outlet = StreamOutlet(info) # Broadcast the stream. + + imageStim1 = visual.ImageStim(main_win, size=(300, 300), pos=(-600, 300), units='pix', image='UI/figures/xy.jpg') + txtStim1 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True, + italic=False, pos=(-600, 30)) + + imageStim2 = visual.ImageStim(main_win, size=(300, 300), pos=(0, 300), units='pix', image='UI/figures/xy.jpg') + txtStim2 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True, + italic=False, pos=(0, 30)) + + imageStim3 = visual.ImageStim(main_win, size=(300, 300), pos=(600, 300), units='pix', image='UI/figures/xy.jpg') + txtStim3 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True, + italic=False, pos=(600, 30)) + imageStim4 = visual.ImageStim(main_win, size=(300, 300), pos=(-600, -200), units='pix', image='UI/figures/xy.jpg') + txtStim4 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True, + italic=False, pos=(-600, -470)) + imageStim5 = visual.ImageStim(main_win, size=(300, 300), pos=(0, -200), units='pix', image='UI/figures/xy.jpg') + txtStim5 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True, + italic=False, pos=(0, -470)) + imageStim6 = visual.ImageStim(main_win, size=(300, 300), pos=(600, -200), units='pix', image='UI/figures/xy.jpg') + txtStim6 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True, + italic=False, pos=(600, -470)) + imageStim1red = visual.ImageStim(main_win, size=(300, 300), pos=(-600, 300), units='pix', image='UI/figures/xy_red.jpg') + imageStim2red = visual.ImageStim(main_win, size=(300, 300), pos=(0, 300), units='pix', image='UI/figures/xy_red.jpg') + imageStim3red = visual.ImageStim(main_win, size=(300, 300), pos=(600, 300), units='pix', image='UI/figures/xy_red.jpg') + imageStim4red = visual.ImageStim(main_win, size=(300, 300), pos=(-600, -200), units='pix', image='UI/figures/xy_red.jpg') + imageStim5red = visual.ImageStim(main_win, size=(300, 300), pos=(0, -200), units='pix', image='UI/figures/xy_red.jpg') + imageStim6red = visual.ImageStim(main_win, size=(300, 300), pos=(600, -200), units='pix', image='UI/figures/xy_red.jpg') + + + frequencies = [25,26,27,28,29,30] #[9,10,11,12,13,14] #[30,31,32,33,34,35] [25,26,27,28,29,30] + # 生成方波数据 + square_wave_9 = generate_square_wave(frequencies[0], REFRESH_RATE, 5) + square_wave_11 = generate_square_wave(frequencies[1], REFRESH_RATE, 5) + square_wave_12 = generate_square_wave(frequencies[2], REFRESH_RATE, 5) + square_wave_13 = generate_square_wave(frequencies[3], REFRESH_RATE, 5) + square_wave_14 = generate_square_wave(frequencies[4], REFRESH_RATE, 5) + square_wave_15 = generate_square_wave(frequencies[5], REFRESH_RATE, 5) + + # 创建刺激对象列表,便于管理 + image_stims = [imageStim1, imageStim2, imageStim3, imageStim4, imageStim5, imageStim6] + txt_stims = [txtStim1, txtStim2, txtStim3, txtStim4, txtStim5, txtStim6] + square_waves = [square_wave_9, square_wave_11, square_wave_12, square_wave_13, square_wave_14, square_wave_15] + + time.sleep(2) + # grating.color = 'black' + server = TCPServer() + server.start() + client = TCPClient('127.0.0.1', 8099) + client.connect() + print('Connected decoder_main') + # client.send_data('impedance', 1) + # time.sleep(20) + # client.send_data('impedance', 2) + client.send_data('targetFreqs', frequencies) # 使用frequencies变量,确保与刺激频率一致 + time.sleep(1) + # 开启全程数据保存到 EEGFiles + client.send_data('saveData',1) + # client.send_data('impedance',1) + + + + # 实验参数 + repeats = 3 + seq_freq = frequencies * repeats + seq_freq = np.random.permutation(seq_freq).tolist() + num_trials = len(seq_freq) # 总试验次数, 6*6=36 + trial_count = 0 + + # 在线解码精度计算 + online_results = [] # 存储每个trial的解码结果 + correct_predictions = 0 # 正确预测计数 + + # 保存序列信息 + seq_info = { + 'total_trials': num_trials, + 'frequencies': frequencies, + 'sequence': seq_freq, + 'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S') + } + # 保存序列信息到文件 + import json + seq_file_path = f'EEGFiles/sequence_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json' + with open(seq_file_path, 'a', encoding='utf-8') as f: + json.dump(seq_info, f, ensure_ascii=False, indent=2) + + +#========================Trials Started======================# + while trial_count < num_trials: + # 从序列中获取当前试验的目标频率 + target_freq = seq_freq[trial_count] + target_freq_index = frequencies.index(target_freq) + print(f'Trials {trial_count + 1}/{num_trials} - Target Frequency: {target_freq}Hz (Label: {target_freq_index + 1})') + + # Stage 1: Cue Stage + # print('Cue Stage: The target frequency is in Red') + client.send_data('setLabelAndTrialInfo', { + 'label': 0, + 'trial_info': { + 'trial': trial_count + 1, + 'phase': 'cue', + 'target_freq': target_freq + } + }) + + for frameN in range(int(1 * REFRESH_RATE)): # 1秒提示 + key_press = get_keypress() + if key_press in ['q']: + shutdown(main_win, client) + + # 显示所有刺激,目标刺激为红色 + for i, stim in enumerate(image_stims): + if i == target_freq_index: + # 目标刺激显示红色 + if i == 0: + imageStim1red.draw() + elif i == 1: + imageStim2red.draw() + elif i == 2: + imageStim3red.draw() + elif i == 3: + imageStim4red.draw() + elif i == 4: + imageStim5red.draw() + elif i == 5: + imageStim6red.draw() + else: + # 其他刺激显示正常颜色 + stim.draw() + + main_win.flip() + + # Stage 2: Flanker Stimulus + # print('Flanker Stage: flank all frequencies') + client.send_data('predict', 1) + client.send_data('setLabelAndTrialInfo', { + 'label': target_freq_index + 1, # 设置目标频率标签 这里+1,是因为0代表不记录数据 + 'trial_info': { + 'trial': trial_count + 1, # trial 从0开始 + 'phase': 'stimulus', + 'target_freq': target_freq + } + }) + outlet.push_sample(['S 1']) + + for frameN in range(6 * REFRESH_RATE): # 6秒刺激 + key_press = get_keypress() + if key_press in ['q']: + shutdown(main_win, client) + + # 所有频率按照方波闪烁 + if square_wave_9[frameN % len(square_wave_9)] == 1: + imageStim1.draw() + if square_wave_11[frameN % len(square_wave_11)] == 1: + imageStim2.draw() + if square_wave_12[frameN % len(square_wave_12)] == 1: + imageStim3.draw() + if square_wave_13[frameN % len(square_wave_13)] == 1: + imageStim4.draw() + if square_wave_14[frameN % len(square_wave_14)] == 1: + imageStim5.draw() + if square_wave_15[frameN % len(square_wave_15)] == 1: + imageStim6.draw() + + main_win.flip() + if server.ChoosenNum != -1: + break + + # 记录在线解码结果 + predicted_freq_index = server.ChoosenNum # 解码结果 + predicted_freq = frequencies[predicted_freq_index] if predicted_freq_index != -1 else -1 + + # 判断解码是否正确 + is_correct = (predicted_freq_index == target_freq_index) if predicted_freq_index != -1 else False + if is_correct: + correct_predictions += 1 + + # 记录trial结果 + trial_result = { + 'trial': trial_count + 1, + 'target_freq': target_freq, + 'target_freq_index': target_freq_index, + 'predicted_freq': predicted_freq, + 'predicted_freq_index': predicted_freq_index, + 'is_correct': is_correct, + 'status': 'Success' if predicted_freq_index != -1 else 'Failed' + } + online_results.append(trial_result) + + # 打印当前trial结果 + status_symbol = "✓" if is_correct else "✗" + if predicted_freq_index == -1: + print(f'Trial {trial_count + 1}: 目标{target_freq}Hz -> 解码失败 - {status_symbol}') + else: + print(f'Trial {trial_count + 1}: 目标{target_freq}Hz -> 预测{predicted_freq}Hz - {status_symbol}') + + + # Stage 3: Decoding Feedback + outlet.push_sample(['S 2']) + client.send_data('setLabelAndTrialInfo', { + 'label': 0, # 反馈阶段标签为0 + 'trial_info': { + 'trial': trial_count + 1, + 'phase': 'feedback', + 'target_freq': target_freq + } + }) + # print('反馈阶段: 显示解码结果') + + for frameN in range(1 * REFRESH_RATE): # 1秒反馈 + key_press = get_keypress() + if key_press in ['q']: + shutdown(main_win, client) + + # 显示所有刺激但不闪烁 + for stim in image_stims: + stim.draw() + + # 显示解码结果 + if server.ChoosenNum == 0: + txtStim1.draw() + elif server.ChoosenNum == 1: + txtStim2.draw() + elif server.ChoosenNum == 2: + txtStim3.draw() + elif server.ChoosenNum == 3: + txtStim4.draw() + elif server.ChoosenNum == 4: + txtStim5.draw() + elif server.ChoosenNum == 5: + txtStim6.draw() + + main_win.flip() + + server.ChoosenNum = -1 + trial_count += 1 + +# 计算总体在线解码精度 +total_trials = len(online_results) +successful_trials = len([r for r in online_results if r['status'] == 'Success']) +failed_trials = len([r for r in online_results if r['status'] == 'Failed']) +overall_accuracy = correct_predictions / total_trials if total_trials > 0 else 0 + +# Print Accuracy +print(f"Total Accuracy: {overall_accuracy:.3f} ({correct_predictions}/{total_trials})") + +# 按频率分析准确率 +print(f"\n=== 按频率分析准确率 ===") +freq_accuracy = {} +for result in online_results: + freq = result['target_freq'] + if freq not in freq_accuracy: + freq_accuracy[freq] = {'correct': 0, 'total': 0, 'failed': 0} + + freq_accuracy[freq]['total'] += 1 + if result['status'] == 'Failed': + freq_accuracy[freq]['failed'] += 1 + elif result['is_correct']: + freq_accuracy[freq]['correct'] += 1 + +print(f"{'频率':<8} {'准确率':<8} {'正确/总数':<10} {'失败数':<8}") +print("-" * 40) +for freq in sorted(freq_accuracy.keys()): + stats = freq_accuracy[freq] + accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0 + print(f"{freq}Hz{'':<4} {accuracy:.3f}{'':<4} {stats['correct']}/{stats['total']}{'':<6} {stats['failed']}") + +# 保存在线解码结果到文件 +online_results_file = f'EEGFiles/online_results_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json' +online_summary = { + 'total_trials': total_trials, + 'successful_trials': successful_trials, + 'failed_trials': failed_trials, + 'correct_predictions': correct_predictions, + 'overall_accuracy': overall_accuracy, + # 'freq_accuracy': freq_accuracy, + 'trial_results': online_results, + # 'end_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S') +} + +with open(online_results_file, 'w', encoding='utf-8') as f: + json.dump(online_summary, f, ensure_ascii=False, indent=2) + + +client.send_data('predict',2) # 关闭系统 +main_win.close()