1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
| #!/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :Pytorch_learn @File :Myattention.py @IDE :PyCharm @Author :咋 @Date :2023/7/14 17:42 """ import torch import torch.nn as nn from torch.nn import Conv2d
class Channel_attention(nn.Module): def __init__(self,channel,ratio=16): super(Channel_attention, self).__init__() self.max_pool = nn.AdaptiveMaxPool2d(1) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel,channel//ratio,False), nn.ReLU(), nn.Linear(channel//ratio,channel,False), ) self.sigmoid = nn.Sigmoid() def forward(self,x): b,c,h,w = x.size() maxpool = self.max_pool(x).view([b,c]) avgpool = self.avg_pool(x).view([b,c]) maxpool_fc = self.fc(maxpool) avgpool_fc = self.fc(avgpool) out =maxpool_fc+avgpool_fc out = self.sigmoid(out).view([b,c,1,1]) return out
class Spacial_attention(nn.Module): def __init__(self,kernel_size=7): super(Spacial_attention, self).__init__() self.conv1 = Conv2d(2,1,kernel_size,1,padding=3,bias=False) self.sigmoid = nn.Sigmoid()
def forward(self,x): max_pool_out,_ = torch.max(x,dim=1,keepdim=True) avg_pool_out = torch.mean(x,dim=1,keepdim=True) pool_out = torch.cat([max_pool_out,avg_pool_out],dim=1) out = self.conv1(pool_out) out = self.sigmoid(out) return out
class My_attention(nn.Module): def __init__(self,channel,ratio=16,kernel_size=7): super(My_attention, self).__init__() self.channel_attention = Channel_attention(channel) self.spacial_attention = Spacial_attention()
def forward(self,x): SA_x = self.spacial_attention(x)*x # 空间注意力机制求的结果 CA_x = self.channel_attention(x)*x # 通道注意力机制求的结果 output = SA_x+CA_x+x # 将三个特征图加在一起 return output
myattention = My_attention(512) print(myattention) inputs = torch.ones([2,512,26,26]) outputs = myattention(inputs) print(outputs) print(outputs.size())
|