# -*- coding: utf-8 -*- # # Authors: Swolf # Date: 2021/10/10 # License: MIT License """ Task Decomposition Component Analysis. """ from typing import List import numpy as np from scipy.linalg import qr from scipy.stats import pearsonr from numpy import ndarray from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin from typing import Optional, List from SSMVEP.algorithm.base import FilterBankSSVEP from SSMVEP.algorithm.dsp import xiang_dsp_kernel, xiang_dsp_feature def proj_ref(Yf: ndarray): Q, R = qr(Yf.T, mode="economic") P = Q @ Q.T return P def aug_2(X: ndarray, n_samples: int, padding_len: int, P: ndarray, training: bool = True): X = X.reshape((-1, *X.shape[-2:])) n_trials, n_channels, n_points = X.shape # if n_points < padding_len + n_samples: # raise ValueError("the length of X should be larger than l+n_samples.") aug_X = np.zeros((n_trials, (padding_len + 1) * n_channels, n_samples)) if training: for i in range(padding_len + 1): aug_X[:, i * n_channels : (i + 1) * n_channels, :] = X[ ..., i : i + n_samples ] else: for i in range(padding_len + 1): aug_X[:, i * n_channels : (i + 1) * n_channels, : n_samples - i] = X[ ..., i:n_samples ] aug_Xp = aug_X @ P aug_X = np.concatenate([aug_X, aug_Xp], axis=-1) return aug_X def tdca_feature( X: ndarray, templates: ndarray, W: ndarray, M: ndarray, Ps: List[ndarray], padding_len: int, n_components: int = 1, training=False, ): rhos = [] for Xk, P in zip(templates, Ps): a = xiang_dsp_feature( W, M, aug_2(X, P.shape[0], padding_len, P, training=training), n_components=n_components, ) b = Xk[:n_components, :] a = np.reshape(a, (-1)) b = np.reshape(b, (-1)) rhos.append(pearsonr(a, b)[0]) return rhos class TDCA(BaseEstimator, TransformerMixin, ClassifierMixin): def __init__(self, padding_len: int, n_components: int = 1): self.padding_len = padding_len self.n_components = n_components def fit(self, X: ndarray, y: ndarray, Yf: ndarray): X -= np.mean(X, axis=-1, keepdims=True) self.classes_ = np.unique(y) self.Ps_ = [proj_ref(Yf[i]) for i in range(len(self.classes_))] # print(np.shape(self.Ps_)) aug_X_list, aug_Y_list = [], [] for i, label in enumerate(self.classes_): aug_X_list.append( aug_2( X[y == label], self.Ps_[i].shape[0], self.padding_len, self.Ps_[i], training=True, ) ) aug_Y_list.append(y[y == label]) aug_X = np.concatenate(aug_X_list, axis=0) aug_Y = np.concatenate(aug_Y_list, axis=0) self.W_, _, self.M_, _ = xiang_dsp_kernel(aug_X, aug_Y) self.templates_ = np.stack( [ np.mean( xiang_dsp_feature( self.W_, self.M_, aug_X[aug_Y == label], n_components=self.W_.shape[1], ), axis=0, ) for label in self.classes_ ] ) return self def transform(self, X: ndarray): n_components = self.n_components X -= np.mean(X, axis=-1, keepdims=True) X = X.reshape((-1, *X.shape[-2:])) rhos = [ tdca_feature( tmp, self.templates_, self.W_, self.M_, self.Ps_, self.padding_len, n_components=n_components, ) for tmp in X ] rhos = np.stack(rhos) return rhos def predict(self, X: ndarray): feat = self.transform(X) labels = self.classes_[np.argmax(feat, axis=-1)] return labels,feat class FBTDCA(FilterBankSSVEP, ClassifierMixin): def __init__( self, filterbank: List[ndarray], padding_len: int, n_components: int = 1, filterweights: Optional[ndarray] = None, n_jobs: Optional[int] = None, ): self.padding_len = padding_len self.n_components = n_components self.filterweights = filterweights self.n_jobs = n_jobs super().__init__( filterbank, TDCA(padding_len, n_components=n_components), filterweights=filterweights, n_jobs=n_jobs, ) def fit(self, X: ndarray, y: ndarray, Yf: Optional[ndarray] = None): # type: ignore[override] self.classes_ = np.unique(y) super().fit(X, y, Yf=Yf) return self def predict(self, X: ndarray): features = self.transform(X) if self.filterweights is None: features = np.reshape( features, (features.shape[0], len(self.filterbank), -1) ) features = np.mean(features, axis=1) labels = self.classes_[np.argmax(features, axis=-1)] return labels,features