初始化zmq 项目
This commit is contained in:
418
SSMVEP/algorithm/base.py
Normal file
418
SSMVEP/algorithm/base.py
Normal file
@@ -0,0 +1,418 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Authors: Swolf <swolfforever@gmail.com>
|
||||
# Date: 2021/1/07
|
||||
# License: MIT License
|
||||
|
||||
|
||||
from typing import Optional, List, Tuple, Union
|
||||
import warnings
|
||||
import numpy as np
|
||||
from numpy import ndarray
|
||||
from numpy.linalg import linalg
|
||||
from scipy.linalg import solve, qr
|
||||
from scipy.signal import sosfiltfilt, cheby1, cheb1ord
|
||||
from sklearn.base import BaseEstimator, TransformerMixin, clone
|
||||
|
||||
|
||||
def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray:
|
||||
"""Transform spatial filters to spatial patterns based on paper [1]_.
|
||||
Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine
|
||||
information from different channels to extract signals of interest from EEG signals, but if our goal is
|
||||
neurophysiological interpretation or visualization of weights, activation patterns need to be constructed
|
||||
from the obtained spatial filters.
|
||||
|
||||
update log:
|
||||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
W : ndarray
|
||||
Spatial filters, shape (n_channels, n_filters).
|
||||
Cx : ndarray
|
||||
Covariance matrix of eeg data, shape (n_channels, n_channels).
|
||||
Cs : ndarray
|
||||
Covariance matrix of source data, shape (n_channels, n_channels).
|
||||
|
||||
Returns
|
||||
-------
|
||||
A : ndarray
|
||||
Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging.
|
||||
Neuroimage 87 (2014): 96-110.
|
||||
"""
|
||||
# use linalg.solve instead of inv, makes it more stable
|
||||
# see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py
|
||||
# and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html
|
||||
A = solve(Cs.T, np.dot(Cx, W).T).T
|
||||
return A
|
||||
|
||||
|
||||
class FilterBank(BaseEstimator, TransformerMixin):
|
||||
"""
|
||||
Filter bank decomposition is a bandpass filter array that divides the input signal into
|
||||
multiple subband components and obtains the eigenvalues of each subband component.
|
||||
|
||||
update log:
|
||||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
base_estimator : class
|
||||
Estimator for model training and feature extraction.
|
||||
filterbank : list[ndarray]
|
||||
A bandpass filter bank used to divide the input signal into multiple subband components.
|
||||
n_jobs : int
|
||||
Sets the number of CPU working cores. The default is None.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Chen X, Wang Y, Nakanishi M, et al. High-speed spelling with a noninvasive brain-computer interface[J].
|
||||
Proceedings of the national academy of sciences, 2015, 112(44): E6058-E6067.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
base_estimator: BaseEstimator,
|
||||
filterbank: List[ndarray],
|
||||
n_jobs: Optional[int] = None,
|
||||
):
|
||||
self.base_estimator = base_estimator
|
||||
self.filterbank = filterbank
|
||||
self.n_jobs = n_jobs
|
||||
|
||||
def fit(self, X: ndarray, y: Optional[ndarray] = None, **kwargs):
|
||||
"""
|
||||
Training model
|
||||
|
||||
update log:
|
||||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : None
|
||||
Training signal (parameters can be ignored, only used to maintain code structure).
|
||||
y : None
|
||||
Label data (ibid., ignorable).
|
||||
Yf : None
|
||||
Reference signal (ibid., ignorable).
|
||||
"""
|
||||
self.estimators_ = [
|
||||
clone(self.base_estimator) for _ in range(len(self.filterbank))
|
||||
]
|
||||
X = self.transform_filterbank(X)
|
||||
for i, est in enumerate(self.estimators_):
|
||||
est.fit(X[i], y, **kwargs)
|
||||
# def wrapper(est, X, y, kwargs):
|
||||
# est.fit(X, y, **kwargs)
|
||||
# return est
|
||||
# self.estimators_ = Parallel(n_jobs=self.n_jobs)(
|
||||
# delayed(wrapper)(est, X[i], y, kwargs) for i, est in enumerate(self.estimators_))
|
||||
return self
|
||||
|
||||
def transform(self, X: ndarray, **kwargs):
|
||||
"""
|
||||
The parameters stored in self are used to convert X into features, and X is filtered through the filter bank to
|
||||
obtain the eigenvalues of each subband component.
|
||||
|
||||
update log:
|
||||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : ndarray, shape(n_trials, n_channels, n_samples)
|
||||
Test the signal.
|
||||
|
||||
Returns
|
||||
-------
|
||||
feat : ndarray, shape(n_trials, n_fre)
|
||||
Feature array.
|
||||
"""
|
||||
X = self.transform_filterbank(X)
|
||||
feat = [est.transform(X[i], **kwargs) for i, est in enumerate(self.estimators_)]
|
||||
# def wrapper(est, X, kwargs):
|
||||
# retval = est.transform(X, **kwargs)
|
||||
# return retval
|
||||
# feat = Parallel(n_jobs=self.n_jobs)(
|
||||
# delayed(wrapper)(est, X[i], kwargs) for i, est in enumerate(self.estimators_))
|
||||
feat = np.concatenate(feat, axis=-1)
|
||||
return feat
|
||||
|
||||
def transform_filterbank(self, X: ndarray):
|
||||
"""
|
||||
The input signal is filtered through a filter bank.
|
||||
|
||||
update log:
|
||||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : ndarray, shape(n_trials, n_channels, n_samples)
|
||||
Input signal.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Xs: ndarray, shape(Nfb, n_trials, n_channels, n_samples)
|
||||
Individual subband components of the input signal.
|
||||
"""
|
||||
Xs = np.stack([sosfiltfilt(sos, X, axis=-1) for sos in self.filterbank])
|
||||
return Xs
|
||||
|
||||
|
||||
class FilterBankSSVEP(FilterBank):
|
||||
"""
|
||||
Filter bank analysis for SSVEP.
|
||||
The SSVEP is analyzed using filter banks, that is, multiple filters are combined to decompose the SSVEP signal
|
||||
into specific segments (subbands containing the original data) and obtain its characteristic data.
|
||||
|
||||
update log:
|
||||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filterbank : list[ndarray]
|
||||
The filter bank.
|
||||
base_estimator : class
|
||||
Estimator for model training and feature extraction.
|
||||
filterweights : ndarray
|
||||
Filter weight, default is None.
|
||||
n_jobs : int
|
||||
Sets the number of CPU working cores. The default is None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filterbank: List[ndarray],
|
||||
base_estimator: BaseEstimator,
|
||||
filterweights: Optional[ndarray] = None,
|
||||
n_jobs: Optional[int] = None,
|
||||
):
|
||||
self.filterweights = filterweights
|
||||
super().__init__(base_estimator, filterbank, n_jobs=n_jobs)
|
||||
|
||||
def transform(self, X: ndarray): # type: ignore[override]
|
||||
"""
|
||||
X is converted into features by using the parameters stored in self, and the eigenvalues of each subband
|
||||
component are obtained after the input signal is filtered by the filter bank.
|
||||
|
||||
update log:
|
||||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : ndarray, shape(n_trials, n_channels, n_samples)
|
||||
Test the signal.
|
||||
|
||||
Returns
|
||||
-------
|
||||
features : ndarray, shape(n_trials, n_fre)
|
||||
Feature array.
|
||||
"""
|
||||
features = super().transform(X)
|
||||
if self.filterweights is None:
|
||||
return features
|
||||
else:
|
||||
features = np.reshape(
|
||||
features, (features.shape[0], len(self.filterbank), -1)
|
||||
)
|
||||
return np.sum(
|
||||
features * self.filterweights[np.newaxis, :, np.newaxis], axis=1
|
||||
)
|
||||
|
||||
|
||||
|
||||
def generate_filterbank(
|
||||
passbands: List[Tuple[float, float]],
|
||||
stopbands: List[Tuple[float, float]],
|
||||
srate: int,
|
||||
order: Optional[int] = None,
|
||||
rp: float = 0.5,
|
||||
):
|
||||
"""
|
||||
Create a filter bank, that is, obtain a bandpass filter coefficient that can divide the input signal into multiple
|
||||
subband components.
|
||||
|
||||
update log:
|
||||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
passbands : list or tuple(float, float)
|
||||
Passband parameters.
|
||||
stopbands : list or tuple(float, float)
|
||||
Stopband parameters.
|
||||
srate : float
|
||||
Sampling rate.
|
||||
order : int
|
||||
Filter order.
|
||||
rp : float
|
||||
The maximum ripple allowed in the passband below the unit gain is 0.5 by default.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Filterbank:ndarray, shape(len(passbands), N, 6)
|
||||
Filter bank coefficient.
|
||||
"""
|
||||
filterbank = []
|
||||
for wp, ws in zip(passbands, stopbands):
|
||||
if order is None:
|
||||
N, wn = cheb1ord(wp, ws, 3, 40, fs=srate)
|
||||
sos = cheby1(N, rp, wn, btype="bandpass", output="sos", fs=srate)
|
||||
else:
|
||||
sos = cheby1(order, rp, wp, btype="bandpass", output="sos", fs=srate)
|
||||
|
||||
filterbank.append(sos)
|
||||
return filterbank
|
||||
|
||||
def process(data):
|
||||
# 白化操作
|
||||
meanValue = np.mat(data.mean(axis=1))
|
||||
meanData = np.repeat(meanValue, data.shape[1], axis=1)
|
||||
whiteTemp = data - meanData
|
||||
# QR 分解
|
||||
rankWhiteTemp = whiteTemp.shape[0]
|
||||
whiteTemp = np.transpose(whiteTemp)
|
||||
Q, R = qr(whiteTemp.A, mode='economic')
|
||||
# 计算矩阵的秩
|
||||
rankQ = linalg.matrix_rank(R)
|
||||
if rankQ == 0:
|
||||
raise ValueError('stats:canoncorr:badData')
|
||||
elif rankQ <= rankWhiteTemp:
|
||||
# warnings.warn('stats:canoncorr:NotFullRank')
|
||||
Q = Q[:, 0:rankQ]
|
||||
return Q, rankQ
|
||||
|
||||
def reference(listFreqs,fs, numberSmples, num_harms):
|
||||
numberFrequence = len(listFreqs)
|
||||
timeIndex = np.arange(1, numberSmples + 1) / fs # time index
|
||||
referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples))
|
||||
for frequenceIndex in range(numberFrequence):
|
||||
temp = []
|
||||
for harmIndex in range(1, num_harms + 1):
|
||||
stimFrequence = listFreqs[frequenceIndex] # in HZ
|
||||
# Sin and Cos
|
||||
temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence),
|
||||
np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)])
|
||||
referenceTemp = np.mat(temp)
|
||||
# 白化操作和QR分解
|
||||
Q, rankQ = process(referenceTemp)
|
||||
referenceData[frequenceIndex] = np.transpose(Q)
|
||||
return referenceData
|
||||
|
||||
def generate_cca_references(
|
||||
freqs: Union[ndarray, int, float],
|
||||
srate,
|
||||
T,
|
||||
phases: Optional[Union[ndarray, int, float]] = None,
|
||||
n_harmonics: int = 1,
|
||||
):
|
||||
"""
|
||||
Construct a sine-cosine reference signal for canonical correlation analysis (CCA).
|
||||
|
||||
update log:
|
||||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freqs : int or float
|
||||
Frequency.
|
||||
srate : int
|
||||
Sampling rate.
|
||||
T : int
|
||||
Sampling time.
|
||||
phases : int or float
|
||||
Phase, default is None.
|
||||
n_harmonics : int
|
||||
The number of harmonics. The default value is 1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Yf:ndarray, shape(srate*T, n_harmonics*2)
|
||||
Sine and cosine reference signal.
|
||||
"""
|
||||
if isinstance(freqs, int) or isinstance(freqs, float):
|
||||
freqs = np.array([freqs])
|
||||
freqs = np.array(freqs)[:, np.newaxis]
|
||||
if phases is None:
|
||||
phases = 0
|
||||
if isinstance(phases, int) or isinstance(phases, float):
|
||||
phases = np.array([phases])
|
||||
phases = np.array(phases)[:, np.newaxis]
|
||||
t = np.linspace(0, T, int(T * srate))
|
||||
|
||||
Yf = []
|
||||
for i in range(n_harmonics):
|
||||
Yf.append(
|
||||
np.stack(
|
||||
[
|
||||
np.sin(2 * np.pi * (i + 1) * freqs * t + np.pi * phases),
|
||||
np.cos(2 * np.pi * (i + 1) * freqs * t + np.pi * phases),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
)
|
||||
Yf = np.concatenate(Yf, axis=1)
|
||||
return Yf
|
||||
|
||||
|
||||
def sign_flip(u, s, vh=None):
|
||||
"""Flip signs of SVD or EIG using the method in paper [1]_.
|
||||
|
||||
update log:
|
||||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
u: ndarray
|
||||
left singular vectors, shape (M, K).
|
||||
s: ndarray
|
||||
singular values, shape (K,).
|
||||
vh: ndarray or None
|
||||
transpose of right singular vectors, shape (K, N).
|
||||
|
||||
Returns
|
||||
-------
|
||||
u: ndarray
|
||||
corrected left singular vectors.
|
||||
s: ndarray
|
||||
singular values.
|
||||
vh: ndarray
|
||||
transpose of corrected right singular vectors.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] https://www.sandia.gov/~tgkolda/pubs/pubfiles/SAND2007-6422.pdf
|
||||
"""
|
||||
if vh is None:
|
||||
total_proj = np.sum(u * s, axis=0)
|
||||
signs = np.sign(total_proj)
|
||||
|
||||
random_idx = signs == 0
|
||||
if np.any(random_idx):
|
||||
signs[random_idx] = 1
|
||||
warnings.warn(
|
||||
"The magnitude is close to zero, the sign will become arbitrary."
|
||||
)
|
||||
|
||||
u = u * signs
|
||||
|
||||
return u, s
|
||||
else:
|
||||
left_proj = np.sum(s[:, np.newaxis] * vh, axis=-1)
|
||||
right_proj = np.sum(u * s, axis=0)
|
||||
total_proj = left_proj + right_proj
|
||||
signs = np.sign(total_proj)
|
||||
|
||||
random_idx = signs == 0
|
||||
if np.any(random_idx):
|
||||
signs[random_idx] = 1
|
||||
warnings.warn(
|
||||
"The magnitude is close to zero, the sign will become arbitrary."
|
||||
)
|
||||
|
||||
u = u * signs
|
||||
vh = signs[:, np.newaxis] * vh
|
||||
|
||||
return u, s, vh
|
||||
436
SSMVEP/algorithm/dsp.py
Normal file
436
SSMVEP/algorithm/dsp.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# DSP: Discriminal Spatial Patterns
|
||||
# Authors: Swolf <swolfforever@gmail.com>
|
||||
# Junyang Wang <2144755928@qq.com>
|
||||
# Last update date: 2022-8-11
|
||||
# License: MIT License
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
from itertools import combinations
|
||||
import numpy as np
|
||||
from scipy.linalg import eigh
|
||||
from numpy import ndarray
|
||||
from scipy.linalg import solve
|
||||
from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin
|
||||
|
||||
|
||||
|
||||
def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray:
|
||||
"""Transform spatial filters to spatial patterns based on paper [1]_.
|
||||
Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine
|
||||
information from different channels to extract signals of interest from EEG signals, but if our goal is
|
||||
neurophysiological interpretation or visualization of weights, activation patterns need to be constructed
|
||||
from the obtained spatial filters.
|
||||
|
||||
update log:
|
||||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
W : ndarray
|
||||
Spatial filters, shape (n_channels, n_filters).
|
||||
Cx : ndarray
|
||||
Covariance matrix of eeg data, shape (n_channels, n_channels).
|
||||
Cs : ndarray
|
||||
Covariance matrix of source data, shape (n_channels, n_channels).
|
||||
|
||||
Returns
|
||||
-------
|
||||
A : ndarray
|
||||
Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging.
|
||||
Neuroimage 87 (2014): 96-110.
|
||||
"""
|
||||
# use linalg.solve instead of inv, makes it more stable
|
||||
# see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py
|
||||
# and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html
|
||||
A = solve(Cs.T, np.dot(Cx, W).T).T
|
||||
return A
|
||||
|
||||
def isPD(B: ndarray) -> bool:
|
||||
"""Returns true when input matrix is positive-definite, via Cholesky decompositon method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
B : ndarray
|
||||
Any matrix, shape (N, N)
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if B is positve-definite.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Use numpy.linalg rather than scipy.linalg. In this case, scipy.linalg has unpredictable behaviors.
|
||||
"""
|
||||
|
||||
try:
|
||||
_ = np.linalg.cholesky(B)
|
||||
return True
|
||||
except np.linalg.LinAlgError:
|
||||
return False
|
||||
|
||||
def nearestPD(A: ndarray) -> ndarray:
|
||||
"""Find the nearest positive-definite matrix to input.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A : ndarray
|
||||
Any square matrxi, shape (N, N)
|
||||
|
||||
Returns
|
||||
-------
|
||||
A3 : ndarray
|
||||
positive-definite matrix to A
|
||||
|
||||
Notes
|
||||
-----
|
||||
A Python/Numpy port of John D'Errico's `nearestSPD` MATLAB code [1]_, which
|
||||
origins at [2]_.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] https://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd
|
||||
.. [2] N.J. Higham, "Computing a nearest symmetric positive semidefinite matrix" (1988):
|
||||
https://doi.org/10.1016/0024-3795(88)90223-6
|
||||
"""
|
||||
|
||||
B = (A + A.T) / 2
|
||||
_, s, V = np.linalg.svd(B)
|
||||
|
||||
H = np.dot(V.T, np.dot(np.diag(s), V))
|
||||
|
||||
A2 = (B + H) / 2
|
||||
|
||||
A3 = (A2 + A2.T) / 2
|
||||
|
||||
if isPD(A3):
|
||||
return A3
|
||||
|
||||
print("Replace current matrix with the nearest positive-definite matrix.")
|
||||
|
||||
spacing = np.spacing(np.linalg.norm(A))
|
||||
# The above is different from [1]. It appears that MATLAB's `chol` Cholesky
|
||||
# decomposition will accept matrixes with exactly 0-eigenvalue, whereas
|
||||
# Numpy's will not. So where [1] uses `eps(mineig)` (where `eps` is Matlab
|
||||
# for `numpy.spacing`), we use the above definition. CAVEAT: our `spacing`
|
||||
# will be much larger than [1]'s `eps(mineig)`, since `mineig` is usually on
|
||||
# the order of 1e-16, and `eps(1e-16)` is on the order of 1e-34, whereas
|
||||
# `spacing` will, for Gaussian random matrixes of small dimension, be on
|
||||
# othe order of 1e-16. In practice, both ways converge, as the unit test
|
||||
# below suggests.
|
||||
eye = np.eye(A.shape[0])
|
||||
k = 1
|
||||
while not isPD(A3):
|
||||
mineig = np.min(np.real(np.linalg.eigvals(A3)))
|
||||
A3 += eye * (-mineig * k**2 + spacing)
|
||||
k += 1
|
||||
|
||||
return A3
|
||||
|
||||
def xiang_dsp_kernel(
|
||||
X: ndarray, y: ndarray
|
||||
) -> Tuple[ndarray, ndarray, ndarray, ndarray]:
|
||||
"""
|
||||
DSP: Discriminal Spatial Patterns, only for two classes[1]_.
|
||||
Import train data to solve spatial filters with DSP,
|
||||
finds a projection matrix that maximize the between-class scatter matrix and
|
||||
minimize the within-class scatter matrix. Currently only support for two types of data.
|
||||
|
||||
Author: Swolf <swolfforever@gmail.com>
|
||||
|
||||
Created on: 2021-1-07
|
||||
|
||||
Update log:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : ndarray
|
||||
EEG train data assuming removing mean, shape (n_trials, n_channels, n_samples)
|
||||
y : ndarray
|
||||
labels of EEG data, shape (n_trials, )
|
||||
|
||||
Returns
|
||||
-------
|
||||
W : ndarray
|
||||
spatial filters, shape (n_channels, n_filters)
|
||||
D : ndarray
|
||||
eigenvalues in descending order
|
||||
M : ndarray
|
||||
mean value of all classes and trials, i.e. common mode signals, shape (n_channel, n_samples)
|
||||
A : ndarray
|
||||
spatial patterns, shape (n_channels, n_filters)
|
||||
|
||||
Notes
|
||||
-----
|
||||
the implementation removes regularization on within-class scatter matrix Sw.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Liao, Xiang, et al. "Combining spatial filters for the classification of single-trial EEG in
|
||||
a finger movement task." IEEE Transactions on Biomedical Engineering 54.5 (2007): 821-831.
|
||||
"""
|
||||
X, y = np.copy(X), np.copy(y)
|
||||
labels = np.unique(y)
|
||||
X = np.reshape(X, (-1, *X.shape[-2:]))
|
||||
X = X - np.mean(X, axis=-1, keepdims=True)
|
||||
# the number of each label
|
||||
n_labels = np.array([np.sum(y == label) for label in labels])
|
||||
# average template of all trials
|
||||
M = np.mean(X, axis=0)
|
||||
# class conditional template
|
||||
Ms, Ss = zip(
|
||||
*[
|
||||
(
|
||||
np.mean(X[y == label], axis=0),
|
||||
np.sum(
|
||||
np.matmul(X[y == label], np.swapaxes(X[y == label], -1, -2)), axis=0
|
||||
),
|
||||
)
|
||||
for label in labels
|
||||
]
|
||||
)
|
||||
Ms, Ss = np.stack(Ms), np.stack(Ss)
|
||||
# within-class scatter matrix
|
||||
Sw = np.sum(
|
||||
Ss
|
||||
- n_labels[:, np.newaxis, np.newaxis] * np.matmul(Ms, np.swapaxes(Ms, -1, -2)),
|
||||
axis=0,
|
||||
)
|
||||
Ms = Ms - M
|
||||
# between-class scatter matrix
|
||||
Sb = np.sum(
|
||||
n_labels[:, np.newaxis, np.newaxis] * np.matmul(Ms, np.swapaxes(Ms, -1, -2)),
|
||||
axis=0,
|
||||
)
|
||||
|
||||
D, W = eigh(nearestPD(Sb), nearestPD(Sw))
|
||||
ix = np.argsort(D)[::-1] # in descending order
|
||||
D, W = D[ix], W[:, ix]
|
||||
A = robust_pattern(W, Sb, W.T @ Sb @ W)
|
||||
|
||||
return W, D, M, A
|
||||
|
||||
|
||||
def xiang_dsp_feature(
|
||||
W: ndarray, M: ndarray, X: ndarray, n_components: int = 1
|
||||
) -> ndarray:
|
||||
"""
|
||||
Return DSP features in paper [1]_.
|
||||
|
||||
Author: Swolf <swolfforever@gmail.com>
|
||||
|
||||
Created on: 2021-1-07
|
||||
|
||||
Update log:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
W : ndarray
|
||||
spatial filters from csp_kernel, shape (n_channels, n_filters)
|
||||
M : ndarray
|
||||
common template for all classes, shape (n_channel, n_samples)
|
||||
X : ndarray
|
||||
eeg test data, shape (n_trials, n_channels, n_samples)
|
||||
n_components : int, optional
|
||||
length of the spatial filters, first k components to use, by default 1
|
||||
|
||||
Returns
|
||||
-------
|
||||
features: ndarray
|
||||
features, shape (n_trials, n_components, n_samples)
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
n_components should less than half of the number of channels
|
||||
|
||||
Notes
|
||||
-----
|
||||
1. instead of meaning of filtered signals in paper [1]_., we directly return filtered signals.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Liao, Xiang, et al. "Combining spatial filters for the classification of single-trial EEG in
|
||||
a finger movement task." IEEE Transactions on Biomedical Engineering 54.5 (2007): 821-831.
|
||||
"""
|
||||
W, M, X = np.copy(W), np.copy(M), np.copy(X)
|
||||
max_components = W.shape[1]
|
||||
if n_components > max_components:
|
||||
raise ValueError("n_components should less than the number of channels")
|
||||
X = np.reshape(X, (-1, *X.shape[-2:]))
|
||||
X = X - np.mean(X, axis=-1, keepdims=True)
|
||||
# print('************: ',np.shape(W),np.shape(X),np.shape(M))
|
||||
features = np.matmul(W[:, :n_components].T, X - M)
|
||||
return features
|
||||
|
||||
|
||||
class DSP(BaseEstimator, TransformerMixin, ClassifierMixin):
|
||||
"""
|
||||
DSP: Discriminal Spatial Patterns
|
||||
|
||||
Author: Swolf <swolfforever@gmail.com>
|
||||
|
||||
Created on: 2021-1-07
|
||||
|
||||
Update log:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_components : int
|
||||
length of the spatial filter, first k components to use, by default 1
|
||||
transform_method : str
|
||||
method of template matching, by default ’corr‘ (pearson correlation coefficient)
|
||||
classes_ : int
|
||||
number of the EEG classes
|
||||
|
||||
Attributes
|
||||
----------
|
||||
n_components : int
|
||||
length of the spatial filter, first k components to use, by default 1
|
||||
transform_method : str
|
||||
method of template matching, by default ’corr‘ (pearson correlation coefficient)
|
||||
classes_ : int
|
||||
number of the EEG classes
|
||||
W_ : ndarray, shape(n_channels, n_filters)
|
||||
Spatial filters, shape(n_channels, n_filters), in which n_channels = n_filters
|
||||
D_ : ndarray, shape(n_filters, )
|
||||
eigenvalues in descending order, shape(n_filters, )
|
||||
M_ : ndarray, shape(n_channels, n_samples)
|
||||
mean value of all classes and trials, i.e. common mode signals, shape(n_channels, n_samples)
|
||||
A_ : ndarray, shape(n_channels, n_filters)
|
||||
spatial patterns, shape(n_channels, n_filters)
|
||||
templates_: ndarray, shape(n_classes, n_filters, n_samples)
|
||||
templates of train data, shape(n_classes, n_filters, n_samples)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_components: int = 1, transform_method: str = "corr"):
|
||||
self.n_components = n_components
|
||||
self.transform_method = transform_method
|
||||
|
||||
def fit(self, X: ndarray, y: ndarray, Yf: Optional[ndarray] = None):
|
||||
"""
|
||||
Import the train data to get a model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : ndarray
|
||||
train data, shape(n_trials, n_channels, n_samples)
|
||||
y : ndarray
|
||||
labels of train data, shape (n_trials, )
|
||||
Yf : ndarray
|
||||
optional parameter
|
||||
|
||||
Returns
|
||||
-------
|
||||
W_ : ndarray
|
||||
spatial filters, shape (n_channels, n_filters), in which n_channels = n_filters
|
||||
D_ : ndarray
|
||||
eigenvalues in descending order, shape (n_filters, )
|
||||
M_ : ndarray
|
||||
template for all classes, shape (n_channel, n_samples)
|
||||
A_ : ndarray
|
||||
spatial patterns, shape (n_channels, n_filters)
|
||||
templates_ : ndarray
|
||||
templates of train data, shape (n_channels, n_filters, n_samples)
|
||||
"""
|
||||
X -= np.mean(X, axis=-1, keepdims=True)
|
||||
self.classes_ = np.unique(y)
|
||||
self.W_, self.D_, self.M_, self.A_ = xiang_dsp_kernel(X, y)
|
||||
|
||||
self.templates_ = np.stack(
|
||||
[
|
||||
np.mean(
|
||||
xiang_dsp_feature(
|
||||
self.W_, self.M_, X[y == label], n_components=self.W_.shape[1]
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
for label in self.classes_
|
||||
]
|
||||
)
|
||||
return self
|
||||
|
||||
def transform(self, X: ndarray):
|
||||
"""
|
||||
Import the test data to get features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : ndarray
|
||||
test data, shape(n_trials, n_channels, n_samples)
|
||||
|
||||
Returns
|
||||
-------
|
||||
feature : ndarray, shape(n_trials,n_classes)
|
||||
correlation coefficients of templates of train data and features of test data, shape(n_trials, n_classes)
|
||||
"""
|
||||
n_components = self.n_components
|
||||
X -= np.mean(X, axis=-1, keepdims=True)
|
||||
features = xiang_dsp_feature(self.W_, self.M_, X, n_components=n_components)
|
||||
if self.transform_method is None:
|
||||
return features.reshape((features.shape[0], -1))
|
||||
elif self.transform_method == "mean":
|
||||
return np.mean(features, axis=-1)
|
||||
elif self.transform_method == "corr":
|
||||
return self._pearson_features(
|
||||
features, self.templates_[:, :n_components, :]
|
||||
)
|
||||
else:
|
||||
raise ValueError("non-supported transform method")
|
||||
|
||||
def _pearson_features(self, X: ndarray, templates: ndarray):
|
||||
"""
|
||||
Calculate pearson correlation coefficient.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : ndarray
|
||||
features of test data after spatial filters, shape(n_trials, n_components, n_samples)
|
||||
templates : ndarray
|
||||
templates of train data, shape(n_classes, n_components, n_samples)
|
||||
|
||||
Returns
|
||||
-------
|
||||
corr : ndarray
|
||||
pearson correlation coefficient, shape(n_trials, n_classes)
|
||||
"""
|
||||
X = np.reshape(X, (-1, *X.shape[-2:]))
|
||||
templates = np.reshape(templates, (-1, *templates.shape[-2:]))
|
||||
X = X - np.mean(X, axis=-1, keepdims=True)
|
||||
templates = templates - np.mean(templates, axis=-1, keepdims=True)
|
||||
X = np.reshape(X, (X.shape[0], -1))
|
||||
templates = np.reshape(templates, (templates.shape[0], -1))
|
||||
istd_X = 1 / np.std(X, axis=-1, keepdims=True)
|
||||
istd_templates = 1 / np.std(templates, axis=-1, keepdims=True)
|
||||
corr = (X @ templates.T) / (templates.shape[1] - 1)
|
||||
corr = istd_X * corr * istd_templates.T
|
||||
return corr
|
||||
|
||||
def predict(self, X: ndarray):
|
||||
"""
|
||||
Import the templates and the test data to get prediction labels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : ndarray
|
||||
test data, shape(n_trials, n_channels, n_samples)
|
||||
|
||||
Returns
|
||||
-------
|
||||
labels : ndarray
|
||||
prediction labels of test data, shape(n_trials,)
|
||||
"""
|
||||
feat = self.transform(X)
|
||||
if self.transform_method == "corr":
|
||||
labels = self.classes_[np.argmax(feat, axis=-1)]
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return labels
|
||||
|
||||
|
||||
175
SSMVEP/algorithm/tdca.py
Normal file
175
SSMVEP/algorithm/tdca.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Authors: Swolf <swolfforever@gmail.com>
|
||||
# Date: 2021/10/10
|
||||
# License: MIT License
|
||||
"""
|
||||
Task Decomposition Component Analysis.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from scipy.linalg import qr
|
||||
from scipy.stats import pearsonr
|
||||
from numpy import ndarray
|
||||
from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin
|
||||
from typing import Optional, List
|
||||
from SSMVEP.algorithm.base import FilterBankSSVEP
|
||||
from SSMVEP.algorithm.dsp import xiang_dsp_kernel, xiang_dsp_feature
|
||||
|
||||
|
||||
def proj_ref(Yf: ndarray):
|
||||
Q, R = qr(Yf.T, mode="economic")
|
||||
P = Q @ Q.T
|
||||
return P
|
||||
|
||||
|
||||
def aug_2(X: ndarray, n_samples: int, padding_len: int, P: ndarray, training: bool = True):
|
||||
X = X.reshape((-1, *X.shape[-2:]))
|
||||
n_trials, n_channels, n_points = X.shape
|
||||
# if n_points < padding_len + n_samples:
|
||||
# raise ValueError("the length of X should be larger than l+n_samples.")
|
||||
aug_X = np.zeros((n_trials, (padding_len + 1) * n_channels, n_samples))
|
||||
if training:
|
||||
for i in range(padding_len + 1):
|
||||
aug_X[:, i * n_channels : (i + 1) * n_channels, :] = X[
|
||||
..., i : i + n_samples
|
||||
]
|
||||
else:
|
||||
for i in range(padding_len + 1):
|
||||
aug_X[:, i * n_channels : (i + 1) * n_channels, : n_samples - i] = X[
|
||||
..., i:n_samples
|
||||
]
|
||||
aug_Xp = aug_X @ P
|
||||
aug_X = np.concatenate([aug_X, aug_Xp], axis=-1)
|
||||
return aug_X
|
||||
|
||||
|
||||
def tdca_feature(
|
||||
X: ndarray,
|
||||
templates: ndarray,
|
||||
W: ndarray,
|
||||
M: ndarray,
|
||||
Ps: List[ndarray],
|
||||
padding_len: int,
|
||||
n_components: int = 1,
|
||||
training=False,
|
||||
):
|
||||
rhos = []
|
||||
for Xk, P in zip(templates, Ps):
|
||||
a = xiang_dsp_feature(
|
||||
W,
|
||||
M,
|
||||
aug_2(X, P.shape[0], padding_len, P, training=training),
|
||||
n_components=n_components,
|
||||
)
|
||||
b = Xk[:n_components, :]
|
||||
a = np.reshape(a, (-1))
|
||||
b = np.reshape(b, (-1))
|
||||
rhos.append(pearsonr(a, b)[0])
|
||||
return rhos
|
||||
|
||||
|
||||
class TDCA(BaseEstimator, TransformerMixin, ClassifierMixin):
|
||||
def __init__(self, padding_len: int, n_components: int = 1):
|
||||
self.padding_len = padding_len
|
||||
self.n_components = n_components
|
||||
|
||||
def fit(self, X: ndarray, y: ndarray, Yf: ndarray):
|
||||
X -= np.mean(X, axis=-1, keepdims=True)
|
||||
self.classes_ = np.unique(y)
|
||||
self.Ps_ = [proj_ref(Yf[i]) for i in range(len(self.classes_))]
|
||||
# print(np.shape(self.Ps_))
|
||||
|
||||
aug_X_list, aug_Y_list = [], []
|
||||
for i, label in enumerate(self.classes_):
|
||||
aug_X_list.append(
|
||||
aug_2(
|
||||
X[y == label],
|
||||
self.Ps_[i].shape[0],
|
||||
self.padding_len,
|
||||
self.Ps_[i],
|
||||
training=True,
|
||||
)
|
||||
)
|
||||
aug_Y_list.append(y[y == label])
|
||||
|
||||
aug_X = np.concatenate(aug_X_list, axis=0)
|
||||
aug_Y = np.concatenate(aug_Y_list, axis=0)
|
||||
self.W_, _, self.M_, _ = xiang_dsp_kernel(aug_X, aug_Y)
|
||||
|
||||
self.templates_ = np.stack(
|
||||
[
|
||||
np.mean(
|
||||
xiang_dsp_feature(
|
||||
self.W_,
|
||||
self.M_,
|
||||
aug_X[aug_Y == label],
|
||||
n_components=self.W_.shape[1],
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
for label in self.classes_
|
||||
]
|
||||
)
|
||||
return self
|
||||
|
||||
def transform(self, X: ndarray):
|
||||
n_components = self.n_components
|
||||
X -= np.mean(X, axis=-1, keepdims=True)
|
||||
X = X.reshape((-1, *X.shape[-2:]))
|
||||
rhos = [
|
||||
tdca_feature(
|
||||
tmp,
|
||||
self.templates_,
|
||||
self.W_,
|
||||
self.M_,
|
||||
self.Ps_,
|
||||
self.padding_len,
|
||||
n_components=n_components,
|
||||
)
|
||||
for tmp in X
|
||||
]
|
||||
rhos = np.stack(rhos)
|
||||
return rhos
|
||||
|
||||
def predict(self, X: ndarray):
|
||||
feat = self.transform(X)
|
||||
labels = self.classes_[np.argmax(feat, axis=-1)]
|
||||
return labels,feat
|
||||
|
||||
|
||||
class FBTDCA(FilterBankSSVEP, ClassifierMixin):
|
||||
def __init__(
|
||||
self,
|
||||
filterbank: List[ndarray],
|
||||
padding_len: int,
|
||||
n_components: int = 1,
|
||||
filterweights: Optional[ndarray] = None,
|
||||
n_jobs: Optional[int] = None,
|
||||
):
|
||||
self.padding_len = padding_len
|
||||
self.n_components = n_components
|
||||
self.filterweights = filterweights
|
||||
self.n_jobs = n_jobs
|
||||
super().__init__(
|
||||
filterbank,
|
||||
TDCA(padding_len, n_components=n_components),
|
||||
filterweights=filterweights,
|
||||
n_jobs=n_jobs,
|
||||
)
|
||||
|
||||
def fit(self, X: ndarray, y: ndarray, Yf: Optional[ndarray] = None): # type: ignore[override]
|
||||
self.classes_ = np.unique(y)
|
||||
super().fit(X, y, Yf=Yf)
|
||||
return self
|
||||
|
||||
def predict(self, X: ndarray):
|
||||
features = self.transform(X)
|
||||
if self.filterweights is None:
|
||||
features = np.reshape(
|
||||
features, (features.shape[0], len(self.filterbank), -1)
|
||||
)
|
||||
features = np.mean(features, axis=1)
|
||||
labels = self.classes_[np.argmax(features, axis=-1)]
|
||||
return labels,features
|
||||
Reference in New Issue
Block a user