Files
bci_algo/MI/Algorithm/conformer_2class_cpu.py
2026-06-05 09:34:29 +08:00

383 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
EEG Conformer
Convolutional Transformer for EEG decoding
Couple CNN and Transformer in a concise manner with amazing results
"""
# remember to change paths
import os
import numpy as np
import math
import random
import time
import datetime
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import nn
from torch import Tensor
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
from torch.backends import cudnn
from sklearn.model_selection import train_test_split
# writer = SummaryWriter('./TensorBoardX/')
# Convolution module
# use conv to capture local features, instead of postion embedding.
class PatchEmbedding(nn.Module):
def __init__(self, emb_size=40):
# self.patch_size = patch_size
super().__init__()
self.shallownet = nn.Sequential(
nn.Conv2d(1, 40, (1, 25), (1, 1)),
nn.Conv2d(40, 40, (8, 1), (1, 1)),
nn.BatchNorm2d(40),
nn.ELU(),
nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
nn.Dropout(0.5),
)
self.projection = nn.Sequential(
nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly
Rearrange('b e (h) (w) -> b (h w) e'),
)
def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape
x = self.shallownet(x)
x = self.projection(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size, num_heads, dropout):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
self.keys = nn.Linear(emb_size, emb_size)
self.queries = nn.Linear(emb_size, emb_size)
self.values = nn.Linear(emb_size, emb_size)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
if mask is not None:
fill_value = torch.finfo(torch.float32).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1 / 2)
att = F.softmax(energy / scaling, dim=-1)
att = self.att_drop(att)
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out
class ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x
class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size, expansion, drop_p):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
)
class GELU(nn.Module):
def forward(self, input: Tensor) -> Tensor:
return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
class TransformerEncoderBlock(nn.Sequential):
def __init__(self,
emb_size,
num_heads=10,
drop_p=0.5,
forward_expansion=4,
forward_drop_p=0.5):
super().__init__(
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
MultiHeadAttention(emb_size, num_heads, drop_p),
nn.Dropout(drop_p)
)),
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
FeedForwardBlock(
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
nn.Dropout(drop_p)
)
))
class TransformerEncoder(nn.Sequential):
def __init__(self, depth, emb_size):
super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
class ClassificationHead(nn.Sequential):
def __init__(self, emb_size, n_classes):
super().__init__()
# global average pooling
self.clshead = nn.Sequential(
Reduce('b n e -> b e', reduction='mean'),
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes)
)
self.fc = nn.Sequential(
nn.Linear(2440, 256),
nn.ELU(),
nn.Dropout(0.5),
nn.Linear(256, 32),
nn.ELU(),
nn.Dropout(0.3),
nn.Linear(32, 2)
)
def forward(self, x):
x = x.contiguous().view(x.size(0), -1)
out = self.fc(x)
return out
class Conformer(nn.Sequential):
def __init__(self, emb_size=40, depth=6, n_classes=2, **kwargs):
super().__init__(
PatchEmbedding(emb_size),
TransformerEncoder(depth, emb_size),
ClassificationHead(emb_size, n_classes)
)
class ExP():
def __init__(self):
super(ExP, self).__init__()
self.batch_size = 24
self.n_epochs = 250
self.c_dim = 4
self.lr = 0.0002
self.b1 = 0.5
self.b2 = 0.999
self.start_epoch = 0
self.log_write = open("./online_Models/log_result.txt", "w")
# 自动选择设备:有 GPU 用 GPU否则用 CPU
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.device = torch.device("cpu")
print(f"Using device: {self.device}")
# 定义张量类型(不再强制使用 cuda
self.Tensor = torch.FloatTensor
self.LongTensor = torch.LongTensor
# 将模型移到指定设备
self.model = Conformer().to(self.device)
# 损失函数也移到设备
self.criterion_cls = torch.nn.CrossEntropyLoss().to(self.device)
# self.model = EEGNet().cuda()
# self.model = nn.DataParallel(self.model,device_ids=[i for i in range(len(gpus))])
# self.model = self.model.cuda()
# summary(self.model, (1, 8, 1000))
# Segmentation and Reconstruction (S&R) data augmentation
def interaug(self, timg, label):
aug_data = []
aug_label = []
for cls4aug in range(2):
cls_idx = np.where(label == cls4aug + 1)
tmp_data = timg[cls_idx]
tmp_label = label[cls_idx]
tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, 8, 1000))
for ri in range(int(self.batch_size / 2)):
for rj in range(8):
rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :,
rj * 125:(rj + 1) * 125]
aug_data.append(tmp_aug_data)
aug_label.append(tmp_label[:int(self.batch_size / 2)])
aug_data = np.concatenate(aug_data)
aug_label = np.concatenate(aug_label)
aug_shuffle = np.random.permutation(len(aug_data))
aug_data = aug_data[aug_shuffle, :, :]
aug_label = aug_label[aug_shuffle]
aug_data = torch.from_numpy(aug_data).float().to(self.device)
aug_label = torch.from_numpy(aug_label - 1).long().to(self.device)
return aug_data, aug_label
def train(self,all_data,all_label,model_path):
all_data = np.array(all_data);all_label = np.array(all_label)
all_data = np.expand_dims(all_data, axis=1)
train_data, test_data, train_label, test_label = train_test_split(all_data, all_label, test_size=0.2,
random_state=42, stratify=all_label,shuffle=True)
# 转为 Tensor
img = torch.from_numpy(train_data).float().to(self.device)
label = torch.from_numpy(train_label - 1).long().to(self.device)
dataset = torch.utils.data.TensorDataset(img, label)
self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
test_data = torch.from_numpy(test_data).float().to(self.device)
test_label = torch.from_numpy(test_label - 1).long().to(self.device)
test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
# Optimizers
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
bestAcc = 0
averAcc = 0
num = 0
Y_true = 0
Y_pred = 0
# Train the cnn model
for e in range(self.n_epochs):
# in_epoch = time.time()
self.model.train()
for i, (img, label) in enumerate(self.dataloader):
# data augmentation
aug_data, aug_label = self.interaug(train_data, train_label)
img = torch.cat((img, aug_data))
label = torch.cat((label, aug_label))
outputs = self.model(img)
loss = self.criterion_cls(outputs, label)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# out_epoch = time.time()
# test process
if (e + 1) % 1 == 0:
self.model.eval()
with torch.no_grad():
Cls = self.model(test_data)
loss_test = self.criterion_cls(Cls, test_label)
y_pred = torch.max(Cls, 1)[1]
acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
train_pred = torch.max(outputs, 1)[1]
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
print('Epoch:', e,
' Train loss: %.6f' % loss.detach().cpu().numpy(),
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
' Train accuracy %.6f' % train_acc,
' Test accuracy is %.6f' % acc)
self.log_write.write(str(e) + " " + str(acc) + "\n")
num = num + 1
averAcc = averAcc + acc
if acc > bestAcc:
bestAcc = acc
Y_true = test_label
Y_pred = y_pred
torch.save(self.model, model_path)
averAcc = averAcc / num
print('The average accuracy is:', averAcc)
print('The best accuracy is:', bestAcc)
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
return bestAcc, averAcc, Y_true, Y_pred
# writer.close()
def onlineTrain(data_queue,result_queue):
try:
starttime = datetime.datetime.now()
# seed_n = np.random.randint(2025)
seed_n = 1877
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)
exp = ExP()
# 从队列获取训练数据
data = data_queue.get(timeout=30)
all_data, all_label,model_path = data['data'], data['label'],data['modelPath']
print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path)
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
print('THE BEST ACCURACY IS ' + str(bestAcc))
endtime = datetime.datetime.now()
print('train duration: ',str(endtime - starttime))
# 将模型或参数传回
result_queue.put({
'status': 'success',
'model_state': model_path, # 或保存路径
'timestamp': time.time()
})
except Exception as e:
result_queue.put({'status': 'error', 'msg': str(e)})
def offlineTrain(all_data,all_label,modelPath):
starttime = datetime.datetime.now()
# seed_n = np.random.randint(2025)
seed_n = 1877
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)
exp = ExP()
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
print('THE BEST ACCURACY IS ' + str(bestAcc))
endtime = datetime.datetime.now()
print('train duration: ',str(endtime - starttime))
if __name__ == "__main__":
print(time.asctime(time.localtime(time.time())))
print(time.asctime(time.localtime(time.time())))