1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
|
""" @Project :Pytorch_learn @File :unet_parts.py @IDE :PyCharm @Author :咋 @Date :2023/7/14 17:47 """ from torch import nn import torch import torch.nn.functional as F
class Conv_Block(nn.Module): def __init__(self,in_channel,out_channel): super(Conv_Block, self).__init__() self.layer = nn.Sequential( nn.Conv2d(in_channel,out_channel,kernel_size=3,stride=1,padding=1,bias=False), nn.BatchNorm2d(out_channel), nn.Relu(inplace=True), nn.Conv2d(out_channel,out_channel,3,1,1,bias=False), nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True) )
def forward(self,x): return self.layer(x)
class Down(nn.Module): def __init__(self,in_channel,out_channel): super(Down, self).__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), Conv_Block(in_channel,out_channel) ) def forward(self,x): return self.maxpool_conv(x)
class Up(nn.Module): def __init__(self,in_channel,out_channel,bilinear=False): super(Up, self).__init__() self.up = nn.Upsample(scale_factor=2,mode="bilinear",align_corners=True) self.conv = Conv_Block(in_channel,out_channel)
def forward(self,x1,x2): x1 = self.up(x1) diffY = x2.size()[2]-x1.size()[2] diffX = x2.size()[3]-x1.size()[3] x1 = F.pad(x1,[diffX//2,diffX-diffX//2, diffY//2,diffY-diffY//2]) x = torch.cat([x2,x1],dim=1) return self.conv(x)
class OutConv(nn.Module): def __init__(self,in_channel,out_channel): super(OutConv, self).__init__() self.conv = nn.Conv2d(in_channel,out_channel,1)
def forward(self,x): return self.conv(x)
|