""" EEG Conformer Convolutional Transformer for EEG decoding Couple CNN and Transformer in a concise manner with amazing results """ # remember to change paths import os import numpy as np import math import random import time import datetime import torch import torch.nn.functional as F from torch.utils.data import DataLoader from torch import nn from torch import Tensor from einops import rearrange from einops.layers.torch import Rearrange, Reduce from torch.backends import cudnn from sklearn.model_selection import train_test_split from logs.log import algo_log # writer = SummaryWriter('./TensorBoardX/') # Convolution module # use conv to capture local features, instead of postion embedding. class PatchEmbedding(nn.Module): def __init__(self, emb_size=40): # self.patch_size = patch_size super().__init__() self.shallownet = nn.Sequential( nn.Conv2d(1, 40, (1, 25), (1, 1)), nn.Conv2d(40, 40, (8, 1), (1, 1)), nn.BatchNorm2d(40), nn.ELU(), nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT nn.Dropout(0.5), ) self.projection = nn.Sequential( nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly Rearrange('b e (h) (w) -> b (h w) e'), ) def forward(self, x: Tensor) -> Tensor: b, _, _, _ = x.shape x = self.shallownet(x) x = self.projection(x) return x class MultiHeadAttention(nn.Module): def __init__(self, emb_size, num_heads, dropout): super().__init__() self.emb_size = emb_size self.num_heads = num_heads self.keys = nn.Linear(emb_size, emb_size) self.queries = nn.Linear(emb_size, emb_size) self.values = nn.Linear(emb_size, emb_size) self.att_drop = nn.Dropout(dropout) self.projection = nn.Linear(emb_size, emb_size) def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads) keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads) values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads) energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) if mask is not None: fill_value = torch.finfo(torch.float64).min energy.mask_fill(~mask, fill_value) scaling = self.emb_size ** (1 / 2) att = F.softmax(energy / scaling, dim=-1) att = self.att_drop(att) out = torch.einsum('bhal, bhlv -> bhav ', att, values) out = rearrange(out, "b h n d -> b n (h d)") out = self.projection(out) return out class ResidualAdd(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, **kwargs): res = x x = self.fn(x, **kwargs) x += res return x class FeedForwardBlock(nn.Sequential): def __init__(self, emb_size, expansion, drop_p): super().__init__( nn.Linear(emb_size, expansion * emb_size), nn.GELU(), nn.Dropout(drop_p), nn.Linear(expansion * emb_size, emb_size), ) class GELU(nn.Module): def forward(self, input: Tensor) -> Tensor: return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0))) class TransformerEncoderBlock(nn.Sequential): def __init__(self, emb_size, num_heads=10, drop_p=0.5, forward_expansion=4, forward_drop_p=0.5): super().__init__( ResidualAdd(nn.Sequential( nn.LayerNorm(emb_size), MultiHeadAttention(emb_size, num_heads, drop_p), nn.Dropout(drop_p) )), ResidualAdd(nn.Sequential( nn.LayerNorm(emb_size), FeedForwardBlock( emb_size, expansion=forward_expansion, drop_p=forward_drop_p), nn.Dropout(drop_p) ) )) class TransformerEncoder(nn.Sequential): def __init__(self, depth, emb_size): super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)]) class ClassificationHead(nn.Sequential): def __init__(self, emb_size, n_classes): super().__init__() # global average pooling self.clshead = nn.Sequential( Reduce('b n e -> b e', reduction='mean'), nn.LayerNorm(emb_size), nn.Linear(emb_size, n_classes) ) self.fc = nn.Sequential( nn.Linear(2440, 256), nn.ELU(), nn.Dropout(0.5), nn.Linear(256, 32), nn.ELU(), nn.Dropout(0.3), nn.Linear(32, 2) ) def forward(self, x): x = x.contiguous().view(x.size(0), -1) out = self.fc(x) return out class Conformer(nn.Sequential): def __init__(self, emb_size=40, depth=6, n_classes=2, **kwargs): super().__init__( PatchEmbedding(emb_size), TransformerEncoder(depth, emb_size), ClassificationHead(emb_size, n_classes) ) class ExP(): def __init__(self): super(ExP, self).__init__() self.batch_size = 24 self.n_epochs = 250 self.c_dim = 4 self.lr = 0.0002 self.b1 = 0.5 self.b2 = 0.999 self.start_epoch = 0 self.log_write = open("./online_Models/log_result.txt", "w") # 自动选择设备:有 GPU 用 GPU,否则用 CPU self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # self.device = torch.device("cpu") algo_log(f"Using device: {self.device}", level="debug") # 定义张量类型(不再强制使用 cuda) self.Tensor = torch.FloatTensor self.LongTensor = torch.LongTensor # 将模型移到指定设备 self.model = Conformer().to(self.device) # 损失函数也移到设备 self.criterion_cls = torch.nn.CrossEntropyLoss().to(self.device) # self.model = EEGNet().cuda() # self.model = nn.DataParallel(self.model,device_ids=[i for i in range(len(gpus))]) # self.model = self.model.cuda() # summary(self.model, (1, 8, 1000)) # Segmentation and Reconstruction (S&R) data augmentation def interaug(self, timg, label): aug_data = [] aug_label = [] for cls4aug in range(2): cls_idx = np.where(label == cls4aug + 1) tmp_data = timg[cls_idx] tmp_label = label[cls_idx] tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, 8, 1000)) for ri in range(int(self.batch_size / 2)): for rj in range(8): rand_idx = np.random.randint(0, tmp_data.shape[0], 8) tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :, rj * 125:(rj + 1) * 125] aug_data.append(tmp_aug_data) aug_label.append(tmp_label[:int(self.batch_size / 2)]) aug_data = np.concatenate(aug_data) aug_label = np.concatenate(aug_label) aug_shuffle = np.random.permutation(len(aug_data)) aug_data = aug_data[aug_shuffle, :, :] aug_label = aug_label[aug_shuffle] aug_data = torch.from_numpy(aug_data).float().to(self.device) aug_label = torch.from_numpy(aug_label - 1).long().to(self.device) return aug_data, aug_label def train(self,all_data,all_label,model_path): all_data = np.array(all_data);all_label = np.array(all_label) all_data = np.expand_dims(all_data, axis=1) train_data, test_data, train_label, test_label = train_test_split(all_data, all_label, test_size=0.2, random_state=42, stratify=all_label,shuffle=True) # 转为 Tensor img = torch.from_numpy(train_data).float().to(self.device) label = torch.from_numpy(train_label - 1).long().to(self.device) dataset = torch.utils.data.TensorDataset(img, label) self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True) test_data = torch.from_numpy(test_data).float().to(self.device) test_label = torch.from_numpy(test_label - 1).long().to(self.device) test_dataset = torch.utils.data.TensorDataset(test_data, test_label) self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True) # Optimizers self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2)) bestAcc = 0 averAcc = 0 num = 0 Y_true = 0 Y_pred = 0 # Train the cnn model for e in range(self.n_epochs): # in_epoch = time.time() self.model.train() for i, (img, label) in enumerate(self.dataloader): # data augmentation aug_data, aug_label = self.interaug(train_data, train_label) img = torch.cat((img, aug_data)) label = torch.cat((label, aug_label)) outputs = self.model(img) loss = self.criterion_cls(outputs, label) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # out_epoch = time.time() # test process if (e + 1) % 1 == 0: self.model.eval() with torch.no_grad(): Cls = self.model(test_data) loss_test = self.criterion_cls(Cls, test_label) y_pred = torch.max(Cls, 1)[1] acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0)) train_pred = torch.max(outputs, 1)[1] train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0)) print('Epoch:', e, ' Train loss: %.6f' % loss.detach().cpu().numpy(), ' Test loss: %.6f' % loss_test.detach().cpu().numpy(), ' Train accuracy %.6f' % train_acc, ' Test accuracy is %.6f' % acc) self.log_write.write(str(e) + " " + str(acc) + "\n") num = num + 1 averAcc = averAcc + acc if acc > bestAcc: bestAcc = acc Y_true = test_label Y_pred = y_pred torch.save(self.model, model_path) averAcc = averAcc / num print('The average accuracy is:', averAcc) print('The best accuracy is:', bestAcc) self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n") self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n") return bestAcc, averAcc, Y_true, Y_pred # writer.close() def onlineTrain(data_queue,result_queue): try: starttime = datetime.datetime.now() # seed_n = np.random.randint(2025) seed_n = 1877 random.seed(seed_n) np.random.seed(seed_n) torch.manual_seed(seed_n) torch.cuda.manual_seed(seed_n) torch.cuda.manual_seed_all(seed_n) exp = ExP() # 从队列获取训练数据 data = data_queue.get(timeout=30) all_data, all_label,model_path = data['data'], data['label'],data['modelPath'] print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path) bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path) print('THE BEST ACCURACY IS ' + str(bestAcc)) endtime = datetime.datetime.now() print('train duration: ',str(endtime - starttime)) # 将模型或参数传回 result_queue.put({ 'status': 'success', 'model_state': model_path, # 或保存路径 'timestamp': time.time() }) except Exception as e: result_queue.put({'status': 'error', 'msg': str(e)}) def offlineTrain(all_data,all_label,modelPath): starttime = datetime.datetime.now() # seed_n = np.random.randint(2025) seed_n = 1877 print('seed is ' + str(seed_n)) random.seed(seed_n) np.random.seed(seed_n) torch.manual_seed(seed_n) torch.cuda.manual_seed(seed_n) torch.cuda.manual_seed_all(seed_n) exp = ExP() bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath) print('THE BEST ACCURACY IS ' + str(bestAcc)) endtime = datetime.datetime.now() print('train duration: ',str(endtime - starttime)) if __name__ == "__main__": print(time.asctime(time.localtime(time.time()))) print(time.asctime(time.localtime(time.time())))