dev 1
This commit is contained in:
208
Zmq/filterProcess.py
Normal file
208
Zmq/filterProcess.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# -*-coding:utf-8 -*-
|
||||
"""
|
||||
数据滤波模块
|
||||
"""
|
||||
import numpy as np
|
||||
import threading
|
||||
from logs.log import algo_log
|
||||
|
||||
class FilterRingBuffer:
|
||||
def __init__(self, n_chan, n_points):
|
||||
"""
|
||||
初始化纯数据环形缓存
|
||||
:param n_chan: 通道数
|
||||
:param n_points: 总缓存点数(与paradigmRingBuffer参数完全一致)
|
||||
"""
|
||||
self.n_chan = n_chan
|
||||
self.n_points = n_points
|
||||
|
||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
|
||||
self.current_ptr = 0 # 写入指针
|
||||
self.total_samples = 0 # 已写入总点数
|
||||
|
||||
# 线程安全锁(多线程环境必须)
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def appendBuffer(self, data):
|
||||
"""
|
||||
追加数据到缓存(与paradigmRingBuffer接口一致)
|
||||
:param data: 输入数据,shape=(n_chan, n_samples)
|
||||
"""
|
||||
with self.lock:
|
||||
n = data.shape[1]
|
||||
if n == 0:
|
||||
return
|
||||
|
||||
# 环形写入逻辑
|
||||
write_end = self.current_ptr + n
|
||||
if write_end <= self.n_points:
|
||||
self.buffer[:, self.current_ptr:write_end] = data
|
||||
else:
|
||||
split = self.n_points - self.current_ptr
|
||||
self.buffer[:, self.current_ptr:] = data[:, :split]
|
||||
self.buffer[:, :write_end - self.n_points] = data[:, split:]
|
||||
|
||||
# 更新指针和计数
|
||||
self.current_ptr = write_end % self.n_points
|
||||
self.total_samples = min(self.total_samples + n, self.n_points)
|
||||
|
||||
def getData(self, count):
|
||||
"""
|
||||
从读指针位置读取count个点(与paradigmRingBuffer接口一致)
|
||||
:param count: 读取点数
|
||||
:return: np.ndarray, shape=(n_chan, count)
|
||||
"""
|
||||
with self.lock:
|
||||
count = min(count, self.total_samples)
|
||||
if count == 0:
|
||||
return np.zeros((self.n_chan, 0))
|
||||
|
||||
# 环形读取逻辑(与paradigmRingBuffer完全相同)
|
||||
end = self.current_ptr
|
||||
start = end - count
|
||||
if start >= 0:
|
||||
return self.buffer[:, start:end].copy()
|
||||
else:
|
||||
part1 = self.buffer[:, start:]
|
||||
part2 = self.buffer[:, :end]
|
||||
return np.concatenate((part1, part2), axis=1)
|
||||
|
||||
def get_latest_n_points(self, n):
|
||||
"""
|
||||
扩展方法:获取最新的n个点(不移动读指针,用于滑动窗口)
|
||||
:param n: 点数
|
||||
:return: np.ndarray, shape=(n_chan, n)
|
||||
"""
|
||||
with self.lock:
|
||||
if self.total_samples < n:
|
||||
return None
|
||||
return self.getData(n)
|
||||
|
||||
def GetDataLenCount(self):
|
||||
"""获取当前缓存总点数(兼容原有接口)"""
|
||||
with self.lock:
|
||||
return self.total_samples
|
||||
|
||||
def resetAllPara(self):
|
||||
"""重置所有缓存和指针(兼容原有接口)"""
|
||||
with self.lock:
|
||||
self.buffer.fill(0.0)
|
||||
self.current_ptr = 0
|
||||
self.total_samples = 0
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
|
||||
# 可替换任意缓存实现,只要实现appendBuffer、get_latest_n_points接口
|
||||
# -----------------------------------------------------------------------------
|
||||
class SlidingFilter:
|
||||
def __init__(
|
||||
self,
|
||||
n_chan=66,
|
||||
srate=250,
|
||||
buffer_sec=5,
|
||||
window_sec=3,
|
||||
step_sec=0.2,
|
||||
packet_size=5
|
||||
):
|
||||
"""
|
||||
初始化滑动滤波器
|
||||
:param n_chan: 通道数
|
||||
:param srate: 采样率
|
||||
:param buffer_sec: 总缓存时长(秒)
|
||||
:param window_sec: 滤波窗口时长(秒)
|
||||
:param step_sec: 滑动步长/输出时长(秒)
|
||||
:param packet_size: 每包数据点数(20ms一包=5点)
|
||||
"""
|
||||
# 核心参数
|
||||
self.n_chan = n_chan
|
||||
self.srate = srate
|
||||
self.buffer_size = int(srate * buffer_sec)
|
||||
self.window_size = int(srate * window_sec)
|
||||
self.step_size = int(srate * step_sec)
|
||||
self.packet_size = packet_size
|
||||
|
||||
# 初始化纯数据缓存(解耦核心)
|
||||
self.buffer = FilterRingBuffer(n_chan, self.buffer_size)
|
||||
|
||||
# 滤波触发计数器
|
||||
self.packet_count = 0
|
||||
self.ready_to_filter = False
|
||||
|
||||
# 预计算滤波器系数
|
||||
self._init_filters()
|
||||
|
||||
def _init_filters(self):
|
||||
"""预计算所有滤波器系数(仅执行一次)"""
|
||||
# 50Hz工频陷波(Q=30,工业标准)
|
||||
self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate)
|
||||
# 8~30Hz带通FIR(65阶,线性相位)
|
||||
self.b_bp = signal.firwin(
|
||||
numtaps=65,
|
||||
cutoff=[8/(self.srate/2), 30/(self.srate/2)],
|
||||
pass_zero=False,
|
||||
window='hamming'
|
||||
)
|
||||
self.a_bp = np.array([1.0])
|
||||
|
||||
def append_and_check_trigger(self, raw_data):
|
||||
"""
|
||||
追加单包原始数据并检查是否触发滤波
|
||||
:param raw_data: 上位机原始数据,shape=(packet_size, n_chan)
|
||||
:return: bool: 是否触发本次滤波
|
||||
"""
|
||||
# 转置为标准格式:(通道数, 点数)
|
||||
data = raw_data.T.astype(np.float64)
|
||||
|
||||
# 写入缓存(纯缓存操作)
|
||||
self.buffer.appendBuffer(data)
|
||||
|
||||
# 更新包计数器
|
||||
self.packet_count += 1
|
||||
|
||||
# 检查滤波触发条件:数据≥窗口长度 且 累计满一个步长的包数
|
||||
packets_per_step = int(self.step_size / self.packet_size) # 10包=200ms
|
||||
if (self.buffer.GetDataLenCount() >= self.window_size
|
||||
and self.packet_count >= packets_per_step):
|
||||
self.packet_count = 0
|
||||
self.ready_to_filter = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def filter_and_get_output(self):
|
||||
"""
|
||||
执行滤波并返回无边界效应的输出数据
|
||||
:return: np.ndarray: 滤波后数据,shape=(n_chan, step_size)
|
||||
"""
|
||||
if not self.ready_to_filter:
|
||||
return None
|
||||
|
||||
# 获取最新的完整滤波窗口数据
|
||||
window_data = self.buffer.get_latest_n_points(self.window_size)
|
||||
if window_data is None:
|
||||
self.ready_to_filter = False
|
||||
return None
|
||||
|
||||
# 零相位滤波(无延迟,无边界效应)
|
||||
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
|
||||
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
|
||||
filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1)
|
||||
|
||||
# 提取倒数第二个步长的数据(完全避开两端边界效应)
|
||||
start_idx = self.window_size - 2 * self.step_size
|
||||
end_idx = self.window_size - self.step_size
|
||||
output_data = filtered[:, start_idx:end_idx].copy()
|
||||
|
||||
# 重置触发标志
|
||||
self.ready_to_filter = False
|
||||
|
||||
return output_data
|
||||
|
||||
def reset(self):
|
||||
"""重置滤波器和缓存"""
|
||||
self.buffer.resetAllPara()
|
||||
self.packet_count = 0
|
||||
self.ready_to_filter = False
|
||||
|
||||
def get_buffer_length(self):
|
||||
"""获取当前缓存数据长度"""
|
||||
return self.buffer.GetDataLenCount()
|
||||
Reference in New Issue
Block a user