Files
bci_algo/upperHost_stimmock/MI_headless.py
2026-06-10 09:25:11 +08:00

306 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
MI_headless.py
无界面版 MI 运动想象范式通讯流程模拟脚本。
复现 MI_main.py 的完整指令序列train 0/1, rest, predict, saveData
但不依赖 psychopy 也不打开任何窗口/音频,用 time.sleep 替代帧循环等待。
启动顺序:
1. runDecoder.py
2. datamock.py
3. MI_headless.py
"""
import sys
import os
import json
import time
import threading
import zmq
import numpy as np
import ast
from datetime import datetime
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from PubLibrary.InifileHelper import IniRead
personname = 'demo'
session = '01'
DATAMOCK_LABEL_ADDR = 'tcp://127.0.0.1:8101' # datamock 标签命令地址
# ========== ZMQ 结果接收服务 ==========
class ZmqResultServer(threading.Thread):
def __init__(self, port=8088):
threading.Thread.__init__(self)
self.port = port
self.running = True
self.energy = 0
self.paradigm = 0 # 0=个体校准, 1=康复训练, 2=等待模型训练
self.ChoosenNum = -1
self.context = zmq.Context()
self.socket = self.context.socket(zmq.ROUTER)
self.socket.bind(f"tcp://0.0.0.0:{self.port}")
self.daemon = True
self.trial_idx = 0
def run(self):
print(f"[Server] UpperHost_Server listening on {self.port}")
while self.running:
try:
frames = self.socket.recv_multipart(zmq.NOBLOCK)
if len(frames) < 3:
continue
message = json.loads(frames[2].decode('utf-8'))
method = message.get('method')
params = message.get('params')
if method == 'energy':
self.energy = params
elif method == 'paradigm':
self.paradigm = params
print(f"[Server] paradigm -> {params}")
elif method == 'result':
self.ChoosenNum = params
self.trial_idx += 1
print(f"[Server] result={self.ChoosenNum} (trial {self.trial_idx})")
except zmq.Again:
time.sleep(0.005)
except Exception as e:
print(f"[Server] error: {e}")
def stop(self):
self.running = False
self.socket.close()
self.context.term()
# ========== ZMQ 命令发送客户端 ==========
class ZmqCmdClient:
def __init__(self, host, port):
self.host = host
self.port = port
self.context = zmq.Context()
self.socket = self.context.socket(zmq.DEALER)
# PUSH socket 用于向 datamock.py 发送标签命令
self._label_sock = self.context.socket(zmq.PUSH)
self._label_sock.connect(DATAMOCK_LABEL_ADDR)
print(f"[Client] label PUSH connected to {DATAMOCK_LABEL_ADDR}")
def connect(self):
self.socket.connect(f"tcp://{self.host}:{self.port}")
print(f"[Client] connected to {self.host}:{self.port}")
def start_recv_thread(self, result_server):
"""启动后台线程,持续接收 decoder 通过 8099 ROUTER 回发的消息,并更新 result_server 的状态"""
self._result_server = result_server
self._stop_recv = threading.Event()
def _recv_loop():
while not self._stop_recv.is_set():
try:
frames = self.socket.recv_multipart(zmq.NOBLOCK)
# DEALER 收到的格式: [b'', json_bytes]
data_bytes = frames[-1]
message = json.loads(data_bytes.decode('utf-8'))
method = message.get('method')
params = message.get('params')
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
print(f"[{ts}] [CmdClient] recv: {method}={params}")
if method == 'paradigm':
self._result_server.paradigm = params
print(f"[{ts}] [CmdClient] paradigm updated -> {params}")
elif method == 'result':
self._result_server.ChoosenNum = params
self._result_server.trial_idx += 1
print(f"[{ts}] [CmdClient] result={params} (trial {self._result_server.trial_idx})")
elif method == 'energy':
self._result_server.energy = params
except zmq.Again:
time.sleep(0.005)
except Exception as e:
print(f"[CmdClient recv] error: {e}")
time.sleep(0.01)
self._recv_thread = threading.Thread(target=_recv_loop, daemon=True)
self._recv_thread.start()
print(f"[Client] 后台接收线程已启动(监听 decoder 8099 回发消息)")
def stop_recv_thread(self):
if hasattr(self, '_stop_recv'):
self._stop_recv.set()
def _send_label(self, label_value):
"""向 datamock.py 发送标签命令"""
try:
self._label_sock.send_string(str(label_value), zmq.NOBLOCK)
except Exception as e:
print(f"[Client] label send error: {e}")
def send_data(self, method, params):
msg = {'method': method, 'params': params}
try:
self.socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
print(f"[{ts}] send_data: {method}={params}")
# 根据 train/predict 命令向 datamock 发送标签
if method == 'train':
if params == 0:
self._send_label(1)
print(f"[Label] train 0 -> datamock label=1")
elif params == 1:
self._send_label(2)
print(f"[Label] train 1 -> datamock label=2")
elif method == 'predict':
self._send_label(99)
print(f"[Label] predict -> datamock label=99")
except Exception as e:
print(f"[Client] send error: {e}")
# ========== 主流程 ==========
def run_headless():
server = ZmqResultServer(port=8088)
server.start()
_dh = str(IniRead('system', 'Decoder_Host'))
_dp = int(IniRead('system', 'Decoder_Port'))
client = ZmqCmdClient(_dh, _dp)
client.connect()
client.start_recv_thread(server) # 启动后台接收线程,监听 decoder 8099 回发的 paradigm/result 消息
time.sleep(1) # 等待连接建立
client.send_data('decoderClass', 'mi')
# MI_IntervalEpoch = [0.5, 4.5]trial时长 = 4.5-0.5 = 4.0s
_mi_iv = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
_trial_sec = float(_mi_iv[1] - _mi_iv[0])
_margin = 1.0
train_time = max(5.0, _trial_sec + _margin) # 训练刺激时长(与 MI_main.py 保持一致)
# MI epoch latency = interval_epoch[1] // 5 = (4.5*250)//5 = 225包 × 20ms = 4.5s
# train_latency = 225包MI中 train_latency == latency
# 在 train_time 后需再等 epoch_wait 秒decoder 才能完成 epoch 采集
epoch_wait = _mi_iv[1] / _mi_iv[1] * (_mi_iv[1] * 250 // 5) * 0.02 # = latency * 20ms
# 更直接的计算latency = interval_epoch[1] // 5 = int(4.5*250)//5 = 225225*0.02 = 4.5s
epoch_wait = (int(_mi_iv[1] * 250) // 5) * 0.02 # 4.5s
# predict epoch wait与 train 相同MI中 latency == train_latency
predict_epoch_wait = epoch_wait # 4.5s
test_time = 7.0 # 预测窗口时长(与 MI_main.py 保持一致)
right_rehabilitation = float(IniRead('system', 'Right_rehabilitation'))
fault_rehabilitation = float(IniRead('system', 'Fault_rehabilitation'))
rest_time = float(IniRead('system', 'Rest_time'))
num_blocks = int(IniRead('system', 'Num_blocks'))
num_trials = int(IniRead('system', 'Num_trials'))
trained = 0
Num_Total = 0
Num_Success = 0
user_choice = []
print("=" * 50)
print("[Headless] 开始运行 MI 通讯流程(无界面)")
print(f" MI_IntervalEpoch={_mi_iv}, trial_sec={_trial_sec:.2f}s")
print(f" train_time={train_time:.2f}s, epoch_wait={epoch_wait:.2f}s")
print(f" test_time={test_time:.2f}s, predict_epoch_wait={predict_epoch_wait:.2f}s")
print(f" num_blocks={num_blocks}, num_trials={num_trials}")
print("=" * 50)
try:
while True:
# -------- 个体校准阶段 --------
print("\n[Phase] 个体校准阶段 (paradigm=0)")
client.send_data('rest', 0)
time.sleep(1)
while server.paradigm == 0:
# 左侧 MI 刺激train 0label=1
print(f"\n[Train] 左侧 MI 刺激 (train 0) trained={trained}")
client.send_data('rest', 0)
time.sleep(0.5) # ding 提示后等待
client.send_data('train', 0)
time.sleep(train_time + epoch_wait) # 等待刺激时间 + epoch 完成时间
trained += 1
client.send_data('rest', 0)
time.sleep(1.0) # 类间休息
# 空闲态样本采集train 1label=2
print(f"\n[Train] 空闲态采集 (train 1) trained={trained}")
client.send_data('train', 1)
time.sleep(train_time + epoch_wait) # 等待刺激时间 + epoch 完成时间
trained += 1
client.send_data('rest', 0)
time.sleep(1.0) # 类间休息
# 个体校准阶段结束
print("\n[Phase] 个体校准结束,等待模型训练 (paradigm=2) ...")
trained = 0
time.sleep(1)
# 等待模型训练完成 (paradigm=2 -> paradigm=1)
while server.paradigm == 2:
print("[Phase] 等待模型训练完成 ...")
time.sleep(0.5)
# -------- 康复训练阶段 --------
while server.paradigm == 1:
print("\n[Phase] 康复训练阶段 (paradigm=1)")
for block_idx in range(num_blocks):
print(f"\n [Block {block_idx+1}/{num_blocks}]")
time.sleep(10) # 每轮开始前等待
for trial_idx in range(num_trials):
print(f" [Trial {trial_idx+1}/{num_trials}]")
time.sleep(0.5) # ding 提示
server.ChoosenNum = -1
# 开始预测
# MI predict epoch latency = 225包 × 20ms = 4.5s,需额外等待 epoch 完成
client.send_data('predict', 1)
t_start = time.perf_counter()
while time.perf_counter() - t_start < test_time + predict_epoch_wait:
if server.ChoosenNum >= 0:
Num_Total += 1
user_choice.append(server.ChoosenNum)
if server.ChoosenNum == 0:
Num_Success += 1
rest_time = right_rehabilitation
elif server.ChoosenNum == 1:
rest_time = fault_rehabilitation
break
time.sleep(0.02)
trained += 1
client.send_data('rest', 0)
time.sleep(0.5)
time.sleep(rest_time)
server.ChoosenNum = -1
# 训练结束
print("\n[Phase] 康复训练结束")
break # 退出康复训练循环
# 统计结果
overall_accuracy = Num_Success / Num_Total if Num_Total > 0 else 0
print(f"\n[Result] Overall={overall_accuracy:.3f} ({Num_Success}/{Num_Total})")
print(f"[Result] user_choice={user_choice}")
break # 完成一个完整流程后退出
except KeyboardInterrupt:
print("\n[Headless] 用户中断")
finally:
client.send_data('predict', 2) # 关闭系统
client.send_data('saveData', 0)
server.stop()
print("[Headless] 已发送关闭指令,退出。")
if __name__ == '__main__':
run_headless()