Files
bci_algo/upperHost_stimmock/ssmvep_headless.py
2026-06-10 09:25:11 +08:00

302 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
ssmvep_headless.py
无界面版 SSMVEP 范式通讯流程模拟脚本。
复现 ssmvep_main.py 的完整指令序列train 0/1/2, rest, predict, saveData
但不依赖 psychopy 也不打开任何窗口/音频,用 time.sleep 替代帧循环等待。
启动顺序:
1. runDecoder.py
2. datamock.py
3. ssmvep_headless.py
"""
import sys
import os
import json
import time
import threading
import zmq
import numpy as np
from datetime import datetime
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from PubLibrary.InifileHelper import IniRead
personname = 'demo'
session = '01'
DATAMOCK_LABEL_ADDR = 'tcp://127.0.0.1:8101' # datamock 标签命令地址
# ========== ZMQ 结果接收服务 ==========
class ZmqResultServer(threading.Thread):
def __init__(self, port=8088):
threading.Thread.__init__(self)
self.port = port
self.running = True
self.energy = 0
self.paradigm = 0 # 0=个体校准, 1=康复训练, 2=等待模型训练
self.ChoosenNum = -1
self.context = zmq.Context()
self.socket = self.context.socket(zmq.ROUTER)
self.socket.bind(f"tcp://0.0.0.0:{self.port}")
self.daemon = True
self.trial_idx = 0
def run(self):
print(f"[Server] UpperHost_Server listening on {self.port}")
while self.running:
try:
frames = self.socket.recv_multipart(zmq.NOBLOCK)
if len(frames) < 3:
continue
message = json.loads(frames[2].decode('utf-8'))
method = message.get('method')
params = message.get('params')
if method == 'energy':
self.energy = params
elif method == 'paradigm':
self.paradigm = params
print(f"[Server] paradigm -> {params}")
elif method == 'result':
self.ChoosenNum = params
self.trial_idx += 1
print(f"[Server] result={self.ChoosenNum} (trial {self.trial_idx})")
except zmq.Again:
time.sleep(0.005)
except Exception as e:
print(f"[Server] error: {e}")
def stop(self):
self.running = False
self.socket.close()
self.context.term()
# ========== ZMQ 命令发送客户端 ==========
class ZmqCmdClient:
def __init__(self, host, port):
self.host = host
self.port = port
self.context = zmq.Context()
self.socket = self.context.socket(zmq.DEALER)
# PUSH socket 用于向 datamock.py 发送标签命令
self._label_sock = self.context.socket(zmq.PUSH)
self._label_sock.connect(DATAMOCK_LABEL_ADDR)
print(f"[Client] label PUSH connected to {DATAMOCK_LABEL_ADDR}")
def connect(self):
self.socket.connect(f"tcp://{self.host}:{self.port}")
print(f"[Client] connected to {self.host}:{self.port}")
def start_recv_thread(self, result_server):
"""启动后台线程,持续接收 decoder 通过 8099 ROUTER 回发的消息,并更新 result_server 的状态"""
self._result_server = result_server
self._stop_recv = threading.Event()
def _recv_loop():
while not self._stop_recv.is_set():
try:
frames = self.socket.recv_multipart(zmq.NOBLOCK)
# DEALER 收到的格式: [b'', json_bytes]
data_bytes = frames[-1]
message = json.loads(data_bytes.decode('utf-8'))
method = message.get('method')
params = message.get('params')
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
print(f"[{ts}] [CmdClient] recv: {method}={params}")
if method == 'paradigm':
self._result_server.paradigm = params
print(f"[{ts}] [CmdClient] paradigm updated -> {params}")
elif method == 'result':
self._result_server.ChoosenNum = params
self._result_server.trial_idx += 1
print(f"[{ts}] [CmdClient] result={params} (trial {self._result_server.trial_idx})")
elif method == 'energy':
self._result_server.energy = params
except zmq.Again:
time.sleep(0.005)
except Exception as e:
print(f"[CmdClient recv] error: {e}")
time.sleep(0.01)
self._recv_thread = threading.Thread(target=_recv_loop, daemon=True)
self._recv_thread.start()
print(f"[Client] 后台接收线程已启动(监听 decoder 8099 回发消息)")
def stop_recv_thread(self):
if hasattr(self, '_stop_recv'):
self._stop_recv.set()
def _send_label(self, label_value):
"""向 datamock.py 发送标签命令"""
try:
self._label_sock.send_string(str(label_value), zmq.NOBLOCK)
except Exception as e:
print(f"[Client] label send error: {e}")
def send_data(self, method, params):
msg = {'method': method, 'params': params}
try:
self.socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
print(f"[{ts}] send_data: {method}={params}")
# 根据 train/predict 命令向 datamock 发送标签
if method == 'train':
if params == 0:
self._send_label(1)
print(f"[Label] train 0 -> datamock label=1")
elif params == 1:
self._send_label(2)
print(f"[Label] train 1 -> datamock label=2")
elif method == 'predict':
self._send_label(99)
print(f"[Label] predict -> datamock label=99")
except Exception as e:
print(f"[Client] send error: {e}")
# ========== 主流程 ==========
def run_headless():
server = ZmqResultServer(port=8088)
server.start()
_dh = str(IniRead('system', 'Decoder_Host'))
_dp = int(IniRead('system', 'Decoder_Port'))
client = ZmqCmdClient(_dh, _dp)
client.connect()
client.start_recv_thread(server) # 启动后台接收线程,监听 decoder 8099 回发的 paradigm/result 消息
time.sleep(1) # 等待连接建立
client.send_data('decoderClass', 'ssmvep')
train_time = 2.5 # 每轮训练刺激时长 (s)
test_time = 2.5 # 每轮测试刺激时长 (s)
right_rehabilitation = float(IniRead('system', 'Right_rehabilitation'))
fault_rehabilitation = float(IniRead('system', 'Fault_rehabilitation'))
rest_time = float(IniRead('system', 'Rest_time'))
num_blocks = int(IniRead('system', 'Num_blocks'))
num_trials = int(IniRead('system', 'Num_trials'))
position = [0, 1]
truePos_seq = position * int(num_trials / len(position))
truePos_seq = np.random.permutation(truePos_seq).tolist()
user_choice = []
os.makedirs('EEGFiles', exist_ok=True)
seq_file_path = f'EEGFiles/pos_seq_{personname}{session}_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
seq_info = {
'position': position,
'sequence': truePos_seq,
'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
with open(seq_file_path, 'w', encoding='utf-8') as f:
json.dump(seq_info, f, ensure_ascii=False, indent=2)
trained = 0
Num_Total = 0
Num_Success = 0
print("=" * 50)
print("[Headless] 开始运行 SSMVEP 通讯流程(无界面)")
print(f" num_blocks={num_blocks}, num_trials={num_trials}")
print(f" train_time={train_time}s, test_time={test_time}s")
print("=" * 50)
try:
while True:
# -------- 个体校准阶段 --------
print("\n[Phase] 个体校准阶段 (paradigm=0)")
client.send_data('rest', 0)
time.sleep(1)
# epoch完成需要的额外等待时间train_latency=120包×20ms=2.4s
# 在train_time后需再等epoch_wait秒decoder才能完成epoch采集并取出数据
epoch_wait = 2.4 # 秒与train_latency对应
while server.paradigm == 0:
# 左腿刺激
print(f"\n[Train] 左腿刺激 (train 0) trained={trained}")
client.send_data('train', 0)
time.sleep(train_time + epoch_wait) # 等待刺激时间+epoch完成时间
trained += 1
client.send_data('rest', 0)
time.sleep(max(0, abs(fault_rehabilitation - train_time) - epoch_wait))
# 右腿刺激
print(f"\n[Train] 右腿刺激 (train 1) trained={trained}")
client.send_data('train', 1)
time.sleep(train_time + epoch_wait) # 等待刺激时间+epoch完成时间
trained += 1
client.send_data('rest', 0)
time.sleep(max(0, fault_rehabilitation - epoch_wait))
# 个体校准阶段结束
print("\n[Phase] 个体校准结束,等待 paradigm=1 ...")
trained = 0
time.sleep(1)
# -------- 康复训练阶段 --------
while server.paradigm == 1:
print("\n[Phase] 康复训练阶段 (paradigm=1)")
for block_idx in range(num_blocks):
print(f"\n [Block {block_idx+1}/{num_blocks}]")
time.sleep(10) # 每轮开始前等待
for trial_idx in range(num_trials):
true_position = truePos_seq[trial_idx]
print(f" [Trial {trial_idx+1}/{num_trials}] true_pos={true_position}")
time.sleep(0.5) # 提示 + 叮声
server.ChoosenNum = -1
# 开始测试
# predict epoch latency = 115包×20ms = 2.3s需额外等待epoch完成
predict_epoch_wait = 2.3 # 秒与predict latency=115包对应
client.send_data('predict', 1)
t_start = time.perf_counter()
while time.perf_counter() - t_start < test_time + predict_epoch_wait:
if server.ChoosenNum >= 0:
Num_Total += 1
user_choice.append(server.ChoosenNum)
if server.ChoosenNum in [0, 1]:
Num_Success += 1
rest_time = right_rehabilitation
break
time.sleep(0.02)
trained += 1
client.send_data('rest', 0)
time.sleep(0.5)
time.sleep(rest_time)
server.ChoosenNum = -1
# 训练结束
print("\n[Phase] 康复训练结束")
break # 退出康复训练循环
# 统计结果
overall_accuracy = Num_Success / Num_Total if Num_Total > 0 else 0
expected_seq = truePos_seq * num_blocks
min_len = min(len(user_choice), len(expected_seq))
same_count = sum(1 for a, b in zip(user_choice[:min_len], expected_seq[:min_len]) if a == b)
true_accuracy = same_count / min_len if min_len > 0 else 0
print(f"\n[Result] Overall={overall_accuracy:.3f} ({Num_Success}/{Num_Total})")
print(f"[Result] TrueAcc={true_accuracy:.3f} ({same_count}/{min_len})")
break # 完成一个完整流程后退出
except KeyboardInterrupt:
print("\n[Headless] 用户中断")
finally:
client.send_data('predict', 2) # 关闭系统
client.send_data('saveData', 0)
server.stop()
print("[Headless] 已发送关闭指令,退出。")
if __name__ == '__main__':
run_headless()