Compare commits
2 Commits
ba4ae92647
...
1bbe84eb56
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1bbe84eb56 | ||
|
|
f21367bc20 |
@@ -62,6 +62,8 @@ class Decoder_main(threading.Thread):
|
|||||||
|
|
||||||
# 注册滤波结果回调(示例:打印数据形状)
|
# 注册滤波结果回调(示例:打印数据形状)
|
||||||
self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data
|
self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data
|
||||||
|
# 注册 beta_psd 广播回调,每秒通过 8099 端口发送给上位机
|
||||||
|
self.sliding_filter.beta_broadcast_callback = lambda v: self.zmqServer.broadcast_message('beta_psd', v)
|
||||||
|
|
||||||
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
||||||
# data: (chans, samples)
|
# data: (chans, samples)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class Beta_Calculate():
|
|||||||
alpha_psd = np.sum(self.band_psd(freqs, psd, (8, 13)))
|
alpha_psd = np.sum(self.band_psd(freqs, psd, (8, 13)))
|
||||||
theta_psd = np.sum(self.band_psd(freqs, psd, (4, 8)))
|
theta_psd = np.sum(self.band_psd(freqs, psd, (4, 8)))
|
||||||
|
|
||||||
print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}")
|
# print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}")
|
||||||
|
|
||||||
return beta_psd, alpha_psd, theta_psd
|
return beta_psd, alpha_psd, theta_psd
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,10 @@ import time
|
|||||||
import threading
|
import threading
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
from logs.log import algo_log
|
from logs.log import algo_log
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from Tools.beta_calculate import Beta_Calculate
|
||||||
|
|
||||||
class FilterRingBuffer:
|
class FilterRingBuffer:
|
||||||
def __init__(self, n_chan, n_points):
|
def __init__(self, n_chan, n_points):
|
||||||
@@ -117,6 +121,14 @@ class SlidingFilter(threading.Thread):
|
|||||||
self.running.set()
|
self.running.set()
|
||||||
# 滤波结果回调(外部可注册,获取滤波后的数据)
|
# 滤波结果回调(外部可注册,获取滤波后的数据)
|
||||||
self.filter_result_callback = None
|
self.filter_result_callback = None
|
||||||
|
# beta_psd 广播回调(外部注册,用于走 zmqServer 8099 端口发送)
|
||||||
|
self.beta_broadcast_callback = None
|
||||||
|
|
||||||
|
# beta 计算器(Fp1/Fp2 通道,索引 0/1)
|
||||||
|
self._beta_calc = Beta_Calculate(Threshold_value_low=0, Threshold_value_high=0, fs=srate)
|
||||||
|
# beta 每秒触发计数(200ms步长,5次 = 1s)
|
||||||
|
self._beta_step_counter = 0
|
||||||
|
self._beta_steps_per_second = max(1, int(round(1.0 / step_sec))) # 5
|
||||||
|
|
||||||
# 预计算滤波器系数(仅执行一次)
|
# 预计算滤波器系数(仅执行一次)
|
||||||
self._init_filters()
|
self._init_filters()
|
||||||
@@ -135,7 +147,7 @@ class SlidingFilter(threading.Thread):
|
|||||||
self.a_bp = np.array([1.0])
|
self.a_bp = np.array([1.0])
|
||||||
|
|
||||||
def _filter_window_data(self, window_data):
|
def _filter_window_data(self, window_data):
|
||||||
"""对3秒窗口数据执行滤波,返回无边界效应的200ms数据"""
|
"""对3秒窗口数据执行滤波,返回 (无边界效应的200ms数据, 完整3s滤波数据)"""
|
||||||
# 零相位滤波(无延迟,无边界效应)
|
# 零相位滤波(无延迟,无边界效应)
|
||||||
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
|
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
|
||||||
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
|
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
|
||||||
@@ -146,7 +158,7 @@ class SlidingFilter(threading.Thread):
|
|||||||
start_idx = self.window_size - 2 * self.step_size
|
start_idx = self.window_size - 2 * self.step_size
|
||||||
end_idx = self.window_size - self.step_size
|
end_idx = self.window_size - self.step_size
|
||||||
output_data = filtered[:, start_idx:end_idx].copy()
|
output_data = filtered[:, start_idx:end_idx].copy()
|
||||||
return output_data
|
return output_data, filtered
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""线程主逻辑:精确200ms触发一次滤波"""
|
"""线程主逻辑:精确200ms触发一次滤波"""
|
||||||
@@ -174,9 +186,24 @@ class SlidingFilter(threading.Thread):
|
|||||||
algo_log(f"缓存数据不足,当前缓存{self.ring_buffer.GetDataLenCount()}点,需{self.window_size}点", level='debug')
|
algo_log(f"缓存数据不足,当前缓存{self.ring_buffer.GetDataLenCount()}点,需{self.window_size}点", level='debug')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
filtered_data = self._filter_window_data(window_data)
|
filtered_data, filtered_full = self._filter_window_data(window_data)
|
||||||
# algo_log(f"滤波后{filtered_data.shape}数据", level='debug')
|
# algo_log(f"滤波后{filtered_data.shape}数据", level='debug')
|
||||||
|
|
||||||
|
# ========== beta_psd 每秒计算一次(Fp1/Fp2,通道索引 0/1)==========
|
||||||
|
self._beta_step_counter += 1
|
||||||
|
if self._beta_step_counter >= self._beta_steps_per_second:
|
||||||
|
self._beta_step_counter = 0
|
||||||
|
try:
|
||||||
|
# 直接使用已滤波的完整3s数据的前两通道(Fp1/Fp2)
|
||||||
|
filter_betadata = filtered_full[:2, :] # shape (2, 750)
|
||||||
|
beta_psd, _, _ = self._beta_calc.calculate_all(
|
||||||
|
filter_betadata, fs=self.srate, nperseg=min(self.window_size, filter_betadata.shape[1])
|
||||||
|
)
|
||||||
|
if self.beta_broadcast_callback is not None:
|
||||||
|
self.beta_broadcast_callback(round(float(beta_psd), 3))
|
||||||
|
except Exception as be:
|
||||||
|
algo_log(f"beta_psd计算异常: {be}", level='error')
|
||||||
|
|
||||||
if self.filter_result_callback is not None:
|
if self.filter_result_callback is not None:
|
||||||
self.filter_result_callback(filtered_data[:64, :])
|
self.filter_result_callback(filtered_data[:64, :])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -405,10 +405,14 @@ class zmqServer(threading.Thread):
|
|||||||
frames = self.cmd_socket.recv_multipart()
|
frames = self.cmd_socket.recv_multipart()
|
||||||
self._handle_cmd_message(frames)
|
self._handle_cmd_message(frames)
|
||||||
|
|
||||||
# 处理8100数据端口消息
|
# 处理8100数据端口消息(排空积压,消除标签延迟)
|
||||||
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
|
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
|
||||||
frames = self.data_socket.recv_multipart()
|
while True:
|
||||||
self._handle_data_message(frames)
|
try:
|
||||||
|
frames = self.data_socket.recv_multipart(zmq.NOBLOCK)
|
||||||
|
self._handle_data_message(frames)
|
||||||
|
except zmq.Again:
|
||||||
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
algo_log(f"服务器主循环异常: {e}", level="ERROR")
|
algo_log(f"服务器主循环异常: {e}", level="ERROR")
|
||||||
|
|||||||
@@ -170,6 +170,7 @@ def run_headless():
|
|||||||
|
|
||||||
time.sleep(1) # 等待连接建立
|
time.sleep(1) # 等待连接建立
|
||||||
client.send_data('decoderClass', 'mi')
|
client.send_data('decoderClass', 'mi')
|
||||||
|
time.sleep(4) # 等待 zmqServer 排空启动积压包(datamock 提前连接会积压 ~3s 数据)
|
||||||
|
|
||||||
# MI_IntervalEpoch = [0.5, 4.5],trial时长 = 4.5-0.5 = 4.0s
|
# MI_IntervalEpoch = [0.5, 4.5],trial时长 = 4.5-0.5 = 4.0s
|
||||||
_mi_iv = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
|
_mi_iv = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
|
||||||
@@ -222,7 +223,7 @@ def run_headless():
|
|||||||
time.sleep(0.5) # ding 提示后等待
|
time.sleep(0.5) # ding 提示后等待
|
||||||
|
|
||||||
client.send_data('train', 0)
|
client.send_data('train', 0)
|
||||||
time.sleep(train_time + epoch_wait) # 等待刺激时间 + epoch 完成时间
|
time.sleep(train_time + 0.2) # 等待刺激时间 + epoch 完成时间
|
||||||
|
|
||||||
trained += 1
|
trained += 1
|
||||||
client.send_data('rest', 0)
|
client.send_data('rest', 0)
|
||||||
@@ -231,7 +232,7 @@ def run_headless():
|
|||||||
# 空闲态样本采集(train 1,label=2)
|
# 空闲态样本采集(train 1,label=2)
|
||||||
print(f"\n[Train] 空闲态采集 (train 1) trained={trained}")
|
print(f"\n[Train] 空闲态采集 (train 1) trained={trained}")
|
||||||
client.send_data('train', 1)
|
client.send_data('train', 1)
|
||||||
time.sleep(train_time + epoch_wait) # 等待刺激时间 + epoch 完成时间
|
time.sleep(train_time + 0.2) # 等待刺激时间 + epoch 完成时间
|
||||||
|
|
||||||
trained += 1
|
trained += 1
|
||||||
client.send_data('rest', 0)
|
client.send_data('rest', 0)
|
||||||
|
|||||||
Reference in New Issue
Block a user