419 lines
13 KiB
Python
419 lines
13 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
#
|
|||
|
|
# Authors: Swolf <swolfforever@gmail.com>
|
|||
|
|
# Date: 2021/1/07
|
|||
|
|
# License: MIT License
|
|||
|
|
|
|||
|
|
|
|||
|
|
from typing import Optional, List, Tuple, Union
|
|||
|
|
import warnings
|
|||
|
|
import numpy as np
|
|||
|
|
from numpy import ndarray
|
|||
|
|
from numpy.linalg import linalg
|
|||
|
|
from scipy.linalg import solve, qr
|
|||
|
|
from scipy.signal import sosfiltfilt, cheby1, cheb1ord
|
|||
|
|
from sklearn.base import BaseEstimator, TransformerMixin, clone
|
|||
|
|
|
|||
|
|
|
|||
|
|
def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray:
|
|||
|
|
"""Transform spatial filters to spatial patterns based on paper [1]_.
|
|||
|
|
Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine
|
|||
|
|
information from different channels to extract signals of interest from EEG signals, but if our goal is
|
|||
|
|
neurophysiological interpretation or visualization of weights, activation patterns need to be constructed
|
|||
|
|
from the obtained spatial filters.
|
|||
|
|
|
|||
|
|
update log:
|
|||
|
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
|||
|
|
|
|||
|
|
Parameters
|
|||
|
|
----------
|
|||
|
|
W : ndarray
|
|||
|
|
Spatial filters, shape (n_channels, n_filters).
|
|||
|
|
Cx : ndarray
|
|||
|
|
Covariance matrix of eeg data, shape (n_channels, n_channels).
|
|||
|
|
Cs : ndarray
|
|||
|
|
Covariance matrix of source data, shape (n_channels, n_channels).
|
|||
|
|
|
|||
|
|
Returns
|
|||
|
|
-------
|
|||
|
|
A : ndarray
|
|||
|
|
Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern.
|
|||
|
|
|
|||
|
|
References
|
|||
|
|
----------
|
|||
|
|
.. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging.
|
|||
|
|
Neuroimage 87 (2014): 96-110.
|
|||
|
|
"""
|
|||
|
|
# use linalg.solve instead of inv, makes it more stable
|
|||
|
|
# see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py
|
|||
|
|
# and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html
|
|||
|
|
A = solve(Cs.T, np.dot(Cx, W).T).T
|
|||
|
|
return A
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FilterBank(BaseEstimator, TransformerMixin):
|
|||
|
|
"""
|
|||
|
|
Filter bank decomposition is a bandpass filter array that divides the input signal into
|
|||
|
|
multiple subband components and obtains the eigenvalues of each subband component.
|
|||
|
|
|
|||
|
|
update log:
|
|||
|
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
|||
|
|
|
|||
|
|
Parameters
|
|||
|
|
----------
|
|||
|
|
base_estimator : class
|
|||
|
|
Estimator for model training and feature extraction.
|
|||
|
|
filterbank : list[ndarray]
|
|||
|
|
A bandpass filter bank used to divide the input signal into multiple subband components.
|
|||
|
|
n_jobs : int
|
|||
|
|
Sets the number of CPU working cores. The default is None.
|
|||
|
|
|
|||
|
|
References
|
|||
|
|
----------
|
|||
|
|
.. [1] Chen X, Wang Y, Nakanishi M, et al. High-speed spelling with a noninvasive brain-computer interface[J].
|
|||
|
|
Proceedings of the national academy of sciences, 2015, 112(44): E6058-E6067.
|
|||
|
|
"""
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
base_estimator: BaseEstimator,
|
|||
|
|
filterbank: List[ndarray],
|
|||
|
|
n_jobs: Optional[int] = None,
|
|||
|
|
):
|
|||
|
|
self.base_estimator = base_estimator
|
|||
|
|
self.filterbank = filterbank
|
|||
|
|
self.n_jobs = n_jobs
|
|||
|
|
|
|||
|
|
def fit(self, X: ndarray, y: Optional[ndarray] = None, **kwargs):
|
|||
|
|
"""
|
|||
|
|
Training model
|
|||
|
|
|
|||
|
|
update log:
|
|||
|
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
|||
|
|
|
|||
|
|
Parameters
|
|||
|
|
----------
|
|||
|
|
X : None
|
|||
|
|
Training signal (parameters can be ignored, only used to maintain code structure).
|
|||
|
|
y : None
|
|||
|
|
Label data (ibid., ignorable).
|
|||
|
|
Yf : None
|
|||
|
|
Reference signal (ibid., ignorable).
|
|||
|
|
"""
|
|||
|
|
self.estimators_ = [
|
|||
|
|
clone(self.base_estimator) for _ in range(len(self.filterbank))
|
|||
|
|
]
|
|||
|
|
X = self.transform_filterbank(X)
|
|||
|
|
for i, est in enumerate(self.estimators_):
|
|||
|
|
est.fit(X[i], y, **kwargs)
|
|||
|
|
# def wrapper(est, X, y, kwargs):
|
|||
|
|
# est.fit(X, y, **kwargs)
|
|||
|
|
# return est
|
|||
|
|
# self.estimators_ = Parallel(n_jobs=self.n_jobs)(
|
|||
|
|
# delayed(wrapper)(est, X[i], y, kwargs) for i, est in enumerate(self.estimators_))
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
def transform(self, X: ndarray, **kwargs):
|
|||
|
|
"""
|
|||
|
|
The parameters stored in self are used to convert X into features, and X is filtered through the filter bank to
|
|||
|
|
obtain the eigenvalues of each subband component.
|
|||
|
|
|
|||
|
|
update log:
|
|||
|
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
|||
|
|
|
|||
|
|
Parameters
|
|||
|
|
----------
|
|||
|
|
X : ndarray, shape(n_trials, n_channels, n_samples)
|
|||
|
|
Test the signal.
|
|||
|
|
|
|||
|
|
Returns
|
|||
|
|
-------
|
|||
|
|
feat : ndarray, shape(n_trials, n_fre)
|
|||
|
|
Feature array.
|
|||
|
|
"""
|
|||
|
|
X = self.transform_filterbank(X)
|
|||
|
|
feat = [est.transform(X[i], **kwargs) for i, est in enumerate(self.estimators_)]
|
|||
|
|
# def wrapper(est, X, kwargs):
|
|||
|
|
# retval = est.transform(X, **kwargs)
|
|||
|
|
# return retval
|
|||
|
|
# feat = Parallel(n_jobs=self.n_jobs)(
|
|||
|
|
# delayed(wrapper)(est, X[i], kwargs) for i, est in enumerate(self.estimators_))
|
|||
|
|
feat = np.concatenate(feat, axis=-1)
|
|||
|
|
return feat
|
|||
|
|
|
|||
|
|
def transform_filterbank(self, X: ndarray):
|
|||
|
|
"""
|
|||
|
|
The input signal is filtered through a filter bank.
|
|||
|
|
|
|||
|
|
update log:
|
|||
|
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
|||
|
|
|
|||
|
|
Parameters
|
|||
|
|
----------
|
|||
|
|
X : ndarray, shape(n_trials, n_channels, n_samples)
|
|||
|
|
Input signal.
|
|||
|
|
|
|||
|
|
Returns
|
|||
|
|
-------
|
|||
|
|
Xs: ndarray, shape(Nfb, n_trials, n_channels, n_samples)
|
|||
|
|
Individual subband components of the input signal.
|
|||
|
|
"""
|
|||
|
|
Xs = np.stack([sosfiltfilt(sos, X, axis=-1) for sos in self.filterbank])
|
|||
|
|
return Xs
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FilterBankSSVEP(FilterBank):
|
|||
|
|
"""
|
|||
|
|
Filter bank analysis for SSVEP.
|
|||
|
|
The SSVEP is analyzed using filter banks, that is, multiple filters are combined to decompose the SSVEP signal
|
|||
|
|
into specific segments (subbands containing the original data) and obtain its characteristic data.
|
|||
|
|
|
|||
|
|
update log:
|
|||
|
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
|||
|
|
|
|||
|
|
Parameters
|
|||
|
|
----------
|
|||
|
|
filterbank : list[ndarray]
|
|||
|
|
The filter bank.
|
|||
|
|
base_estimator : class
|
|||
|
|
Estimator for model training and feature extraction.
|
|||
|
|
filterweights : ndarray
|
|||
|
|
Filter weight, default is None.
|
|||
|
|
n_jobs : int
|
|||
|
|
Sets the number of CPU working cores. The default is None.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
filterbank: List[ndarray],
|
|||
|
|
base_estimator: BaseEstimator,
|
|||
|
|
filterweights: Optional[ndarray] = None,
|
|||
|
|
n_jobs: Optional[int] = None,
|
|||
|
|
):
|
|||
|
|
self.filterweights = filterweights
|
|||
|
|
super().__init__(base_estimator, filterbank, n_jobs=n_jobs)
|
|||
|
|
|
|||
|
|
def transform(self, X: ndarray): # type: ignore[override]
|
|||
|
|
"""
|
|||
|
|
X is converted into features by using the parameters stored in self, and the eigenvalues of each subband
|
|||
|
|
component are obtained after the input signal is filtered by the filter bank.
|
|||
|
|
|
|||
|
|
update log:
|
|||
|
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
|||
|
|
|
|||
|
|
Parameters
|
|||
|
|
----------
|
|||
|
|
X : ndarray, shape(n_trials, n_channels, n_samples)
|
|||
|
|
Test the signal.
|
|||
|
|
|
|||
|
|
Returns
|
|||
|
|
-------
|
|||
|
|
features : ndarray, shape(n_trials, n_fre)
|
|||
|
|
Feature array.
|
|||
|
|
"""
|
|||
|
|
features = super().transform(X)
|
|||
|
|
if self.filterweights is None:
|
|||
|
|
return features
|
|||
|
|
else:
|
|||
|
|
features = np.reshape(
|
|||
|
|
features, (features.shape[0], len(self.filterbank), -1)
|
|||
|
|
)
|
|||
|
|
return np.sum(
|
|||
|
|
features * self.filterweights[np.newaxis, :, np.newaxis], axis=1
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_filterbank(
|
|||
|
|
passbands: List[Tuple[float, float]],
|
|||
|
|
stopbands: List[Tuple[float, float]],
|
|||
|
|
srate: int,
|
|||
|
|
order: Optional[int] = None,
|
|||
|
|
rp: float = 0.5,
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Create a filter bank, that is, obtain a bandpass filter coefficient that can divide the input signal into multiple
|
|||
|
|
subband components.
|
|||
|
|
|
|||
|
|
update log:
|
|||
|
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
|||
|
|
|
|||
|
|
Parameters
|
|||
|
|
----------
|
|||
|
|
passbands : list or tuple(float, float)
|
|||
|
|
Passband parameters.
|
|||
|
|
stopbands : list or tuple(float, float)
|
|||
|
|
Stopband parameters.
|
|||
|
|
srate : float
|
|||
|
|
Sampling rate.
|
|||
|
|
order : int
|
|||
|
|
Filter order.
|
|||
|
|
rp : float
|
|||
|
|
The maximum ripple allowed in the passband below the unit gain is 0.5 by default.
|
|||
|
|
|
|||
|
|
Returns
|
|||
|
|
-------
|
|||
|
|
Filterbank:ndarray, shape(len(passbands), N, 6)
|
|||
|
|
Filter bank coefficient.
|
|||
|
|
"""
|
|||
|
|
filterbank = []
|
|||
|
|
for wp, ws in zip(passbands, stopbands):
|
|||
|
|
if order is None:
|
|||
|
|
N, wn = cheb1ord(wp, ws, 3, 40, fs=srate)
|
|||
|
|
sos = cheby1(N, rp, wn, btype="bandpass", output="sos", fs=srate)
|
|||
|
|
else:
|
|||
|
|
sos = cheby1(order, rp, wp, btype="bandpass", output="sos", fs=srate)
|
|||
|
|
|
|||
|
|
filterbank.append(sos)
|
|||
|
|
return filterbank
|
|||
|
|
|
|||
|
|
def process(data):
|
|||
|
|
# 白化操作
|
|||
|
|
meanValue = np.mat(data.mean(axis=1))
|
|||
|
|
meanData = np.repeat(meanValue, data.shape[1], axis=1)
|
|||
|
|
whiteTemp = data - meanData
|
|||
|
|
# QR 分解
|
|||
|
|
rankWhiteTemp = whiteTemp.shape[0]
|
|||
|
|
whiteTemp = np.transpose(whiteTemp)
|
|||
|
|
Q, R = qr(whiteTemp.A, mode='economic')
|
|||
|
|
# 计算矩阵的秩
|
|||
|
|
rankQ = linalg.matrix_rank(R)
|
|||
|
|
if rankQ == 0:
|
|||
|
|
raise ValueError('stats:canoncorr:badData')
|
|||
|
|
elif rankQ <= rankWhiteTemp:
|
|||
|
|
# warnings.warn('stats:canoncorr:NotFullRank')
|
|||
|
|
Q = Q[:, 0:rankQ]
|
|||
|
|
return Q, rankQ
|
|||
|
|
|
|||
|
|
def reference(listFreqs,fs, numberSmples, num_harms):
|
|||
|
|
numberFrequence = len(listFreqs)
|
|||
|
|
timeIndex = np.arange(1, numberSmples + 1) / fs # time index
|
|||
|
|
referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples))
|
|||
|
|
for frequenceIndex in range(numberFrequence):
|
|||
|
|
temp = []
|
|||
|
|
for harmIndex in range(1, num_harms + 1):
|
|||
|
|
stimFrequence = listFreqs[frequenceIndex] # in HZ
|
|||
|
|
# Sin and Cos
|
|||
|
|
temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence),
|
|||
|
|
np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)])
|
|||
|
|
referenceTemp = np.mat(temp)
|
|||
|
|
# 白化操作和QR分解
|
|||
|
|
Q, rankQ = process(referenceTemp)
|
|||
|
|
referenceData[frequenceIndex] = np.transpose(Q)
|
|||
|
|
return referenceData
|
|||
|
|
|
|||
|
|
def generate_cca_references(
|
|||
|
|
freqs: Union[ndarray, int, float],
|
|||
|
|
srate,
|
|||
|
|
T,
|
|||
|
|
phases: Optional[Union[ndarray, int, float]] = None,
|
|||
|
|
n_harmonics: int = 1,
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Construct a sine-cosine reference signal for canonical correlation analysis (CCA).
|
|||
|
|
|
|||
|
|
update log:
|
|||
|
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
|||
|
|
|
|||
|
|
Parameters
|
|||
|
|
----------
|
|||
|
|
freqs : int or float
|
|||
|
|
Frequency.
|
|||
|
|
srate : int
|
|||
|
|
Sampling rate.
|
|||
|
|
T : int
|
|||
|
|
Sampling time.
|
|||
|
|
phases : int or float
|
|||
|
|
Phase, default is None.
|
|||
|
|
n_harmonics : int
|
|||
|
|
The number of harmonics. The default value is 1.
|
|||
|
|
|
|||
|
|
Returns
|
|||
|
|
-------
|
|||
|
|
Yf:ndarray, shape(srate*T, n_harmonics*2)
|
|||
|
|
Sine and cosine reference signal.
|
|||
|
|
"""
|
|||
|
|
if isinstance(freqs, int) or isinstance(freqs, float):
|
|||
|
|
freqs = np.array([freqs])
|
|||
|
|
freqs = np.array(freqs)[:, np.newaxis]
|
|||
|
|
if phases is None:
|
|||
|
|
phases = 0
|
|||
|
|
if isinstance(phases, int) or isinstance(phases, float):
|
|||
|
|
phases = np.array([phases])
|
|||
|
|
phases = np.array(phases)[:, np.newaxis]
|
|||
|
|
t = np.linspace(0, T, int(T * srate))
|
|||
|
|
|
|||
|
|
Yf = []
|
|||
|
|
for i in range(n_harmonics):
|
|||
|
|
Yf.append(
|
|||
|
|
np.stack(
|
|||
|
|
[
|
|||
|
|
np.sin(2 * np.pi * (i + 1) * freqs * t + np.pi * phases),
|
|||
|
|
np.cos(2 * np.pi * (i + 1) * freqs * t + np.pi * phases),
|
|||
|
|
],
|
|||
|
|
axis=1,
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
Yf = np.concatenate(Yf, axis=1)
|
|||
|
|
return Yf
|
|||
|
|
|
|||
|
|
|
|||
|
|
def sign_flip(u, s, vh=None):
|
|||
|
|
"""Flip signs of SVD or EIG using the method in paper [1]_.
|
|||
|
|
|
|||
|
|
update log:
|
|||
|
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
|||
|
|
|
|||
|
|
Parameters
|
|||
|
|
----------
|
|||
|
|
u: ndarray
|
|||
|
|
left singular vectors, shape (M, K).
|
|||
|
|
s: ndarray
|
|||
|
|
singular values, shape (K,).
|
|||
|
|
vh: ndarray or None
|
|||
|
|
transpose of right singular vectors, shape (K, N).
|
|||
|
|
|
|||
|
|
Returns
|
|||
|
|
-------
|
|||
|
|
u: ndarray
|
|||
|
|
corrected left singular vectors.
|
|||
|
|
s: ndarray
|
|||
|
|
singular values.
|
|||
|
|
vh: ndarray
|
|||
|
|
transpose of corrected right singular vectors.
|
|||
|
|
|
|||
|
|
References
|
|||
|
|
----------
|
|||
|
|
.. [1] https://www.sandia.gov/~tgkolda/pubs/pubfiles/SAND2007-6422.pdf
|
|||
|
|
"""
|
|||
|
|
if vh is None:
|
|||
|
|
total_proj = np.sum(u * s, axis=0)
|
|||
|
|
signs = np.sign(total_proj)
|
|||
|
|
|
|||
|
|
random_idx = signs == 0
|
|||
|
|
if np.any(random_idx):
|
|||
|
|
signs[random_idx] = 1
|
|||
|
|
warnings.warn(
|
|||
|
|
"The magnitude is close to zero, the sign will become arbitrary."
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
u = u * signs
|
|||
|
|
|
|||
|
|
return u, s
|
|||
|
|
else:
|
|||
|
|
left_proj = np.sum(s[:, np.newaxis] * vh, axis=-1)
|
|||
|
|
right_proj = np.sum(u * s, axis=0)
|
|||
|
|
total_proj = left_proj + right_proj
|
|||
|
|
signs = np.sign(total_proj)
|
|||
|
|
|
|||
|
|
random_idx = signs == 0
|
|||
|
|
if np.any(random_idx):
|
|||
|
|
signs[random_idx] = 1
|
|||
|
|
warnings.warn(
|
|||
|
|
"The magnitude is close to zero, the sign will become arbitrary."
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
u = u * signs
|
|||
|
|
vh = signs[:, np.newaxis] * vh
|
|||
|
|
|
|||
|
|
return u, s, vh
|