-
Notifications
You must be signed in to change notification settings - Fork 105
/
Copy pathUnet.py
103 lines (97 loc) · 4.88 KB
/
Unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
class UNet(nn.Module):
def contracting_block(self, in_channels, out_channels, kernel_size=3):
"""
This function creates one contracting block
"""
block = torch.nn.Sequential(
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(out_channels),
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(out_channels),
)
return block
def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
"""
This function creates one expansive block
"""
block = torch.nn.Sequential(
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(mid_channel),
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(mid_channel),
torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
)
return block
def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
"""
This returns final block
"""
block = torch.nn.Sequential(
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(mid_channel),
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(mid_channel),
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(out_channels),
)
return block
def __init__(self, in_channel, out_channel):
super(UNet, self).__init__()
#Encode
self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)
self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv_encode2 = self.contracting_block(64, 128)
self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
self.conv_encode3 = self.contracting_block(128, 256)
self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
# Bottleneck
self.bottleneck = torch.nn.Sequential(
torch.nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(512),
torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(512),
torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
)
# Decode
self.conv_decode3 = self.expansive_block(512, 256, 128)
self.conv_decode2 = self.expansive_block(256, 128, 64)
self.final_layer = self.final_block(128, 64, out_channel)
def crop_and_concat(self, upsampled, bypass, crop=False):
"""
This layer crop the layer from contraction block and concat it with expansive block vector
"""
if crop:
c = (bypass.size()[2] - upsampled.size()[2]) // 2
bypass = F.pad(bypass, (-c, -c, -c, -c))
return torch.cat((upsampled, bypass), 1)
def forward(self, x):
# Encode
encode_block1 = self.conv_encode1(x)
encode_pool1 = self.conv_maxpool1(encode_block1)
encode_block2 = self.conv_encode2(encode_pool1)
encode_pool2 = self.conv_maxpool2(encode_block2)
encode_block3 = self.conv_encode3(encode_pool2)
encode_pool3 = self.conv_maxpool3(encode_block3)
# Bottleneck
bottleneck1 = self.bottleneck(encode_pool3)
# Decode
decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True)
cat_layer2 = self.conv_decode3(decode_block3)
decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)
cat_layer1 = self.conv_decode2(decode_block2)
decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True)
final_layer = self.final_layer(decode_block1)
return final_layer