Files
bci_algo/datamock.py

189 lines
7.7 KiB
Python
Raw Normal View History

2026-06-06 14:49:38 +08:00
import zmq
import numpy as np
import time
import threading
2026-06-06 14:49:38 +08:00
from datetime import datetime
# ========== 参数配置 ==========
FS = 250 # 采样率 Hz
N_SAMPLES_PER_PKT = 5 # 每包采样点数
N_CHAN = 66 # 通道数: 64 EEG + 1 标签值 + 1 标签序号
EEG_FREQ = 10 # EEG 正弦波频率 Hz
2026-06-06 16:05:32 +08:00
EEG_AMP = 100.0 # EEG 幅值 100μV
2026-06-06 14:49:38 +08:00
LABEL_INTERVAL = 5 # 标签间隔秒数
2026-06-09 19:30:27 +08:00
SERVER_ADDR = 'tcp://127.0.0.1:8100'
LABEL_CMD_ADDR = 'tcp://127.0.0.1:8101' # 接收来自上位机范式的标签命令
2026-06-06 14:49:38 +08:00
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms
PKT_INTERVAL = N_SAMPLES_PER_PKT / FS
def build_packet(global_sample_idx):
"""
2026-06-08 15:47:25 +08:00
生成一包 [5, 66] float64 数据
2026-06-06 14:49:38 +08:00
:param global_sample_idx: 当前包第一个采样点在全局序列中的索引 ( 0 开始)
:return: np.ndarray shape [5, 66]
"""
# 当前包内 5 个采样点对应的时间(秒)
t = (global_sample_idx + np.arange(N_SAMPLES_PER_PKT)) / FS
# Ch0-63: EEG 10Hz 正弦波,幅值 100μV
# t shape [5,]sin 乘以标量后仍是 [5,],需要 reshape 为 [5,1] 再广播到 64 通道
eeg = (EEG_AMP * np.sin(2 * np.pi * EEG_FREQ * t)).reshape(N_SAMPLES_PER_PKT, 1) # [5, 1]
eeg = np.tile(eeg, (1, 64)) # [5, 64]
# Ch64: 标签值通道,初始化为 0
2026-06-08 15:47:25 +08:00
event = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
2026-06-06 14:49:38 +08:00
# Ch65: 标签序号通道,初始化为 0
2026-06-08 15:47:25 +08:00
label_idx = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
2026-06-06 14:49:38 +08:00
# 拼成 [5, 66]
2026-06-08 15:47:25 +08:00
packet = np.concatenate([eeg, event, label_idx], axis=1).astype(np.float64)
2026-06-06 14:49:38 +08:00
return packet
def should_send_label(global_sample_idx):
"""
判断当前包是否包含标签触发点 5s 的最后一个采样点
采样点索引从 0 开始 5s = 1250 个采样点
最后一个采样点索引: 1249, 2499, 3749, ...
由于每包 5 个采样点标签点落在包内的最后一个采样点位置
即当前包起始索引 global_sample_idx 必须使得:
global_sample_idx <= 标签点索引 < global_sample_idx + N_SAMPLES_PER_PKT
也就是 global_sample_idx <= 1249 < global_sample_idx + 5
global_sample_idx = 1245, 2495, 3745, ...
global_sample_idx = n * LABEL_INTERVAL * FS - N_SAMPLES_PER_PKT
"""
samples_per_interval = LABEL_INTERVAL * FS
# 检查当前包是否包含 interval 的最后一个采样点
# 标签点索引 = n * 1250 - 1当 global_sample_idx = n*1250-5 时,标签在包内索引 4
return (global_sample_idx + N_SAMPLES_PER_PKT - 1) % samples_per_interval == samples_per_interval - 1
def main():
ctx = zmq.Context()
sock = ctx.socket(zmq.DEALER)
sock.connect(SERVER_ADDR)
print(f"[{datetime.now().strftime('%H:%M:%S')}] ZMQ Dealer 连接到 {SERVER_ADDR}")
2026-06-09 19:30:27 +08:00
# ========== 上位机标签命令监听 ==========
# 使用线程安全的队列接收来自 ssmvep_main.py 的标签命令
# 标签值: 1 (train 0), 2 (train 1), 99 (predict)
pending_label = [None] # [label_value or None]
label_lock = threading.Lock()
label_cmd_sock = ctx.socket(zmq.PULL)
label_cmd_sock.bind(LABEL_CMD_ADDR)
print(f"[{datetime.now().strftime('%H:%M:%S')}] 标签命令监听绑定到 {LABEL_CMD_ADDR}")
stop_recv = threading.Event()
def label_cmd_thread():
"""监听来自上位机范式的标签命令,写入 pending_label"""
while not stop_recv.is_set():
try:
msg = label_cmd_sock.recv_string(zmq.NOBLOCK)
label_val = int(msg)
with label_lock:
pending_label[0] = label_val
ts = datetime.now().strftime('%H:%M:%S')
label_name = {1: 'train_0', 2: 'train_1', 99: 'predict'}.get(label_val, str(label_val))
print(f"[{ts}] 收到标签命令: {label_name} -> label={label_val}")
except zmq.Again:
time.sleep(0.005)
except Exception as e:
print(f"[label_cmd_thread] 错误: {e}")
time.sleep(0.01)
label_thread = threading.Thread(target=label_cmd_thread, daemon=True)
label_thread.start()
print(f"[{datetime.now().strftime('%H:%M:%S')}] 标签命令监听线程已启动")
# 后台消费线程:持续 recv 从 ROUTER 返回的数据,避免 server 发送队列积压
recv_count = [0]
def consumer_thread():
"""消费线程:阻塞 recv丢弃收到的数据仅用于清空 ROUTER 发送队列"""
while not stop_recv.is_set():
try:
frames = sock.recv_multipart(zmq.NOBLOCK)
recv_count[0] += 1
# 收到的格式: [identity, '', filtered_data_bytes]
if recv_count[0] % 500 == 0:
print(f"[{datetime.now().strftime('%H:%M:%S')}] 消费线程已丢弃 {recv_count[0]} 帧滤波数据")
except zmq.Again:
time.sleep(0.01)
except zmq.error.Again: # 兼容旧版
time.sleep(0.01)
consumer = threading.Thread(target=consumer_thread, daemon=True)
consumer.start()
print(f"[{datetime.now().strftime('%H:%M:%S')}] 消费线程已启动daemon")
2026-06-06 14:49:38 +08:00
global_sample_idx = 0 # 全局采样点计数器
label_type = 1 # 当前标签类型: 1 或 2
label1_count = 0 # label=1 的序号计数器
label2_count = 0 # label=2 的序号计数器
packet_count = 0 # 已发送包数
print(f"[{datetime.now().strftime('%H:%M:%S')}] 开始发送模拟数据 ...")
print(f" 采样率: {FS}Hz | 每包 {N_SAMPLES_PER_PKT} 采样点 | 发送间隔 {PKT_INTERVAL*1000:.0f}ms")
2026-06-06 16:05:32 +08:00
print(f" EEG: {EEG_FREQ}Hz 正弦波 | 幅值 {EEG_AMP}μV")
2026-06-09 19:30:27 +08:00
print(f" 标签: 来自上位机范式命令 (train_0=1, train_1=2, predict=99)")
2026-06-06 14:49:38 +08:00
print("-" * 50)
try:
while True:
t_start = time.perf_counter()
# 构建当前包
packet = build_packet(global_sample_idx)
2026-06-09 19:30:27 +08:00
# 检查是否有来自上位机范式的挂起标签命令
with label_lock:
ext_label = pending_label[0]
if ext_label is not None:
pending_label[0] = None
2026-06-06 14:49:38 +08:00
2026-06-09 19:30:27 +08:00
if ext_label is not None:
# 将标签写入当前包所有5个采样点的第65通道 (index 64)
# 覆盖全部采样点确保 event_inner_idx 无论落在哪个位置都能被正确检测
packet[:, 64] = float(ext_label)
2026-06-06 14:49:38 +08:00
ts = datetime.now().strftime('%H:%M:%S')
2026-06-09 19:30:27 +08:00
print(f"[{ts}] 打标签: label={ext_label} -> ch64[all 5 samples] (global_sample_idx={global_sample_idx})")
2026-06-06 14:49:38 +08:00
2026-06-09 19:30:27 +08:00
# 发送: multipart 2帧 ['', data]
# 使用标准格式ROUTER 会自动附加 ZMQ 分配的客户端身份
2026-06-06 14:49:38 +08:00
sock.send_multipart([
b'',
2026-06-06 14:49:38 +08:00
packet.tobytes()
])
# 每 50 包打印一次进度
if packet_count % 50 == 0:
ts = datetime.now().strftime('%H:%M:%S')
print(f"[{ts}] 已发送 {packet_count} 包 (global_sample_idx={global_sample_idx})")
global_sample_idx += N_SAMPLES_PER_PKT
packet_count += 1
# 精确控制发送节奏: 等待到 PKT_INTERVAL 秒
elapsed = time.perf_counter() - t_start
sleep_time = PKT_INTERVAL - elapsed
if sleep_time > 0:
time.sleep(sleep_time)
except KeyboardInterrupt:
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] 停止发送,共发送 {packet_count}")
finally:
stop_recv.set()
consumer.join(timeout=2)
2026-06-09 19:30:27 +08:00
label_cmd_sock.close()
2026-06-06 14:49:38 +08:00
sock.close()
ctx.term()
if __name__ == '__main__':
main()