image.png

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :Pytorch_learn
@File :Myattention.py
@IDE :PyCharm
@Author :咋
@Date :2023/7/14 17:42
"""
import torch
import torch.nn as nn
from torch.nn import Conv2d

class Channel_attention(nn.Module):
def __init__(self,channel,ratio=16):
super(Channel_attention, self).__init__()
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel,channel//ratio,False),
nn.ReLU(),
nn.Linear(channel//ratio,channel,False),
)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
b,c,h,w = x.size()
maxpool = self.max_pool(x).view([b,c])
avgpool = self.avg_pool(x).view([b,c])
maxpool_fc = self.fc(maxpool)
avgpool_fc = self.fc(avgpool)
out =maxpool_fc+avgpool_fc
out = self.sigmoid(out).view([b,c,1,1])
return out


class Spacial_attention(nn.Module):
def __init__(self,kernel_size=7):
super(Spacial_attention, self).__init__()
self.conv1 = Conv2d(2,1,kernel_size,1,padding=3,bias=False)
self.sigmoid = nn.Sigmoid()

def forward(self,x):
max_pool_out,_ = torch.max(x,dim=1,keepdim=True)
avg_pool_out = torch.mean(x,dim=1,keepdim=True)
pool_out = torch.cat([max_pool_out,avg_pool_out],dim=1)
out = self.conv1(pool_out)
out = self.sigmoid(out)
return out


class My_attention(nn.Module):
def __init__(self,channel,ratio=16,kernel_size=7):
super(My_attention, self).__init__()
self.channel_attention = Channel_attention(channel)
self.spacial_attention = Spacial_attention()


def forward(self,x):
SA_x = self.spacial_attention(x)*x # 空间注意力机制求的结果
CA_x = self.channel_attention(x)*x # 通道注意力机制求的结果
output = SA_x+CA_x+x # 将三个特征图加在一起
return output

myattention = My_attention(512)
print(myattention)
inputs = torch.ones([2,512,26,26])
outputs = myattention(inputs)
print(outputs)
print(outputs.size())