-
Notifications
You must be signed in to change notification settings - Fork 14
/
config.py
226 lines (207 loc) · 8.83 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
## Copyright (C) 2019, Huan Zhang <huan@huan-zhang.com>
## Hongge Chen <chenhg@mit.edu>
## Chaowei Xiao <xiaocw@umich.edu>
##
## This program is licenced under the BSD 2-Clause License,
## contained in the LICENCE file in this directory.
##
import os
import json
import glob
import copy
import importlib
import torch
import numpy as np
from datasets import loaders
# from model_defs import add_feature_subsample, remove_feature_subsample
from model_defs import convert_conv2d_dense, save_checkpoint, load_checkpoint_to_mlpany
# Helper function to find a file with closest match
def get_file_close(filename, ext, load = True):
if ext[0] == ".":
ext = ext[1:]
if not load:
return filename + "." + ext
filelist = glob.glob(filename + "*." + ext +"*")
if len(filelist) == 0:
raise OSError("File " + filename + " not found!")
# FIXME
if "last" in filelist[0]:
filelist = filelist[1:]
if len(filelist) > 1:
filelist = sorted(filelist, key = len)
print("Warning! Multiple files matches ID {}: {}".format(filename, filelist))
for f in filelist:
# return the best model if we have it
if "best" in f:
return f
return filelist[0]
def update_dict(d, u, show_warning = False):
for k, v in u.items():
if k not in d and show_warning:
print("\033[91m Warning: key {} not found in config. Make sure to double check spelling and config option name. \033[0m".format(k))
if isinstance(v, dict):
d[k] = update_dict(d.get(k, {}), v, show_warning)
else:
d[k] = v
return d
def load_config(args):
print("loading config file: {}".format(args.config))
with open("defaults.json") as f:
config = json.load(f)
with open(args.config) as f:
update_dict(config, json.load(f))
if args.overrides_dict:
print("overriding parameters: \033[93mPlease check these parameters carefully.\033[0m")
print("\033[93m" + str(args.overrides_dict) + "\033[0m")
update_dict(config, args.overrides_dict, True)
subset_models = []
# remove ignored models
for model_config in config["models"]:
if "ignore" in model_config and model_config["ignore"]:
continue
else:
subset_models.append(model_config)
config["models"] = subset_models
# repeat models if the "repeated" field is found
repeated_models = []
for model_config in config["models"]:
if "repeats" in model_config:
for i in range(model_config["repeats"]):
c = copy.deepcopy(model_config)
c["repeats_idx"] = i + 1
for k, v in c.items():
if isinstance(v, str):
if v == "##":
c[k] = i + 1
if "@@" in v:
c[k] = c[k].replace("@@", str(i+1))
repeated_models.append(c)
else:
repeated_models.append(model_config)
config["models"] = repeated_models
# only use a subset of models, if specified
if args.model_subset:
subset_models = []
for i in args.model_subset:
subset_models.append(config["models"][i])
config["models"] = subset_models
if args.path_prefix:
config["path_prefix"] = args.path_prefix
return config
# Load dataset loader based on config file
def config_dataloader(config, **kwargs):
return loaders[config["dataset"]](**kwargs)
# Unified naming rule for model files, bound files, ensemble weights and others
# To change format of saved model names, etc, only change here
def get_path(config, model_id, path_name, **kwargs):
if path_name == "model":
model_file = get_file_close(os.path.join(config["path_prefix"], config["models_path"], model_id), "pth", **kwargs)
os.makedirs(os.path.join(config["path_prefix"], config["models_path"]), exist_ok = True)
return model_file
if path_name == "best_model":
model_file = os.path.join(config["path_prefix"], config["models_path"], model_id + "_best.pth")
os.makedirs(os.path.join(config["path_prefix"], config["models_path"]), exist_ok = True)
return model_file
if path_name == "train_log":
model_file = get_path(config, model_id, "model", load = False)
os.makedirs(os.path.join(config["path_prefix"], config["models_path"]), exist_ok = True)
return model_file.replace(".pth", ".log")
if path_name == "eval_log":
model_file = get_path(config, model_id, "model", load = False)
os.makedirs(os.path.join(config["path_prefix"], config["models_path"]), exist_ok = True)
return model_file.replace(".pth", "_test.log")
else:
raise RuntimeError("Unsupported path " + path_name)
# Return config of a single model
def get_model_config(config, model_id):
for model_config in config["models"]:
if model_config["model_id"] == model_id:
return model_config
# Load all models based on config file
def config_modelloader(config, load_pretrain = False, cuda = False):
# load the required modelfile
model_module = importlib.import_module(os.path.splitext(config["model_def"])[0])
models = []
model_names = []
for model_config in config["models"]:
if "ignore" in model_config and model_config["ignore"]:
continue
model_id = model_config["model_id"]
model_names.append(model_id)
model_class = getattr(model_module, model_config["model_class"])
model_params = model_config["model_params"]
m = model_class(**model_params)
if cuda:
m.cuda()
if load_pretrain:
model_file = get_path(config, model_id, "model")
#model_file += "_pretrain"
print("Loading model file", model_file)
checkpoint = torch.load(model_file)
if isinstance(checkpoint["state_dict"], list):
checkpoint["state_dict"] = checkpoint["state_dict"][0]
new_state_dict = {}
for k in checkpoint["state_dict"].keys():
if "prev" in k:
pass
else:
new_state_dict[k] = checkpoint["state_dict"][k]
checkpoint["state_dict"] = new_state_dict
"""
state_dict = m.state_dict()
state_dict.update(checkpoint["state_dict"])
m.load_state_dict(state_dict)
print(checkpoint["state_dict"]["__mask_layer.weight"])
"""
m.load_state_dict(checkpoint["state_dict"])
# print(m)
models.append(m)
return models, model_names
def config_modelloader_and_convert2mlp(config, load_pretrain = True):
# load the required modelfile
model_module = importlib.import_module(os.path.splitext(config["model_def"])[0])
models = []
model_names = []
for model_config in config["models"]:
if "ignore" in model_config and model_config["ignore"]:
continue
model_id = model_config["model_id"]
model_names.append(model_id)
model_class = getattr(model_module, model_config["model_class"])
model_params = model_config["model_params"]
m = model_class(**model_params)
# m.cuda()
if load_pretrain:
model_file = get_path(config, model_id, "model")
#model_file += "_pretrain"
print("Loading model file", model_file)
checkpoint = torch.load(model_file)
if isinstance(checkpoint["state_dict"], list):
checkpoint["state_dict"] = checkpoint["state_dict"][0]
new_state_dict = {}
for k in checkpoint["state_dict"].keys():
if "prev" in k:
pass
else:
new_state_dict[k] = checkpoint["state_dict"][k]
checkpoint["state_dict"] = new_state_dict
"""
state_dict = m.state_dict()
state_dict.update(checkpoint["state_dict"])
m.load_state_dict(state_dict)
# print(checkpoint["state_dict"]["__mask_layer.weight"])
"""
m.load_state_dict(checkpoint["state_dict"])
print("convert to dense w")
dense_m = convert_conv2d_dense(m)
in_dim = model_params["in_dim"]
in_ch = model_params["in_ch"]
tmp = dense_m(torch.zeros(1, in_ch, in_dim, in_dim))
dense_checkpoint_file = model_file.split(".pth")[0] + "_dense.pth"
print("save dense checkpoint to {}".format(dense_checkpoint_file))
save_checkpoint(dense_m, dense_checkpoint_file )
mlp_m = load_checkpoint_to_mlpany(dense_checkpoint_file)
# print(m)
# models.append(m)
models.append(mlp_m)
return models, model_names