Files
bci_algo/Zmq/dataBuffer.py

123 lines
4.4 KiB
Python
Raw Permalink Normal View History

2026-06-06 09:16:49 +08:00
# -*-coding:utf-8 -*-
"""
范式buffer和滤波buffer, 以及滤波函数
"""
import numpy as np
from scipy import signal
import threading
2026-06-06 17:08:09 +08:00
from logs.log import algo_log
2026-06-06 09:16:49 +08:00
class ParadigmRingBuffer:
def __init__(self, n_chan, n_points):
self.n_chan = n_chan
self.n_points = n_points
2026-06-08 15:47:25 +08:00
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
2026-06-06 09:16:49 +08:00
self.currentPtr = 0
self.readPtr = 0
self.nUpdate = 0
self.rawData = np.zeros((n_chan, 1))
## append buffer and update current pointer
def appendBuffer(self, data):
if self.nUpdate == self.n_points:
2026-06-06 17:08:09 +08:00
# raise Exception("Buffer is full")
2026-06-10 07:48:43 +08:00
algo_log("ParadigmRingBuffer is full", record_once=True)
2026-06-06 09:16:49 +08:00
n = data.shape[1]
# 计算可以写入的元素数量
write_count = min(self.n_points - self.nUpdate, n)
# 写入新数据
self.buffer[:, np.mod(np.arange(self.currentPtr, self.currentPtr + write_count), self.n_points)] = data[:,:write_count]
# 更新结束指针
self.currentPtr = (self.currentPtr + write_count) % self.n_points
# 更新大小
self.nUpdate += write_count
## get data from buffer
def getData(self, count=50):
# 确保不会尝试读取超过缓冲区当前大小的数据
count = min(count, self.nUpdate)
# 计算读取结束后的下一个位置
next_read_ptr = (self.readPtr + count) % self.n_points
if self.readPtr + count <= self.n_points:
# 情况 1不环绕数据是连续的
end_index = next_read_ptr if next_read_ptr != 0 else self.n_points
data = self.buffer[:, self.readPtr:end_index]
else:
# 情况 2发生环绕数据被分成两部分
# 第一部分:从 readPtr 到缓冲区末尾
part1 = self.buffer[:, self.readPtr:]
# 第二部分:从缓冲区开头到 (count - part1.shape[1]) 个点
part2 = self.buffer[:, :next_read_ptr]
# 将两部分在列方向上拼接
data = np.concatenate((part1, part2), axis=1)
# 更新读指针
self.readPtr = next_read_ptr
# 更新大小
self.nUpdate -= count
return data
def GetDataLenCount(self):
'''
获取最新缓存中每个通道的数量
@return:
'''
2026-06-06 14:40:07 +08:00
return self.nUpdate
# ========== 各范式数据访问接口 ==========
def get_MIData(self):
"""获取MI导联数据 (21通道 + 事件)"""
data = self.getData(self.GetDataLenCount())
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
2026-06-06 09:16:49 +08:00
2026-06-06 14:40:07 +08:00
def get_SSMVEPData(self):
"""获取SSMVEP导联数据 (8通道 + 事件)"""
data = self.getData(self.GetDataLenCount())
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def getDataViaSSVEP(self, count):
"""获取SSVEP数据 (8通道 + 事件)"""
data = self.getData(count)
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_concentrateData(self, count):
"""获取专注力数据 (2通道)"""
data = self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_blinkData(self, count):
"""获取眨眼数据 (2通道)"""
data = self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
2026-06-06 09:16:49 +08:00
# reset buffer
def resetAllPara(self):
self.nUpdate = 0
self.currentPtr = 0
2026-06-07 11:05:24 +08:00
self.readPtr = 0
self.buffer.fill(0.0)
2026-06-06 09:16:49 +08:00