Files
bci_algo/Zmq/filterProcess.py
2026-06-08 17:29:27 +08:00

204 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*-coding:utf-8 -*-
"""
数据滤波模块
"""
import numpy as np
import time
import threading
from scipy import signal
from logs.log import algo_log
class FilterRingBuffer:
def __init__(self, n_chan, n_points):
self.n_chan = n_chan
self.n_points = n_points
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
self.current_ptr = 0
self.total_samples = 0
self.lock = threading.Lock() # 仅保护元数据
self.has_new_data = False
def appendBuffer(self, data):
n = data.shape[1]
if n == 0:
return
# 仅加锁读取/更新元数据
with self.lock:
old_ptr = self.current_ptr
new_ptr = (old_ptr + n) % self.n_points
new_total = min(self.total_samples + n, self.n_points)
self.has_new_data = True
# 数组写入(耗时操作,移出锁外)
write_end = old_ptr + n
if write_end <= self.n_points:
self.buffer[:, old_ptr:write_end] = data
else:
split = self.n_points - old_ptr
self.buffer[:, old_ptr:] = data[:, :split]
self.buffer[:, :write_end - self.n_points] = data[:, split:]
# 再次加锁更新最终元数据
with self.lock:
self.current_ptr = new_ptr
self.total_samples = new_total
# ========== 新增:获取&清空新数据标记的方法 ==========
def check_and_clear_new_data(self):
"""检查是否有新数据,并一次性清空标记(消费后重置)"""
with self.lock:
flag = self.has_new_data
if flag:
self.has_new_data = False
return flag
def getData(self, count):
# 加锁获取最新元数据
with self.lock:
count = min(count, self.total_samples)
if count == 0:
return np.zeros((self.n_chan, 0))
end = self.current_ptr
start = end - count
# 数据读取、切片、拼接(无锁)
if start >= 0:
res = self.buffer[:, start:end].copy()
else:
part1 = self.buffer[:, start:]
part2 = self.buffer[:, :end]
res = np.concatenate((part1, part2), axis=1).copy()
return res
def get_latest_n_points(self, n):
with self.lock:
if self.total_samples < n:
return None
return self.getData(n)
def GetDataLenCount(self):
with self.lock:
return self.total_samples
def resetAllPara(self):
with self.lock:
self.buffer.fill(0.0)
self.current_ptr = 0
self.total_samples = 0
self.has_new_data = False # 重置时清空新数据标记
# -----------------------------------------------------------------------------
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
# -----------------------------------------------------------------------------
class SlidingFilter(threading.Thread):
def __init__(
self,
ring_buffer: FilterRingBuffer,
n_chan=66,
srate=250,
window_sec=3,
step_sec=0.2
):
super().__init__(daemon=True)
# 核心参数
self.n_chan = n_chan
self.srate = srate
self.step_sec = step_sec # 200ms滑动步长
self.window_sec = window_sec # 3秒窗口
self.step_sec = step_sec # 200ms滑动步长
self.window_size = int(srate * window_sec) # 3秒点数250*3=750
self.step_size = int(srate * step_sec) # 200ms点数250*0.2=50
# 关联ZMQServer的环形缓存解耦仅依赖接口
self.ring_buffer = ring_buffer
# 线程控制
self.running = threading.Event()
self.running.set()
# 滤波结果回调(外部可注册,获取滤波后的数据)
self.filter_result_callback = None
# 预计算滤波器系数(仅执行一次)
self._init_filters()
def _init_filters(self):
"""预计算所有滤波器系数(仅执行一次)"""
# 50Hz工频陷波Q=30工业标准
self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate)
# 8~30Hz带通FIR65阶线性相位
self.b_bp = signal.firwin(
numtaps=65,
cutoff=[8/(self.srate/2), 30/(self.srate/2)],
pass_zero=False,
window='hamming'
)
self.a_bp = np.array([1.0])
def _filter_window_data(self, window_data):
"""对3秒窗口数据执行滤波返回无边界效应的200ms数据"""
# 零相位滤波(无延迟,无边界效应)
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1)
# 提取倒数第二个200ms的数据完全避开两端边界效应
# 窗口长度750步长50 → start=750-100=650end=750-50=700
start_idx = self.window_size - 2 * self.step_size
end_idx = self.window_size - self.step_size
output_data = filtered[:, start_idx:end_idx].copy()
return output_data
def run(self):
"""线程主逻辑精确200ms触发一次滤波"""
interval = self.step_sec # 200ms = 0.2秒
next_run_time = time.perf_counter()
while self.running.is_set():
# 1. 精确定时等待
current_time = time.perf_counter()
if current_time < next_run_time:
time.sleep(next_run_time - current_time)
next_run_time += interval
else:
algo_log("滤波耗时超过200ms定时偏移", level='debug')
next_run_time = time.perf_counter() + interval
# ========== 新增核心判断:无新数据则直接跳过 ==========
if not self.ring_buffer.check_and_clear_new_data():
# 无新数据,不执行滤波、不发送数据
continue
# 2. 有新数据,才执行原有滤波逻辑
try:
window_data = self.ring_buffer.get_latest_n_points(self.window_size)
if window_data is None:
algo_log(f"缓存数据不足,当前缓存{self.ring_buffer.GetDataLenCount()}点,需{self.window_size}", level='debug')
continue
filtered_data = self._filter_window_data(window_data)
algo_log(f"滤波后{filtered_data.shape}数据", level='debug')
if self.filter_result_callback is not None:
self.filter_result_callback(filtered_data[:64, :])
except Exception as e:
algo_log(f"滤波执行异常: {e}", level='error')
def set_result_callback(self, callback):
"""注册滤波结果回调函数"""
self.filter_result_callback = callback
def stop(self):
"""停止滤波线程(安全版)"""
# 1. 先设置停止标志Event.clear()是线程安全的)
self.running.clear()
# 2. 核心修复只有线程已启动且正在运行时才调用join
if self.is_alive():
# 等待线程正常退出最多1秒
self.join(timeout=1)
# 超时未退出时打印警告,便于排查问题
if self.is_alive():
algo_log("警告滤波线程在1秒内未正常退出可能存在阻塞操作", level="WARNING")
# 3. 无论线程是否启动,都打印停止日志
algo_log("滤波线程已停止")