Files
bci_algo/Zmq/filterProcess.py

209 lines
8.4 KiB
Python
Raw Normal View History

2026-06-06 09:16:49 +08:00
# -*-coding:utf-8 -*-
"""
数据滤波模块
"""
import numpy as np
2026-06-08 11:56:42 +08:00
import time
2026-06-06 09:16:49 +08:00
import threading
2026-06-07 11:05:24 +08:00
from scipy import signal
2026-06-06 09:16:49 +08:00
from logs.log import algo_log
class FilterRingBuffer:
def __init__(self, n_chan, n_points):
"""
2026-06-08 11:56:42 +08:00
初始化纯数据环形缓存线程安全
2026-06-06 09:16:49 +08:00
:param n_chan: 通道数
:param n_points: 总缓存点数与paradigmRingBuffer参数完全一致
"""
self.n_chan = n_chan
self.n_points = n_points
2026-06-07 11:05:24 +08:00
self.buffer = np.zeros((n_chan, n_points), dtype=np.float32)
2026-06-08 11:56:42 +08:00
self.current_ptr = 0 # 写入指针:指向下一个要写入的位置
2026-06-06 09:16:49 +08:00
self.total_samples = 0 # 已写入总点数
2026-06-08 11:56:42 +08:00
self.lock = threading.Lock() # 线程安全锁
2026-06-06 09:16:49 +08:00
def appendBuffer(self, data):
"""
追加数据到缓存与paradigmRingBuffer接口一致
:param data: 输入数据shape=(n_chan, n_samples)
"""
with self.lock:
n = data.shape[1]
if n == 0:
return
2026-06-08 11:56:42 +08:00
# 环形写入逻辑:指针到末尾则绕回
2026-06-06 09:16:49 +08:00
write_end = self.current_ptr + n
if write_end <= self.n_points:
self.buffer[:, self.current_ptr:write_end] = data
else:
split = self.n_points - self.current_ptr
self.buffer[:, self.current_ptr:] = data[:, :split]
self.buffer[:, :write_end - self.n_points] = data[:, split:]
2026-06-08 11:56:42 +08:00
# 更新指针(取模保证环形)和计数(不超过缓存总长度)
2026-06-06 09:16:49 +08:00
self.current_ptr = write_end % self.n_points
self.total_samples = min(self.total_samples + n, self.n_points)
def getData(self, count):
"""
2026-06-08 11:56:42 +08:00
从最新位置向前读取count个点环形读取
核心逻辑current_ptr是下一个写入位置 最新数据在current_ptr之前
2026-06-06 09:16:49 +08:00
:param count: 读取点数
:return: np.ndarray, shape=(n_chan, count)
"""
with self.lock:
count = min(count, self.total_samples)
if count == 0:
return np.zeros((self.n_chan, 0))
2026-06-08 11:56:42 +08:00
# 环形读取end是当前写入指针最新数据的下一位start是end - count
2026-06-06 09:16:49 +08:00
end = self.current_ptr
start = end - count
if start >= 0:
return self.buffer[:, start:end].copy()
else:
2026-06-08 11:56:42 +08:00
# 跨环形边界:前半部分从缓存末尾取,后半部分从开头取
part1 = self.buffer[:, start:] # start为负等价于n_points + start
2026-06-06 09:16:49 +08:00
part2 = self.buffer[:, :end]
return np.concatenate((part1, part2), axis=1)
def get_latest_n_points(self, n):
"""
扩展方法获取最新的n个点不移动读指针用于滑动窗口
:param n: 点数
2026-06-08 11:56:42 +08:00
:return: np.ndarray, shape=(n_chan, n) | None数据不足时
2026-06-06 09:16:49 +08:00
"""
with self.lock:
if self.total_samples < n:
return None
return self.getData(n)
def GetDataLenCount(self):
"""获取当前缓存总点数(兼容原有接口)"""
with self.lock:
return self.total_samples
def resetAllPara(self):
"""重置所有缓存和指针(兼容原有接口)"""
with self.lock:
self.buffer.fill(0.0)
self.current_ptr = 0
self.total_samples = 0
# -----------------------------------------------------------------------------
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
# -----------------------------------------------------------------------------
2026-06-08 11:56:42 +08:00
class SlidingFilter(threading.Thread):
2026-06-06 09:16:49 +08:00
def __init__(
self,
2026-06-08 11:56:42 +08:00
ring_buffer: FilterRingBuffer,
2026-06-06 09:16:49 +08:00
n_chan=66,
srate=250,
window_sec=3,
2026-06-08 15:23:47 +08:00
step_sec=0.2
2026-06-06 09:16:49 +08:00
):
2026-06-08 11:56:42 +08:00
super().__init__(daemon=True)
2026-06-06 09:16:49 +08:00
# 核心参数
self.n_chan = n_chan
self.srate = srate
2026-06-08 11:56:42 +08:00
self.step_sec = step_sec # 200ms滑动步长
self.window_sec = window_sec # 3秒窗口
self.step_sec = step_sec # 200ms滑动步长
self.window_size = int(srate * window_sec) # 3秒点数250*3=750
self.step_size = int(srate * step_sec) # 200ms点数250*0.2=50
# 关联ZMQServer的环形缓存解耦仅依赖接口
self.ring_buffer = ring_buffer
# 线程控制
self.running = threading.Event()
self.running.set()
# 滤波结果回调(外部可注册,获取滤波后的数据)
self.filter_result_callback = None
# 预计算滤波器系数(仅执行一次)
2026-06-06 09:16:49 +08:00
self._init_filters()
def _init_filters(self):
"""预计算所有滤波器系数(仅执行一次)"""
# 50Hz工频陷波Q=30工业标准
self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate)
# 8~30Hz带通FIR65阶线性相位
self.b_bp = signal.firwin(
numtaps=65,
cutoff=[8/(self.srate/2), 30/(self.srate/2)],
pass_zero=False,
window='hamming'
)
self.a_bp = np.array([1.0])
2026-06-08 11:56:42 +08:00
def _filter_window_data(self, window_data):
"""对3秒窗口数据执行滤波返回无边界效应的200ms数据"""
2026-06-06 09:16:49 +08:00
# 零相位滤波(无延迟,无边界效应)
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1)
2026-06-08 11:56:42 +08:00
# 提取倒数第二个200ms的数据完全避开两端边界效应
# 窗口长度750步长50 → start=750-100=650end=750-50=700
2026-06-06 09:16:49 +08:00
start_idx = self.window_size - 2 * self.step_size
end_idx = self.window_size - self.step_size
output_data = filtered[:, start_idx:end_idx].copy()
return output_data
2026-06-08 11:56:42 +08:00
def run(self):
"""线程主逻辑精确200ms触发一次滤波"""
# 精确定时核心基于perf_counter计算下一次执行时间补偿sleep误差
interval = self.step_sec # 200ms = 0.2秒
next_run_time = time.perf_counter()
while self.running.is_set():
# 1. 等待到下一次执行时间(精确定时)
current_time = time.perf_counter()
if current_time < next_run_time:
time.sleep(next_run_time - current_time)
next_run_time += interval # 补偿:下次执行时间基于上一次目标时间
else:
# 若超时如滤波耗时超过200ms重置下一次时间避免累积误差
algo_log("滤波耗时超过200ms定时偏移", level='debug')
next_run_time = time.perf_counter() + interval
# 2. 执行滤波逻辑
try:
# 获取最新的3秒窗口数据
window_data = self.ring_buffer.get_latest_n_points(self.window_size)
if window_data is None:
algo_log(f"缓存数据不足,当前缓存{self.ring_buffer.GetDataLenCount()}点,需{self.window_size}", level='debug')
continue
# 滤波并提取无边界效应的200ms数据
filtered_data = self._filter_window_data(window_data)
# 回调返回结果(外部可处理)
if self.filter_result_callback is not None:
self.filter_result_callback(filtered_data[:64, :]) # 只发送前64通道数据
except Exception as e:
algo_log(f"滤波执行异常: {e}", level='error')
def set_result_callback(self, callback):
"""注册滤波结果回调函数"""
self.filter_result_callback = callback
2026-06-06 09:16:49 +08:00
2026-06-08 11:56:42 +08:00
def stop(self):
2026-06-08 15:23:47 +08:00
"""停止滤波线程(安全版)"""
# 1. 先设置停止标志Event.clear()是线程安全的)
2026-06-08 11:56:42 +08:00
self.running.clear()
2026-06-08 15:23:47 +08:00
# 2. 核心修复只有线程已启动且正在运行时才调用join
if self.is_alive():
# 等待线程正常退出最多1秒
self.join(timeout=1)
# 超时未退出时打印警告,便于排查问题
if self.is_alive():
algo_log("警告滤波线程在1秒内未正常退出可能存在阻塞操作", level="WARNING")
# 3. 无论线程是否启动,都打印停止日志
algo_log("滤波线程已停止")