-
Notifications
You must be signed in to change notification settings - Fork 0
/
UNet_2Plus.py
108 lines (90 loc) · 4.15 KB
/
UNet_2Plus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import unetConv2, unetUp_origin
from init_weights import init_weights
import numpy as np
from torchvision import models
class UNet_2Plus(nn.Module):
def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True, is_batchnorm=True, is_ds=True):
super(UNet_2Plus, self).__init__()
self.is_deconv = is_deconv
self.in_channels = in_channels
self.is_batchnorm = is_batchnorm
self.is_ds = is_ds
self.feature_scale = feature_scale
# filters = [32, 64, 128, 256, 512]
filters = [64, 128, 256, 512, 1024]
# filters = [int(x / self.feature_scale) for x in filters]
# downsampling
self.conv00 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
self.maxpool0 = nn.MaxPool2d(kernel_size=2)
self.conv10 = unetConv2(filters[0], filters[1], self.is_batchnorm)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv20 = unetConv2(filters[1], filters[2], self.is_batchnorm)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv30 = unetConv2(filters[2], filters[3], self.is_batchnorm)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.conv40 = unetConv2(filters[3], filters[4], self.is_batchnorm)
# upsampling
self.up_concat01 = unetUp_origin(filters[1], filters[0], self.is_deconv)
self.up_concat11 = unetUp_origin(filters[2], filters[1], self.is_deconv)
self.up_concat21 = unetUp_origin(filters[3], filters[2], self.is_deconv)
self.up_concat31 = unetUp_origin(filters[4], filters[3], self.is_deconv)
self.up_concat02 = unetUp_origin(filters[1], filters[0], self.is_deconv, 3)
self.up_concat12 = unetUp_origin(filters[2], filters[1], self.is_deconv, 3)
self.up_concat22 = unetUp_origin(filters[3], filters[2], self.is_deconv, 3)
self.up_concat03 = unetUp_origin(filters[1], filters[0], self.is_deconv, 4)
self.up_concat13 = unetUp_origin(filters[2], filters[1], self.is_deconv, 4)
self.up_concat04 = unetUp_origin(filters[1], filters[0], self.is_deconv, 5)
# final conv (without any concat)
self.final_1 = nn.Conv2d(filters[0], n_classes, 1)
self.final_2 = nn.Conv2d(filters[0], n_classes, 1)
self.final_3 = nn.Conv2d(filters[0], n_classes, 1)
self.final_4 = nn.Conv2d(filters[0], n_classes, 1)
# initialise weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
init_weights(m, init_type='kaiming')
elif isinstance(m, nn.BatchNorm2d):
init_weights(m, init_type='kaiming')
def forward(self, inputs):
# column : 0
X_00 = self.conv00(inputs)
maxpool0 = self.maxpool0(X_00)
X_10 = self.conv10(maxpool0)
maxpool1 = self.maxpool1(X_10)
X_20 = self.conv20(maxpool1)
maxpool2 = self.maxpool2(X_20)
X_30 = self.conv30(maxpool2)
maxpool3 = self.maxpool3(X_30)
X_40 = self.conv40(maxpool3)
# column : 1
X_01 = self.up_concat01(X_10, X_00)
X_11 = self.up_concat11(X_20, X_10)
X_21 = self.up_concat21(X_30, X_20)
X_31 = self.up_concat31(X_40, X_30)
# column : 2
X_02 = self.up_concat02(X_11, X_00, X_01)
X_12 = self.up_concat12(X_21, X_10, X_11)
X_22 = self.up_concat22(X_31, X_20, X_21)
# column : 3
X_03 = self.up_concat03(X_12, X_00, X_01, X_02)
X_13 = self.up_concat13(X_22, X_10, X_11, X_12)
# column : 4
X_04 = self.up_concat04(X_13, X_00, X_01, X_02, X_03)
# final layer
final_1 = self.final_1(X_01)
final_2 = self.final_2(X_02)
final_3 = self.final_3(X_03)
final_4 = self.final_4(X_04)
final = (final_1 + final_2 + final_3 + final_4) / 4
return final
model = UNet_2Plus()
print('# generator parameters:', 1.0 * sum(param.numel() for param in model.parameters())/1000000)
params = list(model.named_parameters())
for i in range(len(params)):
(name, param) = params[i]
print(name)
print(param.shape)