Files
bci_algo/Zmq/filterProcess.py
2026-06-08 17:13:25 +08:00

209 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*-coding:utf-8 -*-
"""
数据滤波模块
"""
import numpy as np
import time
import threading
from scipy import signal
from logs.log import algo_log
class FilterRingBuffer:
def __init__(self, n_chan, n_points):
"""
初始化纯数据环形缓存(线程安全)
:param n_chan: 通道数
:param n_points: 总缓存点数与paradigmRingBuffer参数完全一致
"""
self.n_chan = n_chan
self.n_points = n_points
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
self.current_ptr = 0
self.total_samples = 0
self.lock = threading.Lock() # 仅保护元数据
def appendBuffer(self, data):
"""
追加数据到缓存与paradigmRingBuffer接口一致
:param data: 输入数据shape=(n_chan, n_samples)
"""
n = data.shape[1]
if n == 0:
return
# -------- 第一步:仅加锁读取/更新元数据(持锁极短)--------
with self.lock:
old_ptr = self.current_ptr
new_ptr = (old_ptr + n) % self.n_points
new_total = min(self.total_samples + n, self.n_points)
# -------- 第二步:数组写入(耗时操作,移出锁外)--------
write_end = old_ptr + n
if write_end <= self.n_points:
self.buffer[:, old_ptr:write_end] = data
else:
split = self.n_points - old_ptr
self.buffer[:, old_ptr:] = data[:, :split]
self.buffer[:, :write_end - self.n_points] = data[:, split:]
# -------- 第三步:再次加锁更新最终元数据 --------
with self.lock:
self.current_ptr = new_ptr
self.total_samples = new_total
def getData(self, count):
"""
从最新位置向前读取count个点环形读取
核心逻辑current_ptr是下一个写入位置 → 最新数据在current_ptr之前
:param count: 读取点数
:return: np.ndarray, shape=(n_chan, count)
"""
# -------- 第一步:加锁获取最新元数据(持锁极短)--------
with self.lock:
count = min(count, self.total_samples)
if count == 0:
return np.zeros((self.n_chan, 0))
end = self.current_ptr
start = end - count
if start >= 0:
res = self.buffer[:, start:end].copy()
else:
part1 = self.buffer[:, start:]
part2 = self.buffer[:, :end]
res = np.concatenate((part1, part2), axis=1).copy()
return res
def get_latest_n_points(self, n):
with self.lock:
if self.total_samples < n:
return None
return self.getData(n)
def GetDataLenCount(self):
with self.lock:
return self.total_samples
def resetAllPara(self):
with self.lock:
self.buffer.fill(0.0)
self.current_ptr = 0
self.total_samples = 0
# -----------------------------------------------------------------------------
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
# -----------------------------------------------------------------------------
class SlidingFilter(threading.Thread):
def __init__(
self,
ring_buffer: FilterRingBuffer,
n_chan=66,
srate=250,
window_sec=3,
step_sec=0.2
):
super().__init__(daemon=True)
# 核心参数
self.n_chan = n_chan
self.srate = srate
self.step_sec = step_sec # 200ms滑动步长
self.window_sec = window_sec # 3秒窗口
self.step_sec = step_sec # 200ms滑动步长
self.window_size = int(srate * window_sec) # 3秒点数250*3=750
self.step_size = int(srate * step_sec) # 200ms点数250*0.2=50
# 关联ZMQServer的环形缓存解耦仅依赖接口
self.ring_buffer = ring_buffer
# 线程控制
self.running = threading.Event()
self.running.set()
# 滤波结果回调(外部可注册,获取滤波后的数据)
self.filter_result_callback = None
# 预计算滤波器系数(仅执行一次)
self._init_filters()
def _init_filters(self):
"""预计算所有滤波器系数(仅执行一次)"""
# 50Hz工频陷波Q=30工业标准
self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate)
# 8~30Hz带通FIR65阶线性相位
self.b_bp = signal.firwin(
numtaps=65,
cutoff=[8/(self.srate/2), 30/(self.srate/2)],
pass_zero=False,
window='hamming'
)
self.a_bp = np.array([1.0])
def _filter_window_data(self, window_data):
"""对3秒窗口数据执行滤波返回无边界效应的200ms数据"""
# 零相位滤波(无延迟,无边界效应)
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1)
# 提取倒数第二个200ms的数据完全避开两端边界效应
# 窗口长度750步长50 → start=750-100=650end=750-50=700
start_idx = self.window_size - 2 * self.step_size
end_idx = self.window_size - self.step_size
output_data = filtered[:, start_idx:end_idx].copy()
return output_data
def run(self):
"""线程主逻辑精确200ms触发一次滤波"""
# 精确定时核心基于perf_counter计算下一次执行时间补偿sleep误差
interval = self.step_sec # 200ms = 0.2秒
next_run_time = time.perf_counter()
while self.running.is_set():
# 1. 等待到下一次执行时间(精确定时)
current_time = time.perf_counter()
if current_time < next_run_time:
time.sleep(next_run_time - current_time)
next_run_time += interval # 补偿:下次执行时间基于上一次目标时间
else:
# 若超时如滤波耗时超过200ms重置下一次时间避免累积误差
algo_log("滤波耗时超过200ms定时偏移", level='debug')
next_run_time = time.perf_counter() + interval
# 2. 执行滤波逻辑
try:
# 获取最新的3秒窗口数据
window_data = self.ring_buffer.get_latest_n_points(self.window_size)
algo_log(f"获取到{window_data.shape}数据", level='debug')
if window_data is None:
algo_log(f"缓存数据不足,当前缓存{self.ring_buffer.GetDataLenCount()}点,需{self.window_size}", level='debug')
continue
# 滤波并提取无边界效应的200ms数据
filtered_data = self._filter_window_data(window_data)
algo_log(f"滤波后{filtered_data.shape}数据", level='debug')
# 回调返回结果(外部可处理)
if self.filter_result_callback is not None:
self.filter_result_callback(filtered_data[:64, :]) # 只发送前64通道数据
except Exception as e:
algo_log(f"滤波执行异常: {e}", level='error')
def set_result_callback(self, callback):
"""注册滤波结果回调函数"""
self.filter_result_callback = callback
def stop(self):
"""停止滤波线程(安全版)"""
# 1. 先设置停止标志Event.clear()是线程安全的)
self.running.clear()
# 2. 核心修复只有线程已启动且正在运行时才调用join
if self.is_alive():
# 等待线程正常退出最多1秒
self.join(timeout=1)
# 超时未退出时打印警告,便于排查问题
if self.is_alive():
algo_log("警告滤波线程在1秒内未正常退出可能存在阻塞操作", level="WARNING")
# 3. 无论线程是否启动,都打印停止日志
algo_log("滤波线程已停止")