add buffer

This commit is contained in:
2026-06-06 14:40:07 +08:00
parent 868ff30238
commit 2d190d6431
5 changed files with 414 additions and 266 deletions

View File

@@ -63,8 +63,53 @@ class ParadigmRingBuffer:
获取最新缓存中每个通道的数量
@return:
'''
return self.nUpdate
return self.nUpdate
# ========== 各范式数据访问接口 ==========
def get_MIData(self):
"""获取MI导联数据 (21通道 + 事件)"""
data = self.getData(self.GetDataLenCount())
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_SSMVEPData(self):
"""获取SSMVEP导联数据 (8通道 + 事件)"""
data = self.getData(self.GetDataLenCount())
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def getDataViaSSVEP(self, count):
"""获取SSVEP数据 (8通道 + 事件)"""
data = self.getData(count)
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_concentrateData(self, count):
"""获取专注力数据 (2通道)"""
data = self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_blinkData(self, count):
"""获取眨眼数据 (2通道)"""
data = self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
# reset buffer
def resetAllPara(self):
@@ -72,6 +117,4 @@ class ParadigmRingBuffer:
self.currentPtr = 0
self.readPtr = 0 # add by lizhenhua 清空读指针
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区

View File

@@ -1,16 +1,22 @@
import ast
import numpy as np
import zmq
import threading
import json
import queue
from typing import Dict
# from Device.SunnyLinker import SunnyLinker64
from dataBuffer import ParadigmRingBuffer
from filterProcess import FilterRingBuffer
from PubLibrary.InifileHelper import IniRead
from logs.log import algo_log
import zmq
class zmqServer(threading.Thread):
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
threading.Thread.__init__(self)
self.device_info = device_info
self.host = host
self.cmd_port = cmd_port # 命令交互端口
self.data_port = data_port # 数据接收端口
@@ -28,8 +34,8 @@ class zmqServer(threading.Thread):
self.daemon = True
# 范式数据缓存
self.paradigmBuffer = ParadigmRingBuffer(66, 2500)
self.filterBuffer = FilterRingBuffer(66, 2500)
self.paradigmBuffer = ParadigmRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10)
self.filterBuffer = FilterRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10)
# 命令与数据通信
@@ -64,6 +70,77 @@ class zmqServer(threading.Thread):
self.cmd_clients = set() # 命令端口客户端ID
self.data_clients = set() # 数据端口客户端ID
self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播)
# 范式buffer参数, 事件检测相关
self._event_lock = threading.Lock()
self._epoch_finished = False
self._event_inner_idx = -1
self.pack_contain_event = False
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self.count_events = {}
self.latency = 50
self.train_latency = 50
self._interval_inited = False
@property
def interval_inited(self):
return self._interval_inited
@interval_inited.setter
def interval_inited(self, value):
self._interval_inited = value
@property
def epoch_finished(self):
with self._event_lock:
return self._epoch_finished
@epoch_finished.setter
def epoch_finished(self, value):
with self._event_lock:
self._epoch_finished = value
@property
def event_inner_idx(self):
with self._event_lock:
return self._event_inner_idx
@event_inner_idx.setter
def event_inner_idx(self, value):
with self._event_lock:
self._event_inner_idx = value
def interval_init(self, decoder_class):
if decoder_class == 'ssmvep':
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息
self.train_epoch = [int(self.interval_epoch[0]),
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
self.latency = (self.interval_epoch[
1] + 0.1 * self.device_info['sample_rate']) // 5 # 提取epoch的延迟标记5代表每次解包得到的5位采样点;0.1表示比实际需要的长度多取0.1,会被截掉
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5
elif decoder_class == 'mi':
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息
self.train_epoch = self.interval_epoch.copy()
self.latency = (self.interval_epoch[1]) // 5 # 提取epoch的延迟标记5代表每次解包得到的5位采样点;
self.train_latency = self.latency
print('时间窗:', (interval_epoch))
self.count_events: Dict[str, int] = {} # 表示包延迟的计数信息
self.event_inner_idx = -1 # event在5位数据包内部的idx
self.epoch_finished = False # 接收epoch是否完整
self.pack_contain_event = False # 当前包是否含有event
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self.interval_inited = True
# if getattr(self, 'serial', None) and self.serial.is_open:
# self.serial.close()
# self.serial = serial.Serial(self.serial_port, 460800, timeout=1) # 连接同步器串口
def broadcast_message(self, method, params):
"""Put message into queue to be sent to all command clients"""
@@ -78,15 +155,15 @@ class zmqServer(threading.Thread):
# 注册新的命令客户端
if ident not in self.cmd_clients:
self.cmd_clients.add(ident)
print(f"New CMD Client Connected: {ident} (port: {self.cmd_port})")
algo_log(f"New CMD Client Connected: {ident} (port: {self.cmd_port})")
# 解析消息
try:
message = json.loads(message_bytes.decode('utf-8'))
except json.JSONDecodeError:
print(f"Invalid JSON from CMD client {ident}")
continue
print(f"Received CMD request: {message}")
algo_log(f"Invalid JSON from CMD client {ident}")
return
algo_log(f"Received CMD request: {message}")
method = message.get("method")
params = message.get("params")
@@ -94,37 +171,40 @@ class zmqServer(threading.Thread):
# 原有命令处理逻辑
if method == "sync":
self.state_mode = 'sync'
if method == "targetFreqs":
elif method == "targetFreqs":
if not isinstance(params, list):
print('targetFreqs must be a list')
continue
algo_log(f"targetFreqs must be a list")
return
if params != self.targetFreqs:
self.targetFreqs = params
self.changeTarget = True
if method == "decoderClass":
elif method == "decoderClass":
if not isinstance(params, str):
print('decoderClass must be a str')
continue
algo_log(f"decoderClass must be a str")
return
if params != self.decoder_class:
self.decoder_class = params
self.decoder_switch = True
if method == "getReport":
self.getReport = True
if method == "train":#训练状态
elif method == "train":#训练状态
self.state_mode = 'train'
self.StartTrain = True
self.currentLabel = params # 当前刺激端的训练标签
self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
# self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
elif method == "predict":#预测状态
self.state_mode = 'predict'
if params == 1: #开始解码
self.StartDecode = True
self.sunnyLinker.push_trigger(0x63)
# self.sunnyLinker.push_trigger(0x63)
elif params == 2: #停止解码
self.IsExitApp = True
self.running = False
elif method == "rest": #休息状态
self.state_mode = 'rest'
else:
algo_log(f"未知命令:{method}", level="WARNING")
# elif method == "getReport":
# self.getReport = True
# elif method == "impedance":
# if params == 1:
# self.open_Impedance = True # 开启阻抗
@@ -153,7 +233,7 @@ class zmqServer(threading.Thread):
try:
# 3. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节与int32字节数相同
EXPECTED_BYTES = 5 * 66 * 4 # 每个float32占4字节
EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * 4 # 每个float32占4字节
if len(data_bytes) != EXPECTED_BYTES:
print(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节")
return
@@ -162,7 +242,7 @@ class zmqServer(threading.Thread):
# 步骤:字节流 → (330,) float32数组 → (5,66) 原始格式 → 转置为 (66,5) 缓冲区标准格式
data_np = np.frombuffer(data_bytes, dtype=np.float32)
# 重塑为上位机原始维度
data_np = data_np.reshape(5, 66)
data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
# 转置为(通道数, 采样点数)标准格式转换为float64保证滤波运算精度
data_np = data_np.T.astype(np.float64)
@@ -215,7 +295,7 @@ class zmqServer(threading.Thread):
self._process_send_queue()
# 2. 轮训监听两个Socket的输入事件10ms超时避免阻塞
socks = dict(self.poller.poll(10))
socks = dict(self.poller.poll(50))
# 处理命令端口消息
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN: