from torchsummary import summary import torch import torch.nn as nn def weights_init(m): if isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight) # nn.init.constant(m.bias, 0) # bias may be none elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0) def square_activation(x): return torch.square(x) def safe_log(x): return torch.clip(torch.log(x), min=1e-7, max=1e7) class ShallowConvNet(nn.Module): def __init__(self, num_classes=3, chans=19, samples=768): super(ShallowConvNet, self).__init__() self.conv_nums = 40 self.features = nn.Sequential( nn.Conv2d(1, self.conv_nums, (1, 25)), nn.Conv2d(self.conv_nums, self.conv_nums, (chans, 1), bias=False), nn.BatchNorm2d(self.conv_nums) ) self.avgpool = nn.AvgPool2d(kernel_size=(1, 75), stride=(1, 15)) self.dropout = nn.Dropout() out = torch.ones((1, 1, chans, samples)) out = self.features(out) out = self.avgpool(out) n_out_time = out.cpu().data.numpy().shape self.classifier = nn.Linear(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes) def forward(self, x): x = self.features(x) x = square_activation(x) x = self.avgpool(x) x = safe_log(x) x = self.dropout(x) features = torch.flatten(x, 1) cls = self.classifier(features) return cls class EEGNet(nn.Module): def __init__(self, num_classes=2, chans=8, samples=1000, dropout_rate=0.5, kernel_length=64, F1=8, F2=16,): super(EEGNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(1, F1, kernel_size=(1, kernel_length), bias=False), nn.BatchNorm2d(F1), nn.Conv2d(F1, F1, kernel_size=(chans, 1), groups=F1, bias=False), # groups=F1 for depthWiseConv nn.BatchNorm2d(F1), nn.ELU(inplace=True), # nn.ReLU(), nn.AvgPool2d((1, 4)), nn.Dropout(dropout_rate), # for SeparableCon2D # SeparableConv2D(F1, F2, kernel1_size=(1, 16), bias=False), nn.Conv2d(F1, F1, kernel_size=(1, 16), groups=F1, bias=False), # groups=F1 for depthWiseConv nn.BatchNorm2d(F1), nn.ELU(inplace=True), # nn.ReLU(), nn.Conv2d(F1, F2, kernel_size=(1, 1), groups=1, bias=False), # point-wise cnn nn.BatchNorm2d(F2), # nn.ReLU(), nn.ELU(inplace=True), nn.AvgPool2d((1, 8)), nn.Dropout(p=dropout_rate), # nn.Dropout(p=0.5), ) out = torch.ones((1, 1, chans, samples)) out = self.features(out) n_out_time = out.cpu().data.numpy().shape self.classifier = nn.Linear(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes) def forward(self, x): conv_features = self.features(x) features = torch.flatten(conv_features, 1) cls = self.classifier(features) return cls class LMDA(nn.Module): """ LMDA-Net for the paper """ def __init__(self, chans=19, samples=768, num_classes=3, depth=9, kernel=75, channel_depth1=24, channel_depth2=9, ave_depth=1, avepool=5): super(LMDA, self).__init__() self.ave_depth = ave_depth self.channel_weight = nn.Parameter(torch.randn(depth, 1, chans), requires_grad=True) nn.init.xavier_uniform_(self.channel_weight.data) self.time_conv = nn.Sequential( nn.Conv2d(depth, channel_depth1, kernel_size=(1, 1), groups=1, bias=False), nn.BatchNorm2d(channel_depth1), nn.Conv2d(channel_depth1, channel_depth1, kernel_size=(1, kernel), groups=channel_depth1, bias=False), nn.BatchNorm2d(channel_depth1), nn.GELU(), ) # self.avgPool1 = nn.AvgPool2d((1, 24)) self.chanel_conv = nn.Sequential( nn.Conv2d(channel_depth1, channel_depth2, kernel_size=(1, 1), groups=1, bias=False), nn.BatchNorm2d(channel_depth2), nn.Conv2d(channel_depth2, channel_depth2, kernel_size=(chans, 1), groups=channel_depth2, bias=False), nn.BatchNorm2d(channel_depth2), nn.GELU(), ) self.norm = nn.Sequential( nn.AvgPool3d(kernel_size=(1, 1, avepool)), # nn.AdaptiveAvgPool3d((9, 1, 35)), nn.Dropout(p=0.65), ) # 定义自动填充模块 out = torch.ones((1, 1, chans, samples)) out = torch.einsum('bdcw, hdc->bhcw', out, self.channel_weight) out = self.time_conv(out) out = self.chanel_conv(out) out = self.norm(out) n_out_time = out.cpu().data.numpy().shape print('In ShallowNet, n_out_time shape: ', n_out_time) self.classifier = nn.Linear(n_out_time[-1]*n_out_time[-2]*n_out_time[-3], num_classes) def EEGDepthAttention(self, x): # x: input features with shape [N, C, H, W] N, C, H, W = x.size() # K = W if W % 2 else W + 1 k = 7 adaptive_pool = nn.AdaptiveAvgPool2d((1, W)) conv = nn.Conv2d(1, 1, kernel_size=(k, 1), padding=(k//2, 0), bias=True).to(x.device) # original kernel k nn.init.xavier_uniform_(conv.weight) nn.init.constant_(conv.bias, 0) softmax = nn.Softmax(dim=-2) x_pool = adaptive_pool(x) x_transpose = x_pool.transpose(-2, -3) y = conv(x_transpose) y = softmax(y) y = y.transpose(-2, -3) return y * C * x def forward(self, x): x = torch.einsum('bdcw, hdc->bhcw', x, self.channel_weight) x_time = self.time_conv(x) # batch, depth1, channel, samples_ x_time = self.EEGDepthAttention(x_time) # DA1 x = self.chanel_conv(x_time) # batch, depth2, 1, samples_ x = self.norm(x) features = torch.flatten(x, 1) cls = self.classifier(features) return cls if __name__ == '__main__': model = ShallowConvNet(num_classes=4, chans=22, samples=1125).cuda() a = torch.randn(12, 1, 3, 875).cuda().float() l2 = model(a) model_optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2) summary(model, show_input=True) print(l2.shape)