upper mock
This commit is contained in:
305
upperHost_stimmock/MI_headless.py
Normal file
305
upperHost_stimmock/MI_headless.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
MI_headless.py
|
||||
无界面版 MI 运动想象范式通讯流程模拟脚本。
|
||||
复现 MI_main.py 的完整指令序列(train 0/1, rest, predict, saveData),
|
||||
但不依赖 psychopy 也不打开任何窗口/音频,用 time.sleep 替代帧循环等待。
|
||||
|
||||
启动顺序:
|
||||
1. runDecoder.py
|
||||
2. datamock.py
|
||||
3. MI_headless.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
import zmq
|
||||
import numpy as np
|
||||
import ast
|
||||
from datetime import datetime
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from PubLibrary.InifileHelper import IniRead
|
||||
|
||||
personname = 'demo'
|
||||
session = '01'
|
||||
|
||||
DATAMOCK_LABEL_ADDR = 'tcp://127.0.0.1:8101' # datamock 标签命令地址
|
||||
|
||||
|
||||
# ========== ZMQ 结果接收服务 ==========
|
||||
class ZmqResultServer(threading.Thread):
|
||||
def __init__(self, port=8088):
|
||||
threading.Thread.__init__(self)
|
||||
self.port = port
|
||||
self.running = True
|
||||
self.energy = 0
|
||||
self.paradigm = 0 # 0=个体校准, 1=康复训练, 2=等待模型训练
|
||||
self.ChoosenNum = -1
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.ROUTER)
|
||||
self.socket.bind(f"tcp://0.0.0.0:{self.port}")
|
||||
self.daemon = True
|
||||
self.trial_idx = 0
|
||||
|
||||
def run(self):
|
||||
print(f"[Server] UpperHost_Server listening on {self.port}")
|
||||
while self.running:
|
||||
try:
|
||||
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
||||
if len(frames) < 3:
|
||||
continue
|
||||
message = json.loads(frames[2].decode('utf-8'))
|
||||
method = message.get('method')
|
||||
params = message.get('params')
|
||||
if method == 'energy':
|
||||
self.energy = params
|
||||
elif method == 'paradigm':
|
||||
self.paradigm = params
|
||||
print(f"[Server] paradigm -> {params}")
|
||||
elif method == 'result':
|
||||
self.ChoosenNum = params
|
||||
self.trial_idx += 1
|
||||
print(f"[Server] result={self.ChoosenNum} (trial {self.trial_idx})")
|
||||
except zmq.Again:
|
||||
time.sleep(0.005)
|
||||
except Exception as e:
|
||||
print(f"[Server] error: {e}")
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
self.socket.close()
|
||||
self.context.term()
|
||||
|
||||
|
||||
# ========== ZMQ 命令发送客户端 ==========
|
||||
class ZmqCmdClient:
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.DEALER)
|
||||
# PUSH socket 用于向 datamock.py 发送标签命令
|
||||
self._label_sock = self.context.socket(zmq.PUSH)
|
||||
self._label_sock.connect(DATAMOCK_LABEL_ADDR)
|
||||
print(f"[Client] label PUSH connected to {DATAMOCK_LABEL_ADDR}")
|
||||
|
||||
def connect(self):
|
||||
self.socket.connect(f"tcp://{self.host}:{self.port}")
|
||||
print(f"[Client] connected to {self.host}:{self.port}")
|
||||
|
||||
def start_recv_thread(self, result_server):
|
||||
"""启动后台线程,持续接收 decoder 通过 8099 ROUTER 回发的消息,并更新 result_server 的状态"""
|
||||
self._result_server = result_server
|
||||
self._stop_recv = threading.Event()
|
||||
|
||||
def _recv_loop():
|
||||
while not self._stop_recv.is_set():
|
||||
try:
|
||||
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
||||
# DEALER 收到的格式: [b'', json_bytes]
|
||||
data_bytes = frames[-1]
|
||||
message = json.loads(data_bytes.decode('utf-8'))
|
||||
method = message.get('method')
|
||||
params = message.get('params')
|
||||
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||
print(f"[{ts}] [CmdClient] recv: {method}={params}")
|
||||
if method == 'paradigm':
|
||||
self._result_server.paradigm = params
|
||||
print(f"[{ts}] [CmdClient] paradigm updated -> {params}")
|
||||
elif method == 'result':
|
||||
self._result_server.ChoosenNum = params
|
||||
self._result_server.trial_idx += 1
|
||||
print(f"[{ts}] [CmdClient] result={params} (trial {self._result_server.trial_idx})")
|
||||
elif method == 'energy':
|
||||
self._result_server.energy = params
|
||||
except zmq.Again:
|
||||
time.sleep(0.005)
|
||||
except Exception as e:
|
||||
print(f"[CmdClient recv] error: {e}")
|
||||
time.sleep(0.01)
|
||||
|
||||
self._recv_thread = threading.Thread(target=_recv_loop, daemon=True)
|
||||
self._recv_thread.start()
|
||||
print(f"[Client] 后台接收线程已启动(监听 decoder 8099 回发消息)")
|
||||
|
||||
def stop_recv_thread(self):
|
||||
if hasattr(self, '_stop_recv'):
|
||||
self._stop_recv.set()
|
||||
|
||||
def _send_label(self, label_value):
|
||||
"""向 datamock.py 发送标签命令"""
|
||||
try:
|
||||
self._label_sock.send_string(str(label_value), zmq.NOBLOCK)
|
||||
except Exception as e:
|
||||
print(f"[Client] label send error: {e}")
|
||||
|
||||
def send_data(self, method, params):
|
||||
msg = {'method': method, 'params': params}
|
||||
try:
|
||||
self.socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
|
||||
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||
print(f"[{ts}] send_data: {method}={params}")
|
||||
# 根据 train/predict 命令向 datamock 发送标签
|
||||
if method == 'train':
|
||||
if params == 0:
|
||||
self._send_label(1)
|
||||
print(f"[Label] train 0 -> datamock label=1")
|
||||
elif params == 1:
|
||||
self._send_label(2)
|
||||
print(f"[Label] train 1 -> datamock label=2")
|
||||
elif method == 'predict':
|
||||
self._send_label(99)
|
||||
print(f"[Label] predict -> datamock label=99")
|
||||
except Exception as e:
|
||||
print(f"[Client] send error: {e}")
|
||||
|
||||
|
||||
# ========== 主流程 ==========
|
||||
def run_headless():
|
||||
server = ZmqResultServer(port=8088)
|
||||
server.start()
|
||||
|
||||
_dh = str(IniRead('system', 'Decoder_Host'))
|
||||
_dp = int(IniRead('system', 'Decoder_Port'))
|
||||
client = ZmqCmdClient(_dh, _dp)
|
||||
client.connect()
|
||||
client.start_recv_thread(server) # 启动后台接收线程,监听 decoder 8099 回发的 paradigm/result 消息
|
||||
|
||||
time.sleep(1) # 等待连接建立
|
||||
client.send_data('decoderClass', 'mi')
|
||||
|
||||
# MI_IntervalEpoch = [0.5, 4.5],trial时长 = 4.5-0.5 = 4.0s
|
||||
_mi_iv = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
||||
_trial_sec = float(_mi_iv[1] - _mi_iv[0])
|
||||
_margin = 1.0
|
||||
train_time = max(5.0, _trial_sec + _margin) # 训练刺激时长(与 MI_main.py 保持一致)
|
||||
|
||||
# MI epoch latency = interval_epoch[1] // 5 = (4.5*250)//5 = 225包 × 20ms = 4.5s
|
||||
# train_latency = 225包(MI中 train_latency == latency)
|
||||
# 在 train_time 后需再等 epoch_wait 秒,decoder 才能完成 epoch 采集
|
||||
epoch_wait = _mi_iv[1] / _mi_iv[1] * (_mi_iv[1] * 250 // 5) * 0.02 # = latency * 20ms
|
||||
# 更直接的计算:latency = interval_epoch[1] // 5 = int(4.5*250)//5 = 225,225*0.02 = 4.5s
|
||||
epoch_wait = (int(_mi_iv[1] * 250) // 5) * 0.02 # 4.5s
|
||||
|
||||
# predict epoch wait(与 train 相同,MI中 latency == train_latency)
|
||||
predict_epoch_wait = epoch_wait # 4.5s
|
||||
|
||||
test_time = 7.0 # 预测窗口时长(与 MI_main.py 保持一致)
|
||||
right_rehabilitation = float(IniRead('system', 'Right_rehabilitation'))
|
||||
fault_rehabilitation = float(IniRead('system', 'Fault_rehabilitation'))
|
||||
rest_time = float(IniRead('system', 'Rest_time'))
|
||||
|
||||
num_blocks = int(IniRead('system', 'Num_blocks'))
|
||||
num_trials = int(IniRead('system', 'Num_trials'))
|
||||
|
||||
trained = 0
|
||||
Num_Total = 0
|
||||
Num_Success = 0
|
||||
user_choice = []
|
||||
|
||||
print("=" * 50)
|
||||
print("[Headless] 开始运行 MI 通讯流程(无界面)")
|
||||
print(f" MI_IntervalEpoch={_mi_iv}, trial_sec={_trial_sec:.2f}s")
|
||||
print(f" train_time={train_time:.2f}s, epoch_wait={epoch_wait:.2f}s")
|
||||
print(f" test_time={test_time:.2f}s, predict_epoch_wait={predict_epoch_wait:.2f}s")
|
||||
print(f" num_blocks={num_blocks}, num_trials={num_trials}")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# -------- 个体校准阶段 --------
|
||||
print("\n[Phase] 个体校准阶段 (paradigm=0)")
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(1)
|
||||
|
||||
while server.paradigm == 0:
|
||||
# 左侧 MI 刺激(train 0,label=1)
|
||||
print(f"\n[Train] 左侧 MI 刺激 (train 0) trained={trained}")
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(0.5) # ding 提示后等待
|
||||
|
||||
client.send_data('train', 0)
|
||||
time.sleep(train_time + epoch_wait) # 等待刺激时间 + epoch 完成时间
|
||||
|
||||
trained += 1
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(1.0) # 类间休息
|
||||
|
||||
# 空闲态样本采集(train 1,label=2)
|
||||
print(f"\n[Train] 空闲态采集 (train 1) trained={trained}")
|
||||
client.send_data('train', 1)
|
||||
time.sleep(train_time + epoch_wait) # 等待刺激时间 + epoch 完成时间
|
||||
|
||||
trained += 1
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(1.0) # 类间休息
|
||||
|
||||
# 个体校准阶段结束
|
||||
print("\n[Phase] 个体校准结束,等待模型训练 (paradigm=2) ...")
|
||||
trained = 0
|
||||
time.sleep(1)
|
||||
|
||||
# 等待模型训练完成 (paradigm=2 -> paradigm=1)
|
||||
while server.paradigm == 2:
|
||||
print("[Phase] 等待模型训练完成 ...")
|
||||
time.sleep(0.5)
|
||||
|
||||
# -------- 康复训练阶段 --------
|
||||
while server.paradigm == 1:
|
||||
print("\n[Phase] 康复训练阶段 (paradigm=1)")
|
||||
for block_idx in range(num_blocks):
|
||||
print(f"\n [Block {block_idx+1}/{num_blocks}]")
|
||||
time.sleep(10) # 每轮开始前等待
|
||||
|
||||
for trial_idx in range(num_trials):
|
||||
print(f" [Trial {trial_idx+1}/{num_trials}]")
|
||||
|
||||
time.sleep(0.5) # ding 提示
|
||||
server.ChoosenNum = -1
|
||||
|
||||
# 开始预测
|
||||
# MI predict epoch latency = 225包 × 20ms = 4.5s,需额外等待 epoch 完成
|
||||
client.send_data('predict', 1)
|
||||
t_start = time.perf_counter()
|
||||
while time.perf_counter() - t_start < test_time + predict_epoch_wait:
|
||||
if server.ChoosenNum >= 0:
|
||||
Num_Total += 1
|
||||
user_choice.append(server.ChoosenNum)
|
||||
if server.ChoosenNum == 0:
|
||||
Num_Success += 1
|
||||
rest_time = right_rehabilitation
|
||||
elif server.ChoosenNum == 1:
|
||||
rest_time = fault_rehabilitation
|
||||
break
|
||||
time.sleep(0.02)
|
||||
|
||||
trained += 1
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(0.5)
|
||||
time.sleep(rest_time)
|
||||
server.ChoosenNum = -1
|
||||
|
||||
# 训练结束
|
||||
print("\n[Phase] 康复训练结束")
|
||||
break # 退出康复训练循环
|
||||
|
||||
# 统计结果
|
||||
overall_accuracy = Num_Success / Num_Total if Num_Total > 0 else 0
|
||||
print(f"\n[Result] Overall={overall_accuracy:.3f} ({Num_Success}/{Num_Total})")
|
||||
print(f"[Result] user_choice={user_choice}")
|
||||
break # 完成一个完整流程后退出
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n[Headless] 用户中断")
|
||||
finally:
|
||||
client.send_data('predict', 2) # 关闭系统
|
||||
client.send_data('saveData', 0)
|
||||
server.stop()
|
||||
print("[Headless] 已发送关闭指令,退出。")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_headless()
|
||||
301
upperHost_stimmock/ssmvep_headless.py
Normal file
301
upperHost_stimmock/ssmvep_headless.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
ssmvep_headless.py
|
||||
无界面版 SSMVEP 范式通讯流程模拟脚本。
|
||||
复现 ssmvep_main.py 的完整指令序列(train 0/1/2, rest, predict, saveData),
|
||||
但不依赖 psychopy 也不打开任何窗口/音频,用 time.sleep 替代帧循环等待。
|
||||
|
||||
启动顺序:
|
||||
1. runDecoder.py
|
||||
2. datamock.py
|
||||
3. ssmvep_headless.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
import zmq
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from PubLibrary.InifileHelper import IniRead
|
||||
|
||||
personname = 'demo'
|
||||
session = '01'
|
||||
|
||||
DATAMOCK_LABEL_ADDR = 'tcp://127.0.0.1:8101' # datamock 标签命令地址
|
||||
|
||||
|
||||
# ========== ZMQ 结果接收服务 ==========
|
||||
class ZmqResultServer(threading.Thread):
|
||||
def __init__(self, port=8088):
|
||||
threading.Thread.__init__(self)
|
||||
self.port = port
|
||||
self.running = True
|
||||
self.energy = 0
|
||||
self.paradigm = 0 # 0=个体校准, 1=康复训练, 2=等待模型训练
|
||||
self.ChoosenNum = -1
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.ROUTER)
|
||||
self.socket.bind(f"tcp://0.0.0.0:{self.port}")
|
||||
self.daemon = True
|
||||
self.trial_idx = 0
|
||||
|
||||
def run(self):
|
||||
print(f"[Server] UpperHost_Server listening on {self.port}")
|
||||
while self.running:
|
||||
try:
|
||||
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
||||
if len(frames) < 3:
|
||||
continue
|
||||
message = json.loads(frames[2].decode('utf-8'))
|
||||
method = message.get('method')
|
||||
params = message.get('params')
|
||||
if method == 'energy':
|
||||
self.energy = params
|
||||
elif method == 'paradigm':
|
||||
self.paradigm = params
|
||||
print(f"[Server] paradigm -> {params}")
|
||||
elif method == 'result':
|
||||
self.ChoosenNum = params
|
||||
self.trial_idx += 1
|
||||
print(f"[Server] result={self.ChoosenNum} (trial {self.trial_idx})")
|
||||
except zmq.Again:
|
||||
time.sleep(0.005)
|
||||
except Exception as e:
|
||||
print(f"[Server] error: {e}")
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
self.socket.close()
|
||||
self.context.term()
|
||||
|
||||
|
||||
# ========== ZMQ 命令发送客户端 ==========
|
||||
class ZmqCmdClient:
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.DEALER)
|
||||
# PUSH socket 用于向 datamock.py 发送标签命令
|
||||
self._label_sock = self.context.socket(zmq.PUSH)
|
||||
self._label_sock.connect(DATAMOCK_LABEL_ADDR)
|
||||
print(f"[Client] label PUSH connected to {DATAMOCK_LABEL_ADDR}")
|
||||
|
||||
def connect(self):
|
||||
self.socket.connect(f"tcp://{self.host}:{self.port}")
|
||||
print(f"[Client] connected to {self.host}:{self.port}")
|
||||
|
||||
def start_recv_thread(self, result_server):
|
||||
"""启动后台线程,持续接收 decoder 通过 8099 ROUTER 回发的消息,并更新 result_server 的状态"""
|
||||
self._result_server = result_server
|
||||
self._stop_recv = threading.Event()
|
||||
|
||||
def _recv_loop():
|
||||
while not self._stop_recv.is_set():
|
||||
try:
|
||||
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
||||
# DEALER 收到的格式: [b'', json_bytes]
|
||||
data_bytes = frames[-1]
|
||||
message = json.loads(data_bytes.decode('utf-8'))
|
||||
method = message.get('method')
|
||||
params = message.get('params')
|
||||
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||
print(f"[{ts}] [CmdClient] recv: {method}={params}")
|
||||
if method == 'paradigm':
|
||||
self._result_server.paradigm = params
|
||||
print(f"[{ts}] [CmdClient] paradigm updated -> {params}")
|
||||
elif method == 'result':
|
||||
self._result_server.ChoosenNum = params
|
||||
self._result_server.trial_idx += 1
|
||||
print(f"[{ts}] [CmdClient] result={params} (trial {self._result_server.trial_idx})")
|
||||
elif method == 'energy':
|
||||
self._result_server.energy = params
|
||||
except zmq.Again:
|
||||
time.sleep(0.005)
|
||||
except Exception as e:
|
||||
print(f"[CmdClient recv] error: {e}")
|
||||
time.sleep(0.01)
|
||||
|
||||
self._recv_thread = threading.Thread(target=_recv_loop, daemon=True)
|
||||
self._recv_thread.start()
|
||||
print(f"[Client] 后台接收线程已启动(监听 decoder 8099 回发消息)")
|
||||
|
||||
def stop_recv_thread(self):
|
||||
if hasattr(self, '_stop_recv'):
|
||||
self._stop_recv.set()
|
||||
|
||||
def _send_label(self, label_value):
|
||||
"""向 datamock.py 发送标签命令"""
|
||||
try:
|
||||
self._label_sock.send_string(str(label_value), zmq.NOBLOCK)
|
||||
except Exception as e:
|
||||
print(f"[Client] label send error: {e}")
|
||||
|
||||
def send_data(self, method, params):
|
||||
msg = {'method': method, 'params': params}
|
||||
try:
|
||||
self.socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
|
||||
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||
print(f"[{ts}] send_data: {method}={params}")
|
||||
# 根据 train/predict 命令向 datamock 发送标签
|
||||
if method == 'train':
|
||||
if params == 0:
|
||||
self._send_label(1)
|
||||
print(f"[Label] train 0 -> datamock label=1")
|
||||
elif params == 1:
|
||||
self._send_label(2)
|
||||
print(f"[Label] train 1 -> datamock label=2")
|
||||
elif method == 'predict':
|
||||
self._send_label(99)
|
||||
print(f"[Label] predict -> datamock label=99")
|
||||
except Exception as e:
|
||||
print(f"[Client] send error: {e}")
|
||||
|
||||
|
||||
# ========== 主流程 ==========
|
||||
def run_headless():
|
||||
server = ZmqResultServer(port=8088)
|
||||
server.start()
|
||||
|
||||
_dh = str(IniRead('system', 'Decoder_Host'))
|
||||
_dp = int(IniRead('system', 'Decoder_Port'))
|
||||
client = ZmqCmdClient(_dh, _dp)
|
||||
client.connect()
|
||||
client.start_recv_thread(server) # 启动后台接收线程,监听 decoder 8099 回发的 paradigm/result 消息
|
||||
|
||||
time.sleep(1) # 等待连接建立
|
||||
client.send_data('decoderClass', 'ssmvep')
|
||||
|
||||
train_time = 2.5 # 每轮训练刺激时长 (s)
|
||||
test_time = 2.5 # 每轮测试刺激时长 (s)
|
||||
right_rehabilitation = float(IniRead('system', 'Right_rehabilitation'))
|
||||
fault_rehabilitation = float(IniRead('system', 'Fault_rehabilitation'))
|
||||
rest_time = float(IniRead('system', 'Rest_time'))
|
||||
|
||||
num_blocks = int(IniRead('system', 'Num_blocks'))
|
||||
num_trials = int(IniRead('system', 'Num_trials'))
|
||||
|
||||
position = [0, 1]
|
||||
truePos_seq = position * int(num_trials / len(position))
|
||||
truePos_seq = np.random.permutation(truePos_seq).tolist()
|
||||
user_choice = []
|
||||
|
||||
os.makedirs('EEGFiles', exist_ok=True)
|
||||
seq_file_path = f'EEGFiles/pos_seq_{personname}{session}_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
||||
seq_info = {
|
||||
'position': position,
|
||||
'sequence': truePos_seq,
|
||||
'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
with open(seq_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(seq_info, f, ensure_ascii=False, indent=2)
|
||||
|
||||
trained = 0
|
||||
Num_Total = 0
|
||||
Num_Success = 0
|
||||
|
||||
print("=" * 50)
|
||||
print("[Headless] 开始运行 SSMVEP 通讯流程(无界面)")
|
||||
print(f" num_blocks={num_blocks}, num_trials={num_trials}")
|
||||
print(f" train_time={train_time}s, test_time={test_time}s")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# -------- 个体校准阶段 --------
|
||||
print("\n[Phase] 个体校准阶段 (paradigm=0)")
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(1)
|
||||
|
||||
# epoch完成需要的额外等待时间:train_latency=120包×20ms=2.4s
|
||||
# 在train_time后需再等epoch_wait秒,decoder才能完成epoch采集并取出数据
|
||||
epoch_wait = 2.4 # 秒,与train_latency对应
|
||||
|
||||
while server.paradigm == 0:
|
||||
# 左腿刺激
|
||||
print(f"\n[Train] 左腿刺激 (train 0) trained={trained}")
|
||||
client.send_data('train', 0)
|
||||
time.sleep(train_time + epoch_wait) # 等待刺激时间+epoch完成时间
|
||||
|
||||
trained += 1
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(max(0, abs(fault_rehabilitation - train_time) - epoch_wait))
|
||||
|
||||
# 右腿刺激
|
||||
print(f"\n[Train] 右腿刺激 (train 1) trained={trained}")
|
||||
client.send_data('train', 1)
|
||||
time.sleep(train_time + epoch_wait) # 等待刺激时间+epoch完成时间
|
||||
|
||||
trained += 1
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(max(0, fault_rehabilitation - epoch_wait))
|
||||
|
||||
# 个体校准阶段结束
|
||||
print("\n[Phase] 个体校准结束,等待 paradigm=1 ...")
|
||||
trained = 0
|
||||
time.sleep(1)
|
||||
|
||||
# -------- 康复训练阶段 --------
|
||||
while server.paradigm == 1:
|
||||
print("\n[Phase] 康复训练阶段 (paradigm=1)")
|
||||
for block_idx in range(num_blocks):
|
||||
print(f"\n [Block {block_idx+1}/{num_blocks}]")
|
||||
time.sleep(10) # 每轮开始前等待
|
||||
|
||||
for trial_idx in range(num_trials):
|
||||
true_position = truePos_seq[trial_idx]
|
||||
print(f" [Trial {trial_idx+1}/{num_trials}] true_pos={true_position}")
|
||||
|
||||
time.sleep(0.5) # 提示 + 叮声
|
||||
server.ChoosenNum = -1
|
||||
|
||||
# 开始测试
|
||||
# predict epoch latency = 115包×20ms = 2.3s,需额外等待epoch完成
|
||||
predict_epoch_wait = 2.3 # 秒,与predict latency=115包对应
|
||||
client.send_data('predict', 1)
|
||||
t_start = time.perf_counter()
|
||||
while time.perf_counter() - t_start < test_time + predict_epoch_wait:
|
||||
if server.ChoosenNum >= 0:
|
||||
Num_Total += 1
|
||||
user_choice.append(server.ChoosenNum)
|
||||
if server.ChoosenNum in [0, 1]:
|
||||
Num_Success += 1
|
||||
rest_time = right_rehabilitation
|
||||
break
|
||||
time.sleep(0.02)
|
||||
|
||||
trained += 1
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(0.5)
|
||||
time.sleep(rest_time)
|
||||
server.ChoosenNum = -1
|
||||
|
||||
# 训练结束
|
||||
print("\n[Phase] 康复训练结束")
|
||||
break # 退出康复训练循环
|
||||
|
||||
# 统计结果
|
||||
overall_accuracy = Num_Success / Num_Total if Num_Total > 0 else 0
|
||||
expected_seq = truePos_seq * num_blocks
|
||||
min_len = min(len(user_choice), len(expected_seq))
|
||||
same_count = sum(1 for a, b in zip(user_choice[:min_len], expected_seq[:min_len]) if a == b)
|
||||
true_accuracy = same_count / min_len if min_len > 0 else 0
|
||||
print(f"\n[Result] Overall={overall_accuracy:.3f} ({Num_Success}/{Num_Total})")
|
||||
print(f"[Result] TrueAcc={true_accuracy:.3f} ({same_count}/{min_len})")
|
||||
break # 完成一个完整流程后退出
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n[Headless] 用户中断")
|
||||
finally:
|
||||
client.send_data('predict', 2) # 关闭系统
|
||||
client.send_data('saveData', 0)
|
||||
server.stop()
|
||||
print("[Headless] 已发送关闭指令,退出。")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_headless()
|
||||
364
upperHost_stimmock/ssvep_main.py
Normal file
364
upperHost_stimmock/ssvep_main.py
Normal file
@@ -0,0 +1,364 @@
|
||||
import time
|
||||
|
||||
from psychopy import visual, core, logging # import some libraries from PsychoPy
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
# LAB STREAMING LAYER1
|
||||
from pylsl import StreamInfo, StreamOutlet
|
||||
from psychopy import event
|
||||
import numpy as np
|
||||
from DecoderDW.Server import TCPServer
|
||||
from DecoderDW.Client import TCPClient
|
||||
# import subprocess
|
||||
|
||||
# ----------------------
|
||||
# constants
|
||||
# size of the window
|
||||
WINWIDTH = 1920
|
||||
WINHEIGHT = 1080
|
||||
REFRESH_RATE = 144
|
||||
|
||||
|
||||
|
||||
def get_keypress():
|
||||
keys = event.getKeys()
|
||||
if keys:
|
||||
return keys[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def shutdown(win,client):
|
||||
client.send_data('saveData', 0)
|
||||
client.send_data('predict',2)
|
||||
win.close()
|
||||
core.quit()
|
||||
|
||||
|
||||
# end of configuration
|
||||
# ----------------------
|
||||
|
||||
def generate_square_wave(frequency, sampling_rate=REFRESH_RATE, duration=5):
|
||||
"""
|
||||
生成方波序列
|
||||
|
||||
参数:
|
||||
frequency (float): 频率(Hz)
|
||||
sampling_rate (int): 采样率(Hz),应与屏幕刷新率一致
|
||||
duration (float): 时长(秒)
|
||||
|
||||
返回:
|
||||
square_wave (list): 方波序列
|
||||
"""
|
||||
# 计算总点数
|
||||
n_points = int(duration * sampling_rate)
|
||||
|
||||
# 生成时间序列
|
||||
time = np.linspace(0, duration, n_points, endpoint=False)
|
||||
|
||||
# 生成正弦波数据
|
||||
sin_wave = np.sin(2 * np.pi * frequency * time)
|
||||
# 生成方波数据
|
||||
square_wave = np.where(sin_wave >= 0, 1, 0)
|
||||
|
||||
return square_wave.tolist()
|
||||
|
||||
|
||||
# 启动一个进程,不等待其完成
|
||||
import os
|
||||
if __name__ == "__main__":
|
||||
# ----------------------------------------------------------------------------------
|
||||
# main window settings
|
||||
main_win = visual.Window(size=(WINWIDTH, WINHEIGHT), units='height', screen=0, fullscr=False,
|
||||
gammaErrorPolicy='warn', color=(0.7, 0.7, 0.7))
|
||||
print('starting 1')
|
||||
# Set up LabStreamingLayer stream.
|
||||
info = StreamInfo(name='psychopy_stimuli', type='Markers', channel_count=1, channel_format='string',
|
||||
source_id='psychopy_stimuli_001')
|
||||
outlet = StreamOutlet(info) # Broadcast the stream.
|
||||
|
||||
imageStim1 = visual.ImageStim(main_win, size=(300, 300), pos=(-600, 300), units='pix', image='UI/figures/xy.jpg')
|
||||
txtStim1 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||
italic=False, pos=(-600, 30))
|
||||
|
||||
imageStim2 = visual.ImageStim(main_win, size=(300, 300), pos=(0, 300), units='pix', image='UI/figures/xy.jpg')
|
||||
txtStim2 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||
italic=False, pos=(0, 30))
|
||||
|
||||
imageStim3 = visual.ImageStim(main_win, size=(300, 300), pos=(600, 300), units='pix', image='UI/figures/xy.jpg')
|
||||
txtStim3 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||
italic=False, pos=(600, 30))
|
||||
imageStim4 = visual.ImageStim(main_win, size=(300, 300), pos=(-600, -200), units='pix', image='UI/figures/xy.jpg')
|
||||
txtStim4 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||
italic=False, pos=(-600, -470))
|
||||
imageStim5 = visual.ImageStim(main_win, size=(300, 300), pos=(0, -200), units='pix', image='UI/figures/xy.jpg')
|
||||
txtStim5 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||
italic=False, pos=(0, -470))
|
||||
imageStim6 = visual.ImageStim(main_win, size=(300, 300), pos=(600, -200), units='pix', image='UI/figures/xy.jpg')
|
||||
txtStim6 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||
italic=False, pos=(600, -470))
|
||||
imageStim1red = visual.ImageStim(main_win, size=(300, 300), pos=(-600, 300), units='pix', image='UI/figures/xy_red.jpg')
|
||||
imageStim2red = visual.ImageStim(main_win, size=(300, 300), pos=(0, 300), units='pix', image='UI/figures/xy_red.jpg')
|
||||
imageStim3red = visual.ImageStim(main_win, size=(300, 300), pos=(600, 300), units='pix', image='UI/figures/xy_red.jpg')
|
||||
imageStim4red = visual.ImageStim(main_win, size=(300, 300), pos=(-600, -200), units='pix', image='UI/figures/xy_red.jpg')
|
||||
imageStim5red = visual.ImageStim(main_win, size=(300, 300), pos=(0, -200), units='pix', image='UI/figures/xy_red.jpg')
|
||||
imageStim6red = visual.ImageStim(main_win, size=(300, 300), pos=(600, -200), units='pix', image='UI/figures/xy_red.jpg')
|
||||
|
||||
|
||||
frequencies = [25,26,27,28,29,30] #[9,10,11,12,13,14] #[30,31,32,33,34,35] [25,26,27,28,29,30]
|
||||
# 生成方波数据
|
||||
square_wave_9 = generate_square_wave(frequencies[0], REFRESH_RATE, 5)
|
||||
square_wave_11 = generate_square_wave(frequencies[1], REFRESH_RATE, 5)
|
||||
square_wave_12 = generate_square_wave(frequencies[2], REFRESH_RATE, 5)
|
||||
square_wave_13 = generate_square_wave(frequencies[3], REFRESH_RATE, 5)
|
||||
square_wave_14 = generate_square_wave(frequencies[4], REFRESH_RATE, 5)
|
||||
square_wave_15 = generate_square_wave(frequencies[5], REFRESH_RATE, 5)
|
||||
|
||||
# 创建刺激对象列表,便于管理
|
||||
image_stims = [imageStim1, imageStim2, imageStim3, imageStim4, imageStim5, imageStim6]
|
||||
txt_stims = [txtStim1, txtStim2, txtStim3, txtStim4, txtStim5, txtStim6]
|
||||
square_waves = [square_wave_9, square_wave_11, square_wave_12, square_wave_13, square_wave_14, square_wave_15]
|
||||
|
||||
time.sleep(2)
|
||||
# grating.color = 'black'
|
||||
server = TCPServer()
|
||||
server.start()
|
||||
client = TCPClient('127.0.0.1', 8099)
|
||||
client.connect()
|
||||
print('Connected decoder_main')
|
||||
# client.send_data('impedance', 1)
|
||||
# time.sleep(20)
|
||||
# client.send_data('impedance', 2)
|
||||
client.send_data('targetFreqs', frequencies) # 使用frequencies变量,确保与刺激频率一致
|
||||
time.sleep(1)
|
||||
# 开启全程数据保存到 EEGFiles
|
||||
client.send_data('saveData',1)
|
||||
# client.send_data('impedance',1)
|
||||
|
||||
|
||||
|
||||
# 实验参数
|
||||
repeats = 3
|
||||
seq_freq = frequencies * repeats
|
||||
seq_freq = np.random.permutation(seq_freq).tolist()
|
||||
num_trials = len(seq_freq) # 总试验次数, 6*6=36
|
||||
trial_count = 0
|
||||
|
||||
# 在线解码精度计算
|
||||
online_results = [] # 存储每个trial的解码结果
|
||||
correct_predictions = 0 # 正确预测计数
|
||||
|
||||
# 保存序列信息
|
||||
seq_info = {
|
||||
'total_trials': num_trials,
|
||||
'frequencies': frequencies,
|
||||
'sequence': seq_freq,
|
||||
'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
# 保存序列信息到文件
|
||||
import json
|
||||
seq_file_path = f'EEGFiles/sequence_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
||||
with open(seq_file_path, 'a', encoding='utf-8') as f:
|
||||
json.dump(seq_info, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
#========================Trials Started======================#
|
||||
while trial_count < num_trials:
|
||||
# 从序列中获取当前试验的目标频率
|
||||
target_freq = seq_freq[trial_count]
|
||||
target_freq_index = frequencies.index(target_freq)
|
||||
print(f'Trials {trial_count + 1}/{num_trials} - Target Frequency: {target_freq}Hz (Label: {target_freq_index + 1})')
|
||||
|
||||
# Stage 1: Cue Stage
|
||||
# print('Cue Stage: The target frequency is in Red')
|
||||
client.send_data('setLabelAndTrialInfo', {
|
||||
'label': 0,
|
||||
'trial_info': {
|
||||
'trial': trial_count + 1,
|
||||
'phase': 'cue',
|
||||
'target_freq': target_freq
|
||||
}
|
||||
})
|
||||
|
||||
for frameN in range(int(1 * REFRESH_RATE)): # 1秒提示
|
||||
key_press = get_keypress()
|
||||
if key_press in ['q']:
|
||||
shutdown(main_win, client)
|
||||
|
||||
# 显示所有刺激,目标刺激为红色
|
||||
for i, stim in enumerate(image_stims):
|
||||
if i == target_freq_index:
|
||||
# 目标刺激显示红色
|
||||
if i == 0:
|
||||
imageStim1red.draw()
|
||||
elif i == 1:
|
||||
imageStim2red.draw()
|
||||
elif i == 2:
|
||||
imageStim3red.draw()
|
||||
elif i == 3:
|
||||
imageStim4red.draw()
|
||||
elif i == 4:
|
||||
imageStim5red.draw()
|
||||
elif i == 5:
|
||||
imageStim6red.draw()
|
||||
else:
|
||||
# 其他刺激显示正常颜色
|
||||
stim.draw()
|
||||
|
||||
main_win.flip()
|
||||
|
||||
# Stage 2: Flanker Stimulus
|
||||
# print('Flanker Stage: flank all frequencies')
|
||||
client.send_data('predict', 1)
|
||||
client.send_data('setLabelAndTrialInfo', {
|
||||
'label': target_freq_index + 1, # 设置目标频率标签 这里+1,是因为0代表不记录数据
|
||||
'trial_info': {
|
||||
'trial': trial_count + 1, # trial 从0开始
|
||||
'phase': 'stimulus',
|
||||
'target_freq': target_freq
|
||||
}
|
||||
})
|
||||
outlet.push_sample(['S 1'])
|
||||
|
||||
for frameN in range(6 * REFRESH_RATE): # 6秒刺激
|
||||
key_press = get_keypress()
|
||||
if key_press in ['q']:
|
||||
shutdown(main_win, client)
|
||||
|
||||
# 所有频率按照方波闪烁
|
||||
if square_wave_9[frameN % len(square_wave_9)] == 1:
|
||||
imageStim1.draw()
|
||||
if square_wave_11[frameN % len(square_wave_11)] == 1:
|
||||
imageStim2.draw()
|
||||
if square_wave_12[frameN % len(square_wave_12)] == 1:
|
||||
imageStim3.draw()
|
||||
if square_wave_13[frameN % len(square_wave_13)] == 1:
|
||||
imageStim4.draw()
|
||||
if square_wave_14[frameN % len(square_wave_14)] == 1:
|
||||
imageStim5.draw()
|
||||
if square_wave_15[frameN % len(square_wave_15)] == 1:
|
||||
imageStim6.draw()
|
||||
|
||||
main_win.flip()
|
||||
if server.ChoosenNum != -1:
|
||||
break
|
||||
|
||||
# 记录在线解码结果
|
||||
predicted_freq_index = server.ChoosenNum # 解码结果
|
||||
predicted_freq = frequencies[predicted_freq_index] if predicted_freq_index != -1 else -1
|
||||
|
||||
# 判断解码是否正确
|
||||
is_correct = (predicted_freq_index == target_freq_index) if predicted_freq_index != -1 else False
|
||||
if is_correct:
|
||||
correct_predictions += 1
|
||||
|
||||
# 记录trial结果
|
||||
trial_result = {
|
||||
'trial': trial_count + 1,
|
||||
'target_freq': target_freq,
|
||||
'target_freq_index': target_freq_index,
|
||||
'predicted_freq': predicted_freq,
|
||||
'predicted_freq_index': predicted_freq_index,
|
||||
'is_correct': is_correct,
|
||||
'status': 'Success' if predicted_freq_index != -1 else 'Failed'
|
||||
}
|
||||
online_results.append(trial_result)
|
||||
|
||||
# 打印当前trial结果
|
||||
status_symbol = "✓" if is_correct else "✗"
|
||||
if predicted_freq_index == -1:
|
||||
print(f'Trial {trial_count + 1}: 目标{target_freq}Hz -> 解码失败 - {status_symbol}')
|
||||
else:
|
||||
print(f'Trial {trial_count + 1}: 目标{target_freq}Hz -> 预测{predicted_freq}Hz - {status_symbol}')
|
||||
|
||||
|
||||
# Stage 3: Decoding Feedback
|
||||
outlet.push_sample(['S 2'])
|
||||
client.send_data('setLabelAndTrialInfo', {
|
||||
'label': 0, # 反馈阶段标签为0
|
||||
'trial_info': {
|
||||
'trial': trial_count + 1,
|
||||
'phase': 'feedback',
|
||||
'target_freq': target_freq
|
||||
}
|
||||
})
|
||||
# print('反馈阶段: 显示解码结果')
|
||||
|
||||
for frameN in range(1 * REFRESH_RATE): # 1秒反馈
|
||||
key_press = get_keypress()
|
||||
if key_press in ['q']:
|
||||
shutdown(main_win, client)
|
||||
|
||||
# 显示所有刺激但不闪烁
|
||||
for stim in image_stims:
|
||||
stim.draw()
|
||||
|
||||
# 显示解码结果
|
||||
if server.ChoosenNum == 0:
|
||||
txtStim1.draw()
|
||||
elif server.ChoosenNum == 1:
|
||||
txtStim2.draw()
|
||||
elif server.ChoosenNum == 2:
|
||||
txtStim3.draw()
|
||||
elif server.ChoosenNum == 3:
|
||||
txtStim4.draw()
|
||||
elif server.ChoosenNum == 4:
|
||||
txtStim5.draw()
|
||||
elif server.ChoosenNum == 5:
|
||||
txtStim6.draw()
|
||||
|
||||
main_win.flip()
|
||||
|
||||
server.ChoosenNum = -1
|
||||
trial_count += 1
|
||||
|
||||
# 计算总体在线解码精度
|
||||
total_trials = len(online_results)
|
||||
successful_trials = len([r for r in online_results if r['status'] == 'Success'])
|
||||
failed_trials = len([r for r in online_results if r['status'] == 'Failed'])
|
||||
overall_accuracy = correct_predictions / total_trials if total_trials > 0 else 0
|
||||
|
||||
# Print Accuracy
|
||||
print(f"Total Accuracy: {overall_accuracy:.3f} ({correct_predictions}/{total_trials})")
|
||||
|
||||
# 按频率分析准确率
|
||||
print(f"\n=== 按频率分析准确率 ===")
|
||||
freq_accuracy = {}
|
||||
for result in online_results:
|
||||
freq = result['target_freq']
|
||||
if freq not in freq_accuracy:
|
||||
freq_accuracy[freq] = {'correct': 0, 'total': 0, 'failed': 0}
|
||||
|
||||
freq_accuracy[freq]['total'] += 1
|
||||
if result['status'] == 'Failed':
|
||||
freq_accuracy[freq]['failed'] += 1
|
||||
elif result['is_correct']:
|
||||
freq_accuracy[freq]['correct'] += 1
|
||||
|
||||
print(f"{'频率':<8} {'准确率':<8} {'正确/总数':<10} {'失败数':<8}")
|
||||
print("-" * 40)
|
||||
for freq in sorted(freq_accuracy.keys()):
|
||||
stats = freq_accuracy[freq]
|
||||
accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
|
||||
print(f"{freq}Hz{'':<4} {accuracy:.3f}{'':<4} {stats['correct']}/{stats['total']}{'':<6} {stats['failed']}")
|
||||
|
||||
# 保存在线解码结果到文件
|
||||
online_results_file = f'EEGFiles/online_results_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
||||
online_summary = {
|
||||
'total_trials': total_trials,
|
||||
'successful_trials': successful_trials,
|
||||
'failed_trials': failed_trials,
|
||||
'correct_predictions': correct_predictions,
|
||||
'overall_accuracy': overall_accuracy,
|
||||
# 'freq_accuracy': freq_accuracy,
|
||||
'trial_results': online_results,
|
||||
# 'end_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
|
||||
with open(online_results_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(online_summary, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
client.send_data('predict',2) # 关闭系统
|
||||
main_win.close()
|
||||
Reference in New Issue
Block a user