Files
bci_algo/concentration/algorithm/calculate_focus.py
2026-06-06 14:57:52 +08:00

425 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
from scipy.signal import welch
from scipy.fft import fft
from scipy import signal
from collections import deque
import time
import os
# import logging
import base64
import io
import math
# logger = logging.getLogger(__name__)
#
# try:
# import matplotlib
# matplotlib.use('Agg')
# import matplotlib.pyplot as plt
# MATPLOTLIB_AVAILABLE = True
# except ImportError:
# MATPLOTLIB_AVAILABLE = False
# logger.warning("matplotlib未安装报告图表功能不可用")
class Calculate():
def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10, config=None):
self.Threshold_value_low = Threshold_value_low
self.Threshold_value_high = Threshold_value_high
self.fs = fs
self.focus_result = []
self.CLI_result = []
self.EVI_result = []
self.eegQueue = deque(maxlen=win_len)
# 初始化滤波器
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30)
self.b_design = signal.firwin(65, [2 / (self.fs/2), 40 / (self.fs/2)], pass_zero=False)
self.last_focus = None
# 异步滤波系数配置(核心手感控制纽)
self.alpha_up = 1 # 上升系数:较小,保证分数平滑爬升,过滤偶发的瞬时高能量
# alpha_down / shrink_factor 从 config.ini 读取,方便上位机调参
if config:
self.alpha_down = float(config.get('alpha_down', 0.8))
self.shrink_factor = float(config.get('shrink_factor', 0.5))
else:
self.alpha_down = 0.8
self.shrink_factor = 0.5
print("[调试] Calculate 类初始化完成")
def calculate_focus(self, beta, alpha, theta):
"""
专注度计算 - 三区间门限异步滤波版本
"""
# 0. 频带特征预处理
theta_mod = theta ** 0.7
# 原始比值
raw = beta / (alpha + theta_mod + 1e-10)
exponent = 2.0
# 1. 防止脑电比值出现负数异常值
raw_input = max(raw, 0.0)
# 2. 2次幂纵轴压缩映射 (shrink_factor 从 config.ini 读取)
focus_raw = 100 * self.shrink_factor * (raw_input ** exponent)
# 3. 计算当前帧的瞬时分数 (基准量级 0-120)
instant_focus = 120 * (1.0 - np.exp(-focus_raw / 100.0))
# 4. 核心修改:三区间门限时域滤波
if self.last_focus is None:
# 冷启动:首帧直接赋值
focus = instant_focus
else:
# 判断当前瞬时分数是否处于【极端区】(80以上 或 60以下)
if instant_focus > 85.0 or instant_focus < 60.0:
# 执行异步低通时域滤波
if instant_focus >= self.last_focus:
# 趋势上升:慢爬升
focus = self.alpha_up * instant_focus + (1 - self.alpha_up) * self.last_focus
else:
# 趋势下降:快跌落
focus = self.alpha_down * instant_focus + (1 - self.alpha_down) * self.last_focus
else:
# 【高灵敏自由区】(60 <= instant_focus <= 80)
# 不执行异步滤波,分数直接跟随瞬时值,保证中间状态绝对跟手
focus = instant_focus
# 5. 更新历史状态缓存
self.last_focus = focus
# 打印在线调试日志,方便观察区间切换
zone_tag = "极端区(滤波)" if (instant_focus > 80 or instant_focus < 60) else "自由区(直通)"
print(f"原始特征比值 raw: {raw:.4f} | 瞬时分数: {instant_focus:.1f} | 滤波后分数: {focus:.1f}")
# 最终返回整型
return int(focus)
def calculate_all(self, data, fs, nperseg=1000):
mean_x = np.mean(data, axis=-1, keepdims=True)
data = data - mean_x
freqs, psd = self.compute_psd_multichannel(data, fs, nperseg)
beta_psd = np.sum(self.band_psd(freqs, psd, (13, 30)))
alpha_psd = np.sum(self.band_psd(freqs, psd, (8, 13)))
theta_psd = np.sum(self.band_psd(freqs, psd, (4, 8)))
print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}")
focus_score = self.calculate_focus(beta_psd, alpha_psd, theta_psd)
focus_score = max(0, min(100, focus_score))
self.focus_result.append(focus_score)
if len(self.focus_result) > 3:
self.focus_result.pop(0)
final_focus = int(self.simple_moving_average(self.focus_result, window_size=5))
cli_denom = alpha_psd + beta_psd
CLI_score = np.log(theta_psd / (cli_denom + 1e-10)) if cli_denom > 0 else 0
self.CLI_result.append(CLI_score)
if len(self.CLI_result) > 5:
self.CLI_result.pop(0)
final_CLI = round(self.simple_moving_average(self.CLI_result, window_size=5), 2)
return final_focus, final_CLI, beta_psd, alpha_psd, theta_psd
def compute_psd_multichannel(self, data, fs=250, nperseg=1000):
n_samples = data.shape[-1]
if n_samples < nperseg:
nperseg = n_samples
noverlap = 500
if noverlap >= nperseg:
noverlap = int(nperseg / 2)
if nperseg == 0:
return np.array([]), np.zeros((data.shape[0], 0))
freqs, psd = welch(data, fs=fs, nperseg=nperseg, noverlap=noverlap, axis=-1)
return freqs, psd
def band_psd(self, freqs, psd, band):
idx = np.logical_and(freqs >= band[0], freqs <= band[1])
return np.sum(psd[:, idx], axis=-1)
def simple_moving_average(self, data, window_size=5):
if len(data) == 0:
return 30
window = data[-window_size:]
return sum(window) / len(window)
def reset_queue(self):
self.eegQueue.clear()
# def start_recording(self):
# """开始记录数据"""
# self.recording = True
# self.start_time = time.time()
# self.beta_history = []
# self.alpha_history = []
# self.theta_history = []
# self.focus_history = []
# self.timestamp_history = []
# print("[调试] ========== 开始记录专注度数据 ==========")
# def stop_recording(self):
# """停止记录并生成图表"""
# print(f"[调试] stop_recording被调用, recording={self.recording}, focus_history长度={len(self.focus_history)}")
# self.recording = False
# if len(self.focus_history) > 0:
# print("[调试] 数据非空,开始生成图表...")
# # 保存到本地文件
# chart_path = self.save_chart_to_file()
# if chart_path:
# print(f"[调试] 本地文件保存成功: {chart_path}")
# else:
# print("[调试] 本地文件保存失败")
# # 生成base64编码
# base64_data = self.generate_chart_base64()
# return base64_data
# else:
# print("[调试] 没有数据可保存focus_history为空")
# return None
# def add_data_point(self, focus, beta, alpha, theta):
# if not self.recording:
# return
# current_time = time.time()
# elapsed = current_time - self.start_time
#
# self.beta_history.append(beta)
# self.alpha_history.append(alpha)
# self.theta_history.append(theta)
# self.focus_history.append(focus)
# self.timestamp_history.append(elapsed)
# print(f"[调试] 记录数据点: time={elapsed:.1f}s, focus={focus}, beta={beta:.2f}")
# def save_chart_to_file(self):
# """
# 保存图表到本地文件(唯一实现)
# """
# print(f"[调试] save_chart_to_file被调用, MATPLOTLIB_AVAILABLE={MATPLOTLIB_AVAILABLE}")
#
# if not MATPLOTLIB_AVAILABLE:
# print("[调试] matplotlib不可用无法保存")
# return None
#
# if len(self.focus_history) < 2:
# print(f"[调试] 数据点不足需要至少2个点当前{len(self.focus_history)}个点")
# return None
#
# print(f"[调试] 开始保存图表到本地文件...")
#
# # 确保所有列表长度一致
# min_len = min(len(self.beta_history), len(self.alpha_history),
# len(self.theta_history), len(self.focus_history),
# len(self.timestamp_history))
#
# print(f"[调试] 数据长度: min_len={min_len}")
#
# beta_list = self.beta_history[:min_len]
# alpha_list = self.alpha_history[:min_len]
# theta_list = self.theta_history[:min_len]
# focus_list = self.focus_history[:min_len]
# times = self.timestamp_history[:min_len]
#
# # 生成文件名
# timestamp = time.strftime("%Y%m%d_%H%M%S")
# chart_path = os.path.join(self.chart_dir, f"concentration_report_{timestamp}.png")
# print(f"[调试] 保存路径: {chart_path}")
#
# try:
# # 创建图表
# fig, ax1 = plt.subplots(figsize=(14, 8))
#
# # 左Y轴功率数据
# ax1.plot(times, beta_list, 'b-', linewidth=1.5, alpha=0.8, label='Beta Power')
# ax1.plot(times, alpha_list, 'g-', linewidth=1.5, alpha=0.8, label='Alpha Power')
# ax1.plot(times, theta_list, 'orange', linewidth=1.5, alpha=0.8, label='Theta Power')
# ax1.set_xlabel('Time (s)', fontsize=12)
# ax1.set_ylabel('Band Power', fontsize=12, color='black')
# ax1.tick_params(axis='y', labelcolor='black')
# ax1.legend(loc='upper left')
# ax1.grid(True, alpha=0.3)
#
# # 右Y轴专注度
# ax2 = ax1.twinx()
# ax2.plot(times, focus_list, 'r-', linewidth=2, alpha=0.9, label='Focus (%)')
# ax2.set_ylabel('Focus (%)', fontsize=12, color='red')
# ax2.tick_params(axis='y', labelcolor='red')
# ax2.set_ylim(0, 105)
# ax2.legend(loc='upper right')
#
# # 标题
# duration = times[-1] if times else 0
# avg_focus = np.mean(focus_list) if focus_list else 0
# plt.title(f'Concentration and EEG Band Power Trend\nDuration: {duration:.1f}s, Avg Focus: {avg_focus:.1f}%',
# fontsize=14)
#
# plt.tight_layout()
# plt.savefig(chart_path, dpi=150, bbox_inches='tight')
# plt.close()
#
# print(f"\n{'='*60}")
# print(f"专注度报告图片已保存到本地:")
# print(f" 文件路径: {chart_path}")
# print(f" 数据点数: {min_len}")
# print(f" 时长: {duration:.1f}秒")
# print(f" 平均专注度: {avg_focus:.1f}%")
# print(f"{'='*60}\n")
#
# return chart_path
#
# except Exception as e:
# print(f"[调试] 保存文件时出错: {e}")
# import traceback
# traceback.print_exc()
# return None
#
# def generate_chart_base64(self):
# """
# 生成图表的base64编码用于网络传输
# """
# if not MATPLOTLIB_AVAILABLE:
# return None
#
# if len(self.focus_history) < 2:
# return None
#
# min_len = min(len(self.beta_history), len(self.alpha_history),
# len(self.theta_history), len(self.focus_history),
# len(self.timestamp_history))
#
# beta_list = self.beta_history[:min_len]
# alpha_list = self.alpha_history[:min_len]
# theta_list = self.theta_history[:min_len]
# focus_list = self.focus_history[:min_len]
# times = self.timestamp_history[:min_len]
#
# fig, ax1 = plt.subplots(figsize=(14, 8))
#
# ax1.plot(times, beta_list, 'b-', linewidth=1.5, alpha=0.8, label='Beta Power')
# ax1.plot(times, alpha_list, 'g-', linewidth=1.5, alpha=0.8, label='Alpha Power')
# ax1.plot(times, theta_list, 'orange', linewidth=1.5, alpha=0.8, label='Theta Power')
# ax1.set_xlabel('Time (s)', fontsize=12)
# ax1.set_ylabel('Band Power', fontsize=12, color='black')
# ax1.tick_params(axis='y', labelcolor='black')
# ax1.legend(loc='upper left')
# ax1.grid(True, alpha=0.3)
#
# ax2 = ax1.twinx()
# ax2.plot(times, focus_list, 'r-', linewidth=2, alpha=0.9, label='Focus (%)')
# ax2.set_ylabel('Focus (%)', fontsize=12, color='red')
# ax2.tick_params(axis='y', labelcolor='red')
# ax2.set_ylim(0, 105)
# ax2.legend(loc='upper right')
#
# duration = times[-1] if times else 0
# avg_focus = np.mean(focus_list) if focus_list else 0
# plt.title(f'Concentration and EEG Band Power Trend\nDuration: {duration:.1f}s, Avg Focus: {avg_focus:.1f}%',
# fontsize=14)
#
# plt.tight_layout()
#
# buffer = io.BytesIO()
# plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
# buffer.seek(0)
# image_base64 = base64.b64encode(buffer.read()).decode('utf-8')
# plt.close()
#
# return image_base64
def queueOpt(self, data):
if data is None or data.size == 0:
return None
if len(self.eegQueue) < self.eegQueue.maxlen:
self.eegQueue.append(data)
else:
self.eegQueue.append(data)
if len(self.eegQueue) == self.eegQueue.maxlen:
eegData = np.hstack([self.eegQueue[i] for i in range(len(self.eegQueue))])
if eegData.size == 0:
return None
eegData -= np.mean(eegData, axis=-1, keepdims=True)
# eegData = signal.lfilter(self.b_notch, self.a_notch, eegData) # 陷波
# eegData = signal.lfilter(self.b_design, 1, eegData) # 滤波
focus_score, CLI_score, beta_psd, alpha_psd, theta_psd = self.calculate_all(eegData, fs=self.fs, nperseg=1000)
# self.add_data_point(focus_score, beta_psd, alpha_psd, theta_psd) # 已注释(方法已移除)
# return (focus_score)
return (focus_score, beta_psd)
# return None
class Calculate2():
def __init__(self, Threshold_value_low, Threshold_value_high):
self.Threshold_value_low = Threshold_value_low
self.Threshold_value_high = Threshold_value_high
self.focus_result = []
self.theta_result = []
self.alpha_result = []
self.flow_result = []
def calculate_all(self, data, fs, L=2500):
mean_x = np.mean(data, axis=-1, keepdims=True)
data = data - mean_x
Y = fft(data, axis=-1)
P2 = np.abs(Y / L)
P1 = P2[:, :L // 2 + 1]
P1[:, 1:-1] = 2 * P1[:, 1:-1]
beta_power = self.PSD(P1, L, fs, 13, 30)
alpha_power = self.PSD(P1, L, fs, 8, 13)
theta_power = self.PSD(P1, L, fs, 4, 8)
gamma_power = self.PSD(P1, L, fs, 30, 100)
focus_score = beta_power / (alpha_power + theta_power)
print('focus score:', focus_score)
focus_score = ((focus_score - self.Threshold_value_low) * 100) / (self.Threshold_value_high - self.Threshold_value_low)
self.focus_result.append(focus_score)
if len(self.focus_result) > 3:
self.focus_result.pop(0)
final_focus = int(self.simple_moving_average(self.focus_result, window_size=3))
self.theta_result.append(theta_power)
if len(self.theta_result) > 30:
self.theta_result.pop(0)
self.alpha_result.append(alpha_power)
if len(self.alpha_result) > 30:
self.alpha_result.pop(0)
rest_theta = self.simple_moving_average(self.theta_result, window_size=30)
rest_alpha = self.simple_moving_average(self.alpha_result, window_size=30)
distraction_score = (theta_power / rest_theta) * (1 - (alpha_power / rest_alpha))
flow_score = gamma_power / beta_power
flow_score = (flow_score / self.Threshold_value_high) * 100
self.flow_result.append(flow_score)
if len(self.flow_result) > 3:
self.flow_result.pop(0)
final_flow = int(self.simple_moving_average(self.flow_result, window_size=3))
return final_focus, distraction_score, final_flow
def PSD(self, P1, L, Fs, s_freq, e_freq):
s_point = round(s_freq * L / Fs)
e_point = round(e_freq * L / Fs)
x, y = P1.shape
band_PSD = 0
for i in range(x):
for j in range(s_point, e_point):
band_PSD += P1[i, j] ** 2
return band_PSD
def simple_moving_average(self, data, window_size=3):
if len(data) == 0:
return []
window = data[-window_size:]
return sum(window) / len(window)