-
Notifications
You must be signed in to change notification settings - Fork 0
/
Seg_networks.py
242 lines (202 loc) · 8.68 KB
/
Seg_networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
from operator import pos
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
# from . import layers
from Diffeo_layers import *
from Diffeo_modelio import LoadableModel, store_config_args
class Unet(nn.Module):
"""
A unet architecture. Layer features can be specified directly as a list of encoder and decoder
features or as a single integer along with a number of unet levels. The default network features
per layer (when no options are specified) are:
encoder: [16, 32, 32, 32]
decoder: [32, 32, 32, 32, 32, 16, 16]
"""
def __init__(self,
inshape=None,
infeats=None,
nb_features=None,
nb_levels=None,
max_pool=2,
feat_mult=1,
nb_conv_per_level=1,
half_res=False):
"""
Parameters:
inshape: Input shape. e.g. (192, 192, 192)
infeats: Number of input features.
nb_features: Unet convolutional features. Can be specified via a list of lists with
the form [[encoder feats], [decoder feats]], or as a single integer.
If None (default), the unet features are defined by the default config described in
the class documentation.
nb_levels: Number of levels in unet. Only used when nb_features is an integer.
Default is None.
feat_mult: Per-level feature multiplier. Only used when nb_features is an integer.
Default is 1.
nb_conv_per_level: Number of convolutions per unet level. Default is 1.
half_res: Skip the last decoder upsampling. Default is False.
"""
super().__init__()
# ensure correct dimensionality
ndims = len(inshape)
assert ndims in [1, 2, 3], 'ndims should be one of 1, 2, or 3. found: %d' % ndims
# cache some parameters
self.half_res = half_res
# default encoder and decoder layer features if nothing provided
if nb_features is None:
nb_features = default_unet_features()
# build feature list automatically
if isinstance(nb_features, int):
if nb_levels is None:
raise ValueError('must provide unet nb_levels if nb_features is an integer')
feats = np.round(nb_features * feat_mult ** np.arange(nb_levels)).astype(int)
nb_features = [
np.repeat(feats[:-1], nb_conv_per_level),
np.repeat(np.flip(feats), nb_conv_per_level)
]
elif nb_levels is not None:
raise ValueError('cannot use nb_levels if nb_features is not an integer')
# extract any surplus (full resolution) decoder convolutions
enc_nf, dec_nf = nb_features
nb_dec_convs = len(enc_nf)
final_convs = dec_nf[nb_dec_convs:]
dec_nf = dec_nf[:nb_dec_convs]
self.nb_levels = int(nb_dec_convs / nb_conv_per_level) + 1
if isinstance(max_pool, int):
max_pool = [max_pool] * self.nb_levels
# cache downsampling / upsampling operations
MaxPooling = getattr(nn, 'MaxPool%dd' % ndims)
self.pooling = [MaxPooling(s) for s in max_pool]
self.upsampling = [nn.Upsample(scale_factor=s, mode='nearest') for s in max_pool]
# configure encoder (down-sampling path)
prev_nf = infeats
encoder_nfs = [prev_nf]
self.encoder = nn.ModuleList()
for level in range(self.nb_levels - 1):
convs = nn.ModuleList()
for conv in range(nb_conv_per_level):
nf = enc_nf[level * nb_conv_per_level + conv]
convs.append(ConvBlock(ndims, prev_nf, nf))
prev_nf = nf
self.encoder.append(convs)
encoder_nfs.append(prev_nf)
# configure decoder (up-sampling path)
encoder_nfs = np.flip(encoder_nfs)
self.decoder = nn.ModuleList()
for level in range(self.nb_levels - 1):
convs = nn.ModuleList()
for conv in range(nb_conv_per_level):
nf = dec_nf[level * nb_conv_per_level + conv]
convs.append(ConvBlock(ndims, prev_nf, nf))
prev_nf = nf
self.decoder.append(convs)
if not half_res or level < (self.nb_levels - 2):
prev_nf += encoder_nfs[level]
# now we take care of any remaining convolutions
self.remaining = nn.ModuleList()
for num, nf in enumerate(final_convs):
self.remaining.append(ConvBlock(ndims, prev_nf, nf))
prev_nf = nf
# cache final number of features
self.final_nf = prev_nf
def forward(self, x):
# encoder forward pass
x_history = [x]
for level, convs in enumerate(self.encoder):
for conv in convs:
x = conv(x)
x_history.append(x)
x = self.pooling[level](x)
latent = x
# decoder forward pass with upsampling and concatenation
# print (x_history[5].shape)
for level, convs in enumerate(self.decoder):
for conv in convs:
x = conv(x)
if not self.half_res or level < (self.nb_levels - 2):
x = self.upsampling[level](x)
# print ('shape1:', x.shape)
# print ('shape2:', x_history.pop().shape)
x = torch.cat([x, x_history.pop()], dim=1)
# remaining convs at full resolution
for conv in self.remaining:
x = conv(x)
return x
class SegDense(LoadableModel):
@store_config_args
def __init__(self,
inshape,
nb_unet_features=None,
nb_unet_levels=None,
unet_feat_mult=1,
nb_unet_conv_per_level=1,
int_steps=7,
int_downsize=2,
bidir=False,
use_probs=False,
src_feats=1,
trg_feats=0,
unet_half_res=False):
super().__init__()
# internal flag indicating whether to return flow or integrated warp during inference
self.training = True
# ensure correct dimensionality
ndims = len(inshape)
assert ndims in [1, 2, 3], 'ndims should be one of 1, 2, or 3. found: %d' % ndims
# configure core unet model
self.unet_model = Unet(
inshape,
infeats=(src_feats + trg_feats),
nb_features=nb_unet_features,
nb_levels=nb_unet_levels,
feat_mult=unet_feat_mult,
nb_conv_per_level=nb_unet_conv_per_level,
half_res=unet_half_res,
)
# configure unet to flow field layer
Conv = getattr(nn, 'Conv%dd' % ndims)
self.flow = Conv(self.unet_model.final_nf, 1, kernel_size=3, padding=1)
# init flow layer with small weights and bias
self.flow.weight = nn.Parameter(Normal(0, 1e-5).sample(self.flow.weight.shape))
self.flow.bias = nn.Parameter(torch.zeros(self.flow.bias.shape))
# probabilities are not supported in pytorch
if use_probs:
raise NotImplementedError(
'Flow variance has not been implemented in pytorch - set use_probs to False')
# configure optional resize layers (downsize)
if not unet_half_res and int_steps > 0 and int_downsize > 1:
self.resize = ResizeTransform(int_downsize, ndims)
else:
self.resize = None
# resize to full res
if int_steps > 0 and int_downsize > 1:
self.fullsize = ResizeTransform(1 / int_downsize, ndims)
else:
self.fullsize = None
# configure bidirectional training
self.bidir = bidir
# configure optional integration layer for diffeomorphic warp
down_shape = [int(dim / int_downsize) for dim in inshape]
self.labelize = nn.Sigmoid()
def forward(self, in_vol):
# concatenate inputs and propagate unet
x = torch.cat([in_vol, in_vol], dim=1)
x = self.unet_model(x)
# transform into flow field
seg_field = self.flow(x)
seg_field = self.fullsize(seg_field )
seg_field = self.labelize(seg_field)
return seg_field
class ConvBlock(nn.Module):
def __init__(self, ndims, in_channels, out_channels, stride=1):
super().__init__()
Conv = getattr(nn, 'Conv%dd' % ndims)
self.main = Conv(in_channels, out_channels, 3, stride, 1)
self.activation = nn.LeakyReLU(0.2)
def forward(self, x):
out = self.main(x)
out = self.activation(out)
return out