Files
bci_algo/Zmq/zmqServer1.py
2026-06-06 09:16:49 +08:00

446 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
import zmq
import threading
import json
import queue
import time
from Device.SunnyLinker import SunnyLinker64, RingBuffer
from collections import deque
class zmqServer(threading.Thread):
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100):
threading.Thread.__init__(self)
self.host = host
self.cmd_port = cmd_port
self.data_port = data_port
self.running = False
self.get_Impedance = False
self.open_Impedance = None
self.StartDecode = False
self.StartTrain = False
self.state_mode = None
self.currentLabel = -1
self.IsExitApp = False
self.getReport = False
self.daemon = True
# ZMQ Context
self.context = zmq.Context()
# 指令通道 (8099) - ROUTER
self.cmd_socket = self.context.socket(zmq.ROUTER)
self.cmd_socket.setsockopt(zmq.RCVHWM, 1000)
self.cmd_socket.setsockopt(zmq.SNDHWM, 1000)
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
# 数据通道 (8100)) - ROUTER
self.data_socket = self.context.socket(zmq.ROUTER)
self.data_socket.setsockopt(zmq.RCVHWM, 1000)
self.data_socket.setsockopt(zmq.RCVTIMEO, 50)
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
self.targetFreqs = []
self.changeTarget = False
self.sunnyLinker = SunnyLinker64(None, None, None, None, None)
self.labels = [0x01, 0x02, 0x03]
self.decoder_switch = False
self.decoder_class = None
self.cmd_clients = set()
self.data_clients = set()
self.send_queue = queue.Queue()
# ========== 数据缓冲区 (RingBuffer) ==========
# 与 SunnyLinker 保持一致,使用 RingBuffer
# 66 = 64 EEG通道 + 1 事件通道(第65) + 1 标签序号通道(第66)
# 缓存约 10 秒数据 (250Hz * 10s = 2500 点)
self.n_chan = 66
self.t_buffer = 10.0 # 缓冲区时长(秒)
self.__ringBuffer = RingBuffer(self.n_chan, int(self.t_buffer * 250))
# 事件检测相关
self._event_lock = threading.Lock()
self._epoch_finished = False
self._event_inner_idx = -1
self.pack_contain_event = False
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self.count_events = {}
self.latency = 50
self.train_latency = 50
# 当前事件标签序号 (从第66通道获取)
self.current_label_index = 0
# 初始化标志
self._interval_inited = False
self._currentLabel = -1
# 注册的客户端(兼容旧接口)
self.clients = set()
# ========== 事件属性:线程安全访问 ==========
@property
def epoch_finished(self):
with self._event_lock:
return self._epoch_finished
@epoch_finished.setter
def epoch_finished(self, value):
with self._event_lock:
self._epoch_finished = value
@property
def event_inner_idx(self):
with self._event_lock:
return self._event_inner_idx
@event_inner_idx.setter
def event_inner_idx(self, value):
with self._event_lock:
self._event_inner_idx = value
@property
def interval_inited(self):
return self._interval_inited
@interval_inited.setter
def interval_inited(self, value):
self._interval_inited = value
@property
def currentLabel(self):
return self._currentLabel
@currentLabel.setter
def currentLabel(self, value):
self._currentLabel = value
def broadcast_message(self, method, params):
"""Put message into queue to be sent to all connected clients"""
self.send_queue.put((method, params))
# ========== 数据缓冲区操作接口 ==========
def GetDataLenCount(self):
"""返回缓冲区当前数据点数"""
return self.__ringBuffer.nUpdate
def getData(self, count):
"""获取最新count个数据点不消费只读"""
with self.__ringBuffer.RingBufferLock:
count = min(count, self.__ringBuffer.nUpdate)
if count == 0:
return np.zeros((self.n_chan, 0))
# 计算读取范围(从尾部取最新数据)
read_end = (self.__ringBuffer.currentPtr - 1) % self.__ringBuffer.n_points
read_start = (read_end - count + 1) % self.__ringBuffer.n_points
if self.__ringBuffer.currentPtr == 0:
read_start = self.__ringBuffer.n_points - count
read_end = self.__ringBuffer.n_points - 1
if read_start <= read_end:
data = self.__ringBuffer.buffer[:, read_start:read_end + 1]
else:
part1 = self.__ringBuffer.buffer[:, read_start:]
part2 = self.__ringBuffer.buffer[:, :read_end + 1]
data = np.concatenate((part1, part2), axis=1)
return data
def consumeData(self, count):
"""消费(丢弃)指定数量的数据点,从头部移除"""
with self.__ringBuffer.RingBufferLock:
count = min(count, self.__ringBuffer.nUpdate)
self.__ringBuffer.readPtr = (self.__ringBuffer.readPtr + count) % self.__ringBuffer.n_points
self.__ringBuffer.nUpdate -= count
def ResetAll(self):
"""重置缓冲区"""
with self.__ringBuffer.RingBufferLock:
self.__ringBuffer.resetAllPara()
with self._event_lock:
self._epoch_finished = False
self._event_inner_idx = -1
self.pack_contain_event = False
self.count_events.clear()
self.current_label_index = 0
def reset_data_buffer(self):
self.ResetAll()
def reset_state(self):
self.ResetAll()
def interval_init(self, decoder_class):
"""初始化事件检测参数"""
import ast
from PubLibrary.InifileHelper import IniRead
if decoder_class == 'ssmvep':
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
self.interval_epoch = [int(i * 250) for i in interval_epoch]
self.train_epoch = [int(self.interval_epoch[0]),
int(self.interval_epoch[1] + 0.1 * 250)]
self.latency = (self.interval_epoch[1] + 0.1 * 250) // 5
self.train_latency = (self.train_epoch[1] + 0.1 * 250) // 5
elif decoder_class == 'mi':
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
self.interval_epoch = [int(i * 250) for i in interval_epoch]
self.train_epoch = self.interval_epoch.copy()
self.latency = self.interval_epoch[1] // 5
self.train_latency = self.latency
self.count_events = {}
self._event_inner_idx = -1
self._epoch_finished = False
self.pack_contain_event = False
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self._interval_inited = True
# ========== 事件检测 ==========
def detect_event(self, data_matrix):
"""
检测事件通道中的触发信号
@param data_matrix: shape (66, N) - N个采样点的数据
第65行(索引64) = 事件通道
第66行(索引65) = 标签通道
@return: 是否检测到事件
"""
if data_matrix.shape[1] == 0:
return False
self.pack_contain_event = False
event_channel = data_matrix[64, :] # 第65通道 = 标签值(event值)
label_channel = data_matrix[65, :] # 第66通道 = 标签序号(label index)
events = event_channel.tolist()
with self._event_lock:
self._event_inner_idx = -1
self.current_event_label = 0
for idx, event in enumerate(events):
if int(event) in self.events:
self._event_inner_idx = idx
self.current_label_index = int(label_channel[idx])
self.pack_contain_event = True
new_key = f"{event}_{time.time()}"
latency = self.latency if event == self.predict_event else self.train_latency
self.count_events[new_key] = latency + 1
# 延迟计数递减
drop_items = []
for key, value in self.count_events.items():
value = value - 1
if value == 0:
drop_items.append(key)
self.count_events[key] = value
for key in drop_items:
del self.count_events[key]
if drop_items:
self._epoch_finished = True
# 检测到事件时清除RingBuffer中之前的数据只保留当前包
if self.pack_contain_event:
self.__ringBuffer.resetAllPara()
return True
self._epoch_finished = False
return False
def run(self):
self.running = True
print(f"Server running - CMD: {self.cmd_port}, DATA: {self.data_port}")
cmd_poller = zmq.Poller()
cmd_poller.register(self.cmd_socket, zmq.POLLIN)
data_poller = zmq.Poller()
data_poller.register(self.data_socket, zmq.POLLIN)
try:
while self.running:
# --- 处理发送队列 (指令通道) ---
while not self.send_queue.empty():
method, params = self.send_queue.get()
if self.cmd_clients:
try:
msg = {'method': method, 'params': params}
msg_bytes = json.dumps(msg).encode('utf-8')
for client_id in list(self.cmd_clients):
try:
self.cmd_socket.send_multipart([client_id, b'', msg_bytes])
except Exception:
pass
except Exception:
pass
# --- 处理指令通道 ---
socks = dict(cmd_poller.poll(10))
if self.cmd_socket in socks:
self._handle_cmd_socket()
# --- 处理数据通道 ---
socks = dict(data_poller.poll(10))
if self.data_socket in socks:
self._handle_data_socket()
except Exception as e:
print(f"Server error: {e}")
finally:
self.running = False
self.cmd_socket.close()
self.data_socket.close()
self.context.term()
def _handle_cmd_socket(self):
"""处理指令通道消息"""
try:
frames = self.cmd_socket.recv_multipart()
if len(frames) < 3:
return
ident, _, message_bytes = frames[:3]
self.cmd_clients.add(ident)
self.clients.add(ident)
message = json.loads(message_bytes.decode('utf-8'))
method = message.get("method")
params = message.get("params")
print(f"[CMD] {method}: {params}")
if method == "sync":
self.state_mode = 'sync'
elif method == "targetFreqs":
if isinstance(params, list) and params != self.targetFreqs:
self.targetFreqs = params
self.changeTarget = True
elif method == "decoderClass":
if isinstance(params, str) and params != self.decoder_class:
self.decoder_class = params
self.decoder_switch = True
elif method == "getReport":
self.getReport = True
elif method == "train":
self.state_mode = 'train'
self.StartTrain = True
self.currentLabel = params
elif method == "predict":
self.state_mode = 'predict'
if params == 1:
self.StartDecode = True
elif params == 2:
self.IsExitApp = True
self.running = False
elif method == "rest":
self.state_mode = 'rest'
elif method == "impedance":
if params == 1:
self.open_Impedance = True
self.get_Impedance = True
elif params == 2:
self.open_Impedance = False
self.get_Impedance = False
except Exception as e:
print(f"CMD socket error: {e}")
def _handle_data_socket(self):
"""处理数据通道消息 (EEG数据)
上位机数据格式:
- 数据帧: [identity, '', meta_json, data_buffer]
data_buffer = [N, 66] float32 -> 转置为 [66, N]
"""
try:
frames = self.data_socket.recv_multipart()
if len(frames) < 4:
return
ident, _, message_bytes = frames[:3]
self.data_clients.add(ident)
meta = json.loads(message_bytes.decode('utf-8'))
# data: [N, 66] -> 转置 -> [66, N]
raw_data = np.frombuffer(frames[3], dtype=np.float32)
n_samples, n_channels = meta.get('shape', [5, 66])
data_matrix = raw_data.reshape(n_samples, n_channels).T.astype(np.float32)
# 写入 RingBuffer
with self.__ringBuffer.RingBufferLock:
self.__ringBuffer.appendBuffer(data_matrix)
# 事件检测
self.detect_event(data_matrix)
except Exception as e:
print(f"DATA socket error: {e}")
# ========== 各范式数据访问接口 ==========
def get_MIData(self):
"""获取MI导联数据 (21通道 + 事件)"""
data = self.getData(self.GetDataLenCount())
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_SSMVEPData(self):
"""获取SSMVEP导联数据 (8通道 + 事件)"""
data = self.getData(self.GetDataLenCount())
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def getDataViaSSVEP(self, count):
"""获取SSVEP数据 (8通道 + 事件)"""
data = self.getData(count)
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_concentrateData(self, count):
"""获取专注力数据 (2通道)"""
data = self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_blinkData(self, count):
"""获取眨眼数据 (2通道)"""
data = self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def getImpedance(self, data, decoder_class):
"""计算阻抗ZMQ模式下不可用"""
return np.zeros(8)
def stop(self):
self.running = False
self.cmd_socket.close()
self.data_socket.close()
self.context.term()
if __name__ == '__main__':
server = zmqServer()
server.start()