Files
bci_algo/MI/Algorithm/otherModels.py

185 lines
6.4 KiB
Python
Raw Permalink Normal View History

2026-06-05 09:34:29 +08:00
from torchsummary import summary
import torch
import torch.nn as nn
def weights_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
# nn.init.constant(m.bias, 0) # bias may be none
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0)
def square_activation(x):
return torch.square(x)
def safe_log(x):
return torch.clip(torch.log(x), min=1e-7, max=1e7)
class ShallowConvNet(nn.Module):
def __init__(self, num_classes=3, chans=19, samples=768):
super(ShallowConvNet, self).__init__()
self.conv_nums = 40
self.features = nn.Sequential(
nn.Conv2d(1, self.conv_nums, (1, 25)),
nn.Conv2d(self.conv_nums, self.conv_nums, (chans, 1), bias=False),
nn.BatchNorm2d(self.conv_nums)
)
self.avgpool = nn.AvgPool2d(kernel_size=(1, 75), stride=(1, 15))
self.dropout = nn.Dropout()
out = torch.ones((1, 1, chans, samples))
out = self.features(out)
out = self.avgpool(out)
n_out_time = out.cpu().data.numpy().shape
self.classifier = nn.Linear(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes)
def forward(self, x):
x = self.features(x)
x = square_activation(x)
x = self.avgpool(x)
x = safe_log(x)
x = self.dropout(x)
features = torch.flatten(x, 1)
cls = self.classifier(features)
return cls
class EEGNet(nn.Module):
def __init__(self, num_classes=2, chans=8, samples=1000, dropout_rate=0.5, kernel_length=64, F1=8,
F2=16,):
super(EEGNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, F1, kernel_size=(1, kernel_length), bias=False),
nn.BatchNorm2d(F1),
nn.Conv2d(F1, F1, kernel_size=(chans, 1), groups=F1, bias=False), # groups=F1 for depthWiseConv
nn.BatchNorm2d(F1),
nn.ELU(inplace=True),
# nn.ReLU(),
nn.AvgPool2d((1, 4)),
nn.Dropout(dropout_rate),
# for SeparableCon2D
# SeparableConv2D(F1, F2, kernel1_size=(1, 16), bias=False),
nn.Conv2d(F1, F1, kernel_size=(1, 16), groups=F1, bias=False), # groups=F1 for depthWiseConv
nn.BatchNorm2d(F1),
nn.ELU(inplace=True),
# nn.ReLU(),
nn.Conv2d(F1, F2, kernel_size=(1, 1), groups=1, bias=False), # point-wise cnn
nn.BatchNorm2d(F2),
# nn.ReLU(),
nn.ELU(inplace=True),
nn.AvgPool2d((1, 8)),
nn.Dropout(p=dropout_rate),
# nn.Dropout(p=0.5),
)
out = torch.ones((1, 1, chans, samples))
out = self.features(out)
n_out_time = out.cpu().data.numpy().shape
self.classifier = nn.Linear(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes)
def forward(self, x):
conv_features = self.features(x)
features = torch.flatten(conv_features, 1)
cls = self.classifier(features)
return cls
class LMDA(nn.Module):
"""
LMDA-Net for the paper
"""
def __init__(self, chans=19, samples=768, num_classes=3, depth=9, kernel=75, channel_depth1=24, channel_depth2=9,
ave_depth=1, avepool=5):
super(LMDA, self).__init__()
self.ave_depth = ave_depth
self.channel_weight = nn.Parameter(torch.randn(depth, 1, chans), requires_grad=True)
nn.init.xavier_uniform_(self.channel_weight.data)
self.time_conv = nn.Sequential(
nn.Conv2d(depth, channel_depth1, kernel_size=(1, 1), groups=1, bias=False),
nn.BatchNorm2d(channel_depth1),
nn.Conv2d(channel_depth1, channel_depth1, kernel_size=(1, kernel),
groups=channel_depth1, bias=False),
nn.BatchNorm2d(channel_depth1),
nn.GELU(),
)
# self.avgPool1 = nn.AvgPool2d((1, 24))
self.chanel_conv = nn.Sequential(
nn.Conv2d(channel_depth1, channel_depth2, kernel_size=(1, 1), groups=1, bias=False),
nn.BatchNorm2d(channel_depth2),
nn.Conv2d(channel_depth2, channel_depth2, kernel_size=(chans, 1), groups=channel_depth2, bias=False),
nn.BatchNorm2d(channel_depth2),
nn.GELU(),
)
self.norm = nn.Sequential(
nn.AvgPool3d(kernel_size=(1, 1, avepool)),
# nn.AdaptiveAvgPool3d((9, 1, 35)),
nn.Dropout(p=0.65),
)
# 定义自动填充模块
out = torch.ones((1, 1, chans, samples))
out = torch.einsum('bdcw, hdc->bhcw', out, self.channel_weight)
out = self.time_conv(out)
out = self.chanel_conv(out)
out = self.norm(out)
n_out_time = out.cpu().data.numpy().shape
print('In ShallowNet, n_out_time shape: ', n_out_time)
self.classifier = nn.Linear(n_out_time[-1]*n_out_time[-2]*n_out_time[-3], num_classes)
def EEGDepthAttention(self, x):
# x: input features with shape [N, C, H, W]
N, C, H, W = x.size()
# K = W if W % 2 else W + 1
k = 7
adaptive_pool = nn.AdaptiveAvgPool2d((1, W))
conv = nn.Conv2d(1, 1, kernel_size=(k, 1), padding=(k//2, 0), bias=True).to(x.device) # original kernel k
nn.init.xavier_uniform_(conv.weight)
nn.init.constant_(conv.bias, 0)
softmax = nn.Softmax(dim=-2)
x_pool = adaptive_pool(x)
x_transpose = x_pool.transpose(-2, -3)
y = conv(x_transpose)
y = softmax(y)
y = y.transpose(-2, -3)
return y * C * x
def forward(self, x):
x = torch.einsum('bdcw, hdc->bhcw', x, self.channel_weight)
x_time = self.time_conv(x) # batch, depth1, channel, samples_
x_time = self.EEGDepthAttention(x_time) # DA1
x = self.chanel_conv(x_time) # batch, depth2, 1, samples_
x = self.norm(x)
features = torch.flatten(x, 1)
cls = self.classifier(features)
return cls
if __name__ == '__main__':
model = ShallowConvNet(num_classes=4, chans=22, samples=1125).cuda()
a = torch.randn(12, 1, 3, 875).cuda().float()
l2 = model(a)
model_optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
summary(model, show_input=True)
print(l2.shape)