-
Notifications
You must be signed in to change notification settings - Fork 1
/
build.py
123 lines (90 loc) · 3.08 KB
/
build.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# 常用资源库
import pandas as pd
import numpy as np
EPS = 1e-9#
import os,glob,numbers
# 图像处理
import math,cv2,random
from PIL import Image, ImageFile, ImageOps, ImageFilter
ImageFile.LOAD_TRUNCATED_IMAGES = True
# 图像显示
from matplotlib import pyplot as plt
plt.rcParams['image.cmap'] = 'gray'
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import functional as f
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import sys
sys.path.append('.')
sys.path.append('..')
from utils import *
from nets import *
from scls import *
#start#
class SeqNet(nn.Module):#Supervised contrastive learning segmentation network
__name__ = 'scls'
def __init__(self, type_net, type_seg, num_emb=128):
super(SeqNet, self).__init__()
self.fcn = eval(type_net+'(num_emb=num_emb)')#build_model(cfg['net']['fcn'])
self.seg = eval(type_seg+'(inp_c=32)')#build_model(cfg['net']['seg'])
self.__name__ = '{}X{}'.format(self.fcn.__name__, self.seg.__name__)
tmp = {}
def forward(self, x):
aux = self.fcn(x)
self.feat = self.fcn.feat
out = self.seg(self.feat)
self.pred = out
if self.training:
if isinstance(aux, (tuple, list)):
return [self.pred, aux[0], aux[1]]
else:
return [self.pred, aux]
return self.pred
def build_model(type_net='hrdo', type_seg='', type_loss='sim2', type_arch='', num_emb=128):
model = eval(type_net+'(num_emb=num_emb)')
# raise NotImplementedError(f'--> Unknown type_net: {type_net}')
if type_seg!='':
model = SeqNet(type_net, type_seg, num_emb=num_emb)
return model
#end#
# 把形态学模块放到前一层
import time
if __name__ == '__main__':
num_emb = 128
x = torch.rand(8,1,128,128)
# cfg = read_config('configs/siam_unet_unet.ini')
# print(cfg)
# net = build_model('sunet', 'munet', 'sim2', 'roma', num_emb)
# net = build_model('lunet', 'munet', 'sim2', 'siam', num_emb)
# net = build_model('smf', 'lunet', 'sim2', 'siam', num_emb)
net = build_model('smf', 'lunet', 'sim2', '', num_emb)
# net = build_model('dmf32', 'munet', 'sim2', '', num_emb)
# net = build_model('sunet', 'munet', 'arc', 'siam', num_emb)
# net = build_model('lunet', 'lunet', '', '', num_emb)
# net = build_model('spun', '', '', '', num_emb)
# net.eval()
st = time.time()
ys = net(x)
print(net.__name__, 'Time:', time.time() - st)
for y in ys:
print(y.shape, y.min().item(), y.max().item())
# # net.train()
# for key, item in net.tmp.items():
# print(key, item.shape)
# sampler = MLPSampler(top=4, low=0, mode='half')
# # net.train()
# st = time.time()
# l = net.regular(sampler, torch.rand_like(x), torch.rand_like(x))
# # print('Regular:', l.item())
# print(net.__name__, 'Time:', time.time() - st)
# print('feat:', net.feat.shape, net.proj.shape)
# # plot4(emb=net.feat, path_save='emb.png')
# # plt.show()
# st = time.time()
# l = net.constraint(aux=x, fun=nn.MSELoss())
# print('constraint:', l.item())
# print(net.__name__, 'Time:', time.time() - st)
# plot(net.emb)
print('Params model:',sum(p.numel() for p in net.parameters() if p.requires_grad))