Files
bci_algo/MI/Algorithm/conformer_2class_cpu.py

384 lines
13 KiB
Python
Raw Permalink Normal View History

2026-06-05 09:34:29 +08:00
"""
EEG Conformer
Convolutional Transformer for EEG decoding
Couple CNN and Transformer in a concise manner with amazing results
"""
# remember to change paths
import os
import numpy as np
import math
import random
import time
import datetime
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import nn
from torch import Tensor
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
from torch.backends import cudnn
from sklearn.model_selection import train_test_split
2026-06-10 16:04:02 +08:00
from logs.log import algo_log
2026-06-05 09:34:29 +08:00
# writer = SummaryWriter('./TensorBoardX/')
# Convolution module
# use conv to capture local features, instead of postion embedding.
class PatchEmbedding(nn.Module):
def __init__(self, emb_size=40):
# self.patch_size = patch_size
super().__init__()
self.shallownet = nn.Sequential(
nn.Conv2d(1, 40, (1, 25), (1, 1)),
nn.Conv2d(40, 40, (8, 1), (1, 1)),
nn.BatchNorm2d(40),
nn.ELU(),
nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
nn.Dropout(0.5),
)
self.projection = nn.Sequential(
nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly
Rearrange('b e (h) (w) -> b (h w) e'),
)
def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape
x = self.shallownet(x)
x = self.projection(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size, num_heads, dropout):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
self.keys = nn.Linear(emb_size, emb_size)
self.queries = nn.Linear(emb_size, emb_size)
self.values = nn.Linear(emb_size, emb_size)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
if mask is not None:
2026-06-08 15:47:25 +08:00
fill_value = torch.finfo(torch.float64).min
2026-06-05 09:34:29 +08:00
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1 / 2)
att = F.softmax(energy / scaling, dim=-1)
att = self.att_drop(att)
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out
class ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x
class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size, expansion, drop_p):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
)
class GELU(nn.Module):
def forward(self, input: Tensor) -> Tensor:
return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
class TransformerEncoderBlock(nn.Sequential):
def __init__(self,
emb_size,
num_heads=10,
drop_p=0.5,
forward_expansion=4,
forward_drop_p=0.5):
super().__init__(
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
MultiHeadAttention(emb_size, num_heads, drop_p),
nn.Dropout(drop_p)
)),
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
FeedForwardBlock(
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
nn.Dropout(drop_p)
)
))
class TransformerEncoder(nn.Sequential):
def __init__(self, depth, emb_size):
super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
class ClassificationHead(nn.Sequential):
def __init__(self, emb_size, n_classes):
super().__init__()
# global average pooling
self.clshead = nn.Sequential(
Reduce('b n e -> b e', reduction='mean'),
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes)
)
self.fc = nn.Sequential(
nn.Linear(2440, 256),
nn.ELU(),
nn.Dropout(0.5),
nn.Linear(256, 32),
nn.ELU(),
nn.Dropout(0.3),
nn.Linear(32, 2)
)
def forward(self, x):
x = x.contiguous().view(x.size(0), -1)
out = self.fc(x)
return out
class Conformer(nn.Sequential):
def __init__(self, emb_size=40, depth=6, n_classes=2, **kwargs):
super().__init__(
PatchEmbedding(emb_size),
TransformerEncoder(depth, emb_size),
ClassificationHead(emb_size, n_classes)
)
class ExP():
def __init__(self):
super(ExP, self).__init__()
self.batch_size = 24
self.n_epochs = 250
self.c_dim = 4
self.lr = 0.0002
self.b1 = 0.5
self.b2 = 0.999
self.start_epoch = 0
self.log_write = open("./online_Models/log_result.txt", "w")
# 自动选择设备:有 GPU 用 GPU否则用 CPU
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.device = torch.device("cpu")
2026-06-10 16:04:02 +08:00
algo_log(f"Using device: {self.device}", level="debug")
2026-06-05 09:34:29 +08:00
# 定义张量类型(不再强制使用 cuda
self.Tensor = torch.FloatTensor
self.LongTensor = torch.LongTensor
# 将模型移到指定设备
self.model = Conformer().to(self.device)
# 损失函数也移到设备
self.criterion_cls = torch.nn.CrossEntropyLoss().to(self.device)
# self.model = EEGNet().cuda()
# self.model = nn.DataParallel(self.model,device_ids=[i for i in range(len(gpus))])
# self.model = self.model.cuda()
# summary(self.model, (1, 8, 1000))
# Segmentation and Reconstruction (S&R) data augmentation
def interaug(self, timg, label):
aug_data = []
aug_label = []
for cls4aug in range(2):
cls_idx = np.where(label == cls4aug + 1)
tmp_data = timg[cls_idx]
tmp_label = label[cls_idx]
tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, 8, 1000))
for ri in range(int(self.batch_size / 2)):
for rj in range(8):
rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :,
rj * 125:(rj + 1) * 125]
aug_data.append(tmp_aug_data)
aug_label.append(tmp_label[:int(self.batch_size / 2)])
aug_data = np.concatenate(aug_data)
aug_label = np.concatenate(aug_label)
aug_shuffle = np.random.permutation(len(aug_data))
aug_data = aug_data[aug_shuffle, :, :]
aug_label = aug_label[aug_shuffle]
aug_data = torch.from_numpy(aug_data).float().to(self.device)
aug_label = torch.from_numpy(aug_label - 1).long().to(self.device)
return aug_data, aug_label
def train(self,all_data,all_label,model_path):
all_data = np.array(all_data);all_label = np.array(all_label)
all_data = np.expand_dims(all_data, axis=1)
train_data, test_data, train_label, test_label = train_test_split(all_data, all_label, test_size=0.2,
random_state=42, stratify=all_label,shuffle=True)
# 转为 Tensor
img = torch.from_numpy(train_data).float().to(self.device)
label = torch.from_numpy(train_label - 1).long().to(self.device)
dataset = torch.utils.data.TensorDataset(img, label)
self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
test_data = torch.from_numpy(test_data).float().to(self.device)
test_label = torch.from_numpy(test_label - 1).long().to(self.device)
test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
# Optimizers
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
bestAcc = 0
averAcc = 0
num = 0
Y_true = 0
Y_pred = 0
# Train the cnn model
for e in range(self.n_epochs):
# in_epoch = time.time()
self.model.train()
for i, (img, label) in enumerate(self.dataloader):
# data augmentation
aug_data, aug_label = self.interaug(train_data, train_label)
img = torch.cat((img, aug_data))
label = torch.cat((label, aug_label))
outputs = self.model(img)
loss = self.criterion_cls(outputs, label)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# out_epoch = time.time()
# test process
if (e + 1) % 1 == 0:
self.model.eval()
with torch.no_grad():
Cls = self.model(test_data)
loss_test = self.criterion_cls(Cls, test_label)
y_pred = torch.max(Cls, 1)[1]
acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
train_pred = torch.max(outputs, 1)[1]
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
print('Epoch:', e,
' Train loss: %.6f' % loss.detach().cpu().numpy(),
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
' Train accuracy %.6f' % train_acc,
' Test accuracy is %.6f' % acc)
self.log_write.write(str(e) + " " + str(acc) + "\n")
num = num + 1
averAcc = averAcc + acc
if acc > bestAcc:
bestAcc = acc
Y_true = test_label
Y_pred = y_pred
torch.save(self.model, model_path)
averAcc = averAcc / num
print('The average accuracy is:', averAcc)
print('The best accuracy is:', bestAcc)
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
return bestAcc, averAcc, Y_true, Y_pred
# writer.close()
def onlineTrain(data_queue,result_queue):
try:
starttime = datetime.datetime.now()
# seed_n = np.random.randint(2025)
seed_n = 1877
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)
exp = ExP()
# 从队列获取训练数据
data = data_queue.get(timeout=30)
all_data, all_label,model_path = data['data'], data['label'],data['modelPath']
print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path)
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
print('THE BEST ACCURACY IS ' + str(bestAcc))
endtime = datetime.datetime.now()
print('train duration: ',str(endtime - starttime))
# 将模型或参数传回
result_queue.put({
'status': 'success',
'model_state': model_path, # 或保存路径
'timestamp': time.time()
})
except Exception as e:
result_queue.put({'status': 'error', 'msg': str(e)})
def offlineTrain(all_data,all_label,modelPath):
starttime = datetime.datetime.now()
# seed_n = np.random.randint(2025)
seed_n = 1877
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)
exp = ExP()
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
print('THE BEST ACCURACY IS ' + str(bestAcc))
endtime = datetime.datetime.now()
print('train duration: ',str(endtime - starttime))
if __name__ == "__main__":
print(time.asctime(time.localtime(time.time())))
print(time.asctime(time.localtime(time.time())))