delete server1
This commit is contained in:
@@ -1,445 +0,0 @@
|
||||
import numpy as np
|
||||
import zmq
|
||||
import threading
|
||||
import json
|
||||
import queue
|
||||
import time
|
||||
from Device.SunnyLinker import SunnyLinker64, RingBuffer
|
||||
from collections import deque
|
||||
|
||||
|
||||
class zmqServer(threading.Thread):
|
||||
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100):
|
||||
threading.Thread.__init__(self)
|
||||
self.host = host
|
||||
self.cmd_port = cmd_port
|
||||
self.data_port = data_port
|
||||
self.running = False
|
||||
self.get_Impedance = False
|
||||
self.open_Impedance = None
|
||||
self.StartDecode = False
|
||||
self.StartTrain = False
|
||||
self.state_mode = None
|
||||
self.currentLabel = -1
|
||||
self.IsExitApp = False
|
||||
self.getReport = False
|
||||
self.daemon = True
|
||||
|
||||
# ZMQ Context
|
||||
self.context = zmq.Context()
|
||||
|
||||
# 指令通道 (8099) - ROUTER
|
||||
self.cmd_socket = self.context.socket(zmq.ROUTER)
|
||||
self.cmd_socket.setsockopt(zmq.RCVHWM, 1000)
|
||||
self.cmd_socket.setsockopt(zmq.SNDHWM, 1000)
|
||||
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
|
||||
|
||||
# 数据通道 (8100)) - ROUTER
|
||||
self.data_socket = self.context.socket(zmq.ROUTER)
|
||||
self.data_socket.setsockopt(zmq.RCVHWM, 1000)
|
||||
self.data_socket.setsockopt(zmq.RCVTIMEO, 50)
|
||||
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
|
||||
|
||||
self.targetFreqs = []
|
||||
self.changeTarget = False
|
||||
self.sunnyLinker = SunnyLinker64(None, None, None, None, None)
|
||||
self.labels = [0x01, 0x02, 0x03]
|
||||
|
||||
self.decoder_switch = False
|
||||
self.decoder_class = None
|
||||
self.cmd_clients = set()
|
||||
self.data_clients = set()
|
||||
self.send_queue = queue.Queue()
|
||||
|
||||
# ========== 数据缓冲区 (RingBuffer) ==========
|
||||
# 与 SunnyLinker 保持一致,使用 RingBuffer
|
||||
# 66 = 64 EEG通道 + 1 事件通道(第65) + 1 标签序号通道(第66)
|
||||
# 缓存约 10 秒数据 (250Hz * 10s = 2500 点)
|
||||
self.n_chan = 66
|
||||
self.t_buffer = 10.0 # 缓冲区时长(秒)
|
||||
self.__ringBuffer = RingBuffer(self.n_chan, int(self.t_buffer * 250))
|
||||
|
||||
# 事件检测相关
|
||||
self._event_lock = threading.Lock()
|
||||
self._epoch_finished = False
|
||||
self._event_inner_idx = -1
|
||||
self.pack_contain_event = False
|
||||
self.predict_event = 99
|
||||
self.events = [1, 2, self.predict_event]
|
||||
self.count_events = {}
|
||||
self.latency = 50
|
||||
self.train_latency = 50
|
||||
|
||||
# 当前事件标签序号 (从第66通道获取)
|
||||
self.current_label_index = 0
|
||||
|
||||
# 初始化标志
|
||||
self._interval_inited = False
|
||||
self._currentLabel = -1
|
||||
|
||||
# 注册的客户端(兼容旧接口)
|
||||
self.clients = set()
|
||||
|
||||
# ========== 事件属性:线程安全访问 ==========
|
||||
@property
|
||||
def epoch_finished(self):
|
||||
with self._event_lock:
|
||||
return self._epoch_finished
|
||||
|
||||
@epoch_finished.setter
|
||||
def epoch_finished(self, value):
|
||||
with self._event_lock:
|
||||
self._epoch_finished = value
|
||||
|
||||
@property
|
||||
def event_inner_idx(self):
|
||||
with self._event_lock:
|
||||
return self._event_inner_idx
|
||||
|
||||
@event_inner_idx.setter
|
||||
def event_inner_idx(self, value):
|
||||
with self._event_lock:
|
||||
self._event_inner_idx = value
|
||||
|
||||
@property
|
||||
def interval_inited(self):
|
||||
return self._interval_inited
|
||||
|
||||
@interval_inited.setter
|
||||
def interval_inited(self, value):
|
||||
self._interval_inited = value
|
||||
|
||||
@property
|
||||
def currentLabel(self):
|
||||
return self._currentLabel
|
||||
|
||||
@currentLabel.setter
|
||||
def currentLabel(self, value):
|
||||
self._currentLabel = value
|
||||
|
||||
def broadcast_message(self, method, params):
|
||||
"""Put message into queue to be sent to all connected clients"""
|
||||
self.send_queue.put((method, params))
|
||||
|
||||
# ========== 数据缓冲区操作接口 ==========
|
||||
def GetDataLenCount(self):
|
||||
"""返回缓冲区当前数据点数"""
|
||||
return self.__ringBuffer.nUpdate
|
||||
|
||||
def getData(self, count):
|
||||
"""获取最新count个数据点,不消费(只读)"""
|
||||
with self.__ringBuffer.RingBufferLock:
|
||||
count = min(count, self.__ringBuffer.nUpdate)
|
||||
if count == 0:
|
||||
return np.zeros((self.n_chan, 0))
|
||||
|
||||
# 计算读取范围(从尾部取最新数据)
|
||||
read_end = (self.__ringBuffer.currentPtr - 1) % self.__ringBuffer.n_points
|
||||
read_start = (read_end - count + 1) % self.__ringBuffer.n_points
|
||||
|
||||
if self.__ringBuffer.currentPtr == 0:
|
||||
read_start = self.__ringBuffer.n_points - count
|
||||
read_end = self.__ringBuffer.n_points - 1
|
||||
|
||||
if read_start <= read_end:
|
||||
data = self.__ringBuffer.buffer[:, read_start:read_end + 1]
|
||||
else:
|
||||
part1 = self.__ringBuffer.buffer[:, read_start:]
|
||||
part2 = self.__ringBuffer.buffer[:, :read_end + 1]
|
||||
data = np.concatenate((part1, part2), axis=1)
|
||||
|
||||
return data
|
||||
|
||||
def consumeData(self, count):
|
||||
"""消费(丢弃)指定数量的数据点,从头部移除"""
|
||||
with self.__ringBuffer.RingBufferLock:
|
||||
count = min(count, self.__ringBuffer.nUpdate)
|
||||
self.__ringBuffer.readPtr = (self.__ringBuffer.readPtr + count) % self.__ringBuffer.n_points
|
||||
self.__ringBuffer.nUpdate -= count
|
||||
|
||||
def ResetAll(self):
|
||||
"""重置缓冲区"""
|
||||
with self.__ringBuffer.RingBufferLock:
|
||||
self.__ringBuffer.resetAllPara()
|
||||
with self._event_lock:
|
||||
self._epoch_finished = False
|
||||
self._event_inner_idx = -1
|
||||
self.pack_contain_event = False
|
||||
self.count_events.clear()
|
||||
self.current_label_index = 0
|
||||
|
||||
def reset_data_buffer(self):
|
||||
self.ResetAll()
|
||||
|
||||
def reset_state(self):
|
||||
self.ResetAll()
|
||||
|
||||
def interval_init(self, decoder_class):
|
||||
"""初始化事件检测参数"""
|
||||
import ast
|
||||
from PubLibrary.InifileHelper import IniRead
|
||||
|
||||
if decoder_class == 'ssmvep':
|
||||
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
||||
self.interval_epoch = [int(i * 250) for i in interval_epoch]
|
||||
self.train_epoch = [int(self.interval_epoch[0]),
|
||||
int(self.interval_epoch[1] + 0.1 * 250)]
|
||||
self.latency = (self.interval_epoch[1] + 0.1 * 250) // 5
|
||||
self.train_latency = (self.train_epoch[1] + 0.1 * 250) // 5
|
||||
|
||||
elif decoder_class == 'mi':
|
||||
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
||||
self.interval_epoch = [int(i * 250) for i in interval_epoch]
|
||||
self.train_epoch = self.interval_epoch.copy()
|
||||
self.latency = self.interval_epoch[1] // 5
|
||||
self.train_latency = self.latency
|
||||
|
||||
self.count_events = {}
|
||||
self._event_inner_idx = -1
|
||||
self._epoch_finished = False
|
||||
self.pack_contain_event = False
|
||||
self.predict_event = 99
|
||||
self.events = [1, 2, self.predict_event]
|
||||
self._interval_inited = True
|
||||
|
||||
# ========== 事件检测 ==========
|
||||
def detect_event(self, data_matrix):
|
||||
"""
|
||||
检测事件通道中的触发信号
|
||||
|
||||
@param data_matrix: shape (66, N) - N个采样点的数据
|
||||
第65行(索引64) = 事件通道
|
||||
第66行(索引65) = 标签通道
|
||||
@return: 是否检测到事件
|
||||
"""
|
||||
if data_matrix.shape[1] == 0:
|
||||
return False
|
||||
|
||||
self.pack_contain_event = False
|
||||
event_channel = data_matrix[64, :] # 第65通道 = 标签值(event值)
|
||||
label_channel = data_matrix[65, :] # 第66通道 = 标签序号(label index)
|
||||
|
||||
events = event_channel.tolist()
|
||||
|
||||
with self._event_lock:
|
||||
self._event_inner_idx = -1
|
||||
self.current_event_label = 0
|
||||
|
||||
for idx, event in enumerate(events):
|
||||
if int(event) in self.events:
|
||||
self._event_inner_idx = idx
|
||||
self.current_label_index = int(label_channel[idx])
|
||||
self.pack_contain_event = True
|
||||
|
||||
new_key = f"{event}_{time.time()}"
|
||||
latency = self.latency if event == self.predict_event else self.train_latency
|
||||
self.count_events[new_key] = latency + 1
|
||||
|
||||
# 延迟计数递减
|
||||
drop_items = []
|
||||
for key, value in self.count_events.items():
|
||||
value = value - 1
|
||||
if value == 0:
|
||||
drop_items.append(key)
|
||||
self.count_events[key] = value
|
||||
for key in drop_items:
|
||||
del self.count_events[key]
|
||||
|
||||
if drop_items:
|
||||
self._epoch_finished = True
|
||||
# 检测到事件时,清除RingBuffer中之前的数据,只保留当前包
|
||||
if self.pack_contain_event:
|
||||
self.__ringBuffer.resetAllPara()
|
||||
return True
|
||||
|
||||
self._epoch_finished = False
|
||||
return False
|
||||
|
||||
def run(self):
|
||||
self.running = True
|
||||
print(f"Server running - CMD: {self.cmd_port}, DATA: {self.data_port}")
|
||||
|
||||
cmd_poller = zmq.Poller()
|
||||
cmd_poller.register(self.cmd_socket, zmq.POLLIN)
|
||||
|
||||
data_poller = zmq.Poller()
|
||||
data_poller.register(self.data_socket, zmq.POLLIN)
|
||||
|
||||
try:
|
||||
while self.running:
|
||||
# --- 处理发送队列 (指令通道) ---
|
||||
while not self.send_queue.empty():
|
||||
method, params = self.send_queue.get()
|
||||
if self.cmd_clients:
|
||||
try:
|
||||
msg = {'method': method, 'params': params}
|
||||
msg_bytes = json.dumps(msg).encode('utf-8')
|
||||
for client_id in list(self.cmd_clients):
|
||||
try:
|
||||
self.cmd_socket.send_multipart([client_id, b'', msg_bytes])
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- 处理指令通道 ---
|
||||
socks = dict(cmd_poller.poll(10))
|
||||
if self.cmd_socket in socks:
|
||||
self._handle_cmd_socket()
|
||||
|
||||
# --- 处理数据通道 ---
|
||||
socks = dict(data_poller.poll(10))
|
||||
if self.data_socket in socks:
|
||||
self._handle_data_socket()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Server error: {e}")
|
||||
finally:
|
||||
self.running = False
|
||||
self.cmd_socket.close()
|
||||
self.data_socket.close()
|
||||
self.context.term()
|
||||
|
||||
def _handle_cmd_socket(self):
|
||||
"""处理指令通道消息"""
|
||||
try:
|
||||
frames = self.cmd_socket.recv_multipart()
|
||||
if len(frames) < 3:
|
||||
return
|
||||
ident, _, message_bytes = frames[:3]
|
||||
self.cmd_clients.add(ident)
|
||||
self.clients.add(ident)
|
||||
|
||||
message = json.loads(message_bytes.decode('utf-8'))
|
||||
method = message.get("method")
|
||||
params = message.get("params")
|
||||
|
||||
print(f"[CMD] {method}: {params}")
|
||||
|
||||
if method == "sync":
|
||||
self.state_mode = 'sync'
|
||||
elif method == "targetFreqs":
|
||||
if isinstance(params, list) and params != self.targetFreqs:
|
||||
self.targetFreqs = params
|
||||
self.changeTarget = True
|
||||
elif method == "decoderClass":
|
||||
if isinstance(params, str) and params != self.decoder_class:
|
||||
self.decoder_class = params
|
||||
self.decoder_switch = True
|
||||
elif method == "getReport":
|
||||
self.getReport = True
|
||||
elif method == "train":
|
||||
self.state_mode = 'train'
|
||||
self.StartTrain = True
|
||||
self.currentLabel = params
|
||||
elif method == "predict":
|
||||
self.state_mode = 'predict'
|
||||
if params == 1:
|
||||
self.StartDecode = True
|
||||
elif params == 2:
|
||||
self.IsExitApp = True
|
||||
self.running = False
|
||||
elif method == "rest":
|
||||
self.state_mode = 'rest'
|
||||
elif method == "impedance":
|
||||
if params == 1:
|
||||
self.open_Impedance = True
|
||||
self.get_Impedance = True
|
||||
elif params == 2:
|
||||
self.open_Impedance = False
|
||||
self.get_Impedance = False
|
||||
|
||||
except Exception as e:
|
||||
print(f"CMD socket error: {e}")
|
||||
|
||||
def _handle_data_socket(self):
|
||||
"""处理数据通道消息 (EEG数据)
|
||||
|
||||
上位机数据格式:
|
||||
- 数据帧: [identity, '', meta_json, data_buffer]
|
||||
data_buffer = [N, 66] float32 -> 转置为 [66, N]
|
||||
"""
|
||||
try:
|
||||
frames = self.data_socket.recv_multipart()
|
||||
if len(frames) < 4:
|
||||
return
|
||||
ident, _, message_bytes = frames[:3]
|
||||
self.data_clients.add(ident)
|
||||
|
||||
meta = json.loads(message_bytes.decode('utf-8'))
|
||||
|
||||
# data: [N, 66] -> 转置 -> [66, N]
|
||||
raw_data = np.frombuffer(frames[3], dtype=np.float32)
|
||||
n_samples, n_channels = meta.get('shape', [5, 66])
|
||||
data_matrix = raw_data.reshape(n_samples, n_channels).T.astype(np.float32)
|
||||
|
||||
# 写入 RingBuffer
|
||||
with self.__ringBuffer.RingBufferLock:
|
||||
self.__ringBuffer.appendBuffer(data_matrix)
|
||||
|
||||
# 事件检测
|
||||
self.detect_event(data_matrix)
|
||||
|
||||
except Exception as e:
|
||||
print(f"DATA socket error: {e}")
|
||||
|
||||
# ========== 各范式数据访问接口 ==========
|
||||
def get_MIData(self):
|
||||
"""获取MI导联数据 (21通道 + 事件)"""
|
||||
data = self.getData(self.GetDataLenCount())
|
||||
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def get_SSMVEPData(self):
|
||||
"""获取SSMVEP导联数据 (8通道 + 事件)"""
|
||||
data = self.getData(self.GetDataLenCount())
|
||||
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def getDataViaSSVEP(self, count):
|
||||
"""获取SSVEP数据 (8通道 + 事件)"""
|
||||
data = self.getData(count)
|
||||
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def get_concentrateData(self, count):
|
||||
"""获取专注力数据 (2通道)"""
|
||||
data = self.getData(count)
|
||||
rows_to_extract = [0, 1]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def get_blinkData(self, count):
|
||||
"""获取眨眼数据 (2通道)"""
|
||||
data = self.getData(count)
|
||||
rows_to_extract = [0, 1]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def getImpedance(self, data, decoder_class):
|
||||
"""计算阻抗(ZMQ模式下不可用)"""
|
||||
return np.zeros(8)
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
self.cmd_socket.close()
|
||||
self.data_socket.close()
|
||||
self.context.term()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
server = zmqServer()
|
||||
server.start()
|
||||
Reference in New Issue
Block a user