-
Notifications
You must be signed in to change notification settings - Fork 5
/
udepth.py
131 lines (106 loc) · 5.3 KB
/
udepth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
# > Model architecture of Udepth
# - Paper: https://arxiv.org/pdf/2209.12358.pdf
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .miniViT import mViT
class UpSample(nn.Sequential):
def __init__(self, skip_input, output_features):
super(UpSample, self).__init__()
self.convA = nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1)
self.leakyreluA = nn.LeakyReLU(0.2)
self.convB = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1)
self.leakyreluB = nn.LeakyReLU(0.2)
def forward(self, x, concat_with):
up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
return self.leakyreluB( self.convB( self.leakyreluA(self.convA( torch.cat([up_x, concat_with], dim=1) ) ) ) )
class Decoder(nn.Module):
def __init__(self, num_features=1280, decoder_width = .6):
super(Decoder, self).__init__()
features = int(num_features * decoder_width)
self.conv2 = nn.Conv2d(num_features, features, kernel_size=1, stride=1, padding=1)
self.up0 = UpSample(skip_input=features//1 + 320, output_features=features//2)
self.up1 = UpSample(skip_input=features//2 + 160, output_features=features//2)
self.up2 = UpSample(skip_input=features//2 + 64, output_features=features//4)
self.up3 = UpSample(skip_input=features//4 + 32, output_features=features//8)
self.up4 = UpSample(skip_input=features//8 + 24, output_features=features//8)
self.up5 = UpSample(skip_input=features//8 + 16, output_features=features//16)
self.conv3 = nn.Conv2d(features//16, 1, kernel_size=3, stride=1, padding=1)
def forward(self, features):
x_block0, x_block1, x_block2, x_block3, x_block4,x_block5,x_block6 = features[2], features[4], features[6], features[9], features[15],features[18],features[19]
x_d0 = self.conv2(x_block6)
x_d1 = self.up0(x_d0, x_block5)
x_d2 = self.up1(x_d1, x_block4)
x_d3 = self.up2(x_d2, x_block3)
x_d4 = self.up3(x_d3, x_block2)
x_d5 = self.up4(x_d4, x_block1)
x_d6 = self.up5(x_d5, x_block0)
return x_d6
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
import torchvision.models as models
self.original_model = models.mobilenet_v2( pretrained=True )
def forward(self, x):
features = [x]
for k, v in self.original_model.features._modules.items(): features.append( v(features[-1]) )
return features
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
return self.decoder( self.encoder(x) )
class UDepth(nn.Module):
def __init__(self, backend, n_bins=100, min_val=0.001, max_val=1, norm='linear'):
super(UDepth, self).__init__()
self.num_classes = n_bins
self.min_val = min_val
self.max_val = max_val
self.encoder = Encoder()
self.adaptive_bins_layer = mViT(48, n_query_channels=48, patch_size=16,
dim_out=n_bins,
embedding_dim=48, norm=norm)
self.decoder = Decoder()
self.conv_out = nn.Sequential(nn.Conv2d(48, n_bins, kernel_size=1, stride=1, padding=0),
nn.Softmax(dim=1))
def forward(self, x, **kwargs):
unet_out = self.decoder(self.encoder(x))
bin_widths_normed, range_attention_maps = self.adaptive_bins_layer(unet_out)
out = self.conv_out(range_attention_maps)
bin_widths = (self.max_val - self.min_val) * bin_widths_normed
bin_widths = nn.functional.pad(bin_widths, (1, 0), mode='constant', value=self.min_val)
bin_edges = torch.cumsum(bin_widths, dim=1)
centers = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:])
n, dout = centers.size()
centers = centers.view(n, dout, 1, 1)
pred = torch.sum(out * centers, dim=1, keepdim=True)
return bin_edges, pred
def get_1x_lr_params(self): # lr/10 learning rate
return self.encoder.parameters()
def get_10x_lr_params(self): # lr learning rate
modules = [self.decoder, self.adaptive_bins_layer, self.conv_out]
for m in modules:
yield from m.parameters()
@classmethod
def build(cls, n_bins, **kwargs):
basemodel_name = 'tf_efficientnet_b5_ap'
print('Loading base model ()...'.format(basemodel_name), end='')
basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True)
print('Done.')
# Remove last layer
basemodel.global_pool = nn.Identity()
basemodel.classifier = nn.Identity()
# Building Encoder-Decoder model
print('Building Encoder-Decoder model..', end='')
m = cls(basemodel, n_bins=n_bins, **kwargs)
print('Done.')
return m
if __name__ == '__main__':
model = UDepth.build(100)
x = torch.rand(2, 3, 480, 640)
bins, pred = model(x)
print(bins.shape, pred.shape)