-
Notifications
You must be signed in to change notification settings - Fork 6
/
config.py
273 lines (243 loc) · 9.17 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import torch
from yacs.config import CfgNode as CN
_C = CN()
_C.BASE_CONFIG = "" # will use this file as base config and cover self's unique changes
_C.DEVICE = None
_C.RANK = 0
_C.WORLD_SIZE = 1
_C.LOCAL_RANK = 0
_C.SEED = None
_C.OUTPUT = ""
_C.EVALUATOR = ""
_C.EVALUATION_DOMAIN = None # support 'float'/'bn_merged'/'quant'
_C.MODEL = CN()
_C.MODEL.ARCH = ""
_C.MODEL.CHECKPOINT = ""
_C.MODEL.PRETRAINED = ""
_C.MODEL.NUM_CLASSES = 0
_C.MODEL.INPUTSHAPE = [-1, -1] # h, w
_C.TRAIN = CN()
_C.TRAIN.USE_DDP = False
_C.TRAIN.SYNC_BN = False
_C.TRAIN.LINEAR_EVAL = False
_C.TRAIN.WARMUP_FC = False # load from a linear evaluation model
_C.TRAIN.RESUME = ""
_C.TRAIN.EPOCHS = 0
_C.TRAIN.START_EPOCH = 0
_C.TRAIN.PRETRAIN = ""
_C.TRAIN.DATASET = ""
_C.TRAIN.LABEL_SMOOTHING = 0.0
_C.TRAIN.BATCH_SIZE = 1 # per-gpu
_C.TRAIN.NUM_WORKERS = 0
_C.TRAIN.PRINT_FREQ = 1
_C.TRAIN.LR_SCHEDULER = CN()
_C.TRAIN.LR_SCHEDULER.TYPE = None # support cosine / step / multiStep
_C.TRAIN.LR_SCHEDULER.WARMUP_EPOCHS = 0
_C.TRAIN.LR_SCHEDULER.WARMUP_LR = 0.0
_C.TRAIN.LR_SCHEDULER.BASE_LR = 0.0
_C.TRAIN.LR_SCHEDULER.FC_LR = 0.0 # specific learning rate for final fc layer
_C.TRAIN.LR_SCHEDULER.MIN_LR = 0.0
_C.TRAIN.LR_SCHEDULER.SPECIFIC_LRS = []
_C.TRAIN.LR_SCHEDULER.DECAY_MILESTONES = []
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCH = 0
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
_C.TRAIN.OPTIMIZER = CN()
_C.TRAIN.OPTIMIZER.NAME = ""
_C.TRAIN.OPTIMIZER.EPS = None
_C.TRAIN.OPTIMIZER.BETAS = None # (0.9, 0.999)
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.0
_C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 0.0
_C.TRAIN.LOSS = CN()
_C.TRAIN.LOSS.CRITERION = CN()
_C.TRAIN.LOSS.REGULARIZER = CN()
_C.TRAIN.LOSS.LAMBDA = 0.0 # a ratio controls the importance of regularizer
_C.TRAIN.LOSS.CRITERION.NAME = "" # support CrossEntropy / LpLoss
_C.TRAIN.LOSS.CRITERION.LPLOSS = CN()
_C.TRAIN.LOSS.CRITERION.LPLOSS.P = 2.0
_C.TRAIN.LOSS.CRITERION.LPLOSS.REDUCTION = "none"
_C.TRAIN.LOSS.REGULARIZER.NAME = "" # support PACT/SLIMMING
_C.TRAIN.RUNNER = CN()
_C.TRAIN.RUNNER.NAME = "" # support default
_C.TRAIN.METER = CN()
_C.TRAIN.METER.NAME = "" # support ACC / MAP / MIOU
_C.TRAIN.METER.ACC = CN()
_C.TRAIN.METER.ACC.TOPK = []
_C.TRAIN.METER.MAP = CN()
_C.TRAIN.METER.MIOU = CN()
_C.AUG = CN()
_C.AUG.TRAIN = CN()
_C.AUG.TRAIN.RANDOMRESIZEDCROP = CN()
_C.AUG.TRAIN.RANDOMRESIZEDCROP.ENABLE = False
# _C.AUG.TRAIN.RANDOMRESIZEDCROP.SIZE = MODEL.INPUT_SHAPE
_C.AUG.TRAIN.RANDOMRESIZEDCROP.SCALE = (0.08, 1.0)
_C.AUG.TRAIN.RANDOMRESIZEDCROP.INTERPOLATION = "bilinear"
_C.AUG.TRAIN.RESIZE = CN()
_C.AUG.TRAIN.RESIZE.ENABLE = False
_C.AUG.TRAIN.RESIZE.SIZE = (-1, -1) # h, w
_C.AUG.TRAIN.RESIZE.KEEP_RATIO = True
_C.AUG.TRAIN.RESIZE.INTERPOLATION = "bilinear"
_C.AUG.TRAIN.HORIZONTAL_FLIP = CN()
_C.AUG.TRAIN.HORIZONTAL_FLIP.PROB = 0.0
_C.AUG.TRAIN.VERTICAL_FLIP = CN()
_C.AUG.TRAIN.VERTICAL_FLIP.PROB = 0.0
_C.AUG.TRAIN.RANDOMCROP = CN()
_C.AUG.TRAIN.RANDOMCROP.ENABLE = False
# _C.AUG.TRAIN.RANDOMCROP.SIZE = MODEL.INPUT_SHAPE
_C.AUG.TRAIN.RANDOMCROP.PADDING = 0
_C.AUG.TRAIN.CENTERCROP = CN()
_C.AUG.TRAIN.CENTERCROP.ENABLE = False
_C.AUG.TRAIN.COLOR_JITTER = CN()
_C.AUG.TRAIN.COLOR_JITTER.PROB = 0.0
_C.AUG.TRAIN.COLOR_JITTER.BRIGHTNESS = 0.4
_C.AUG.TRAIN.COLOR_JITTER.CONTRAST = 0.4
_C.AUG.TRAIN.COLOR_JITTER.SATURATION = 0.2
_C.AUG.TRAIN.COLOR_JITTER.HUE = 0.1
_C.AUG.TRAIN.AUTO_AUGMENT = CN()
_C.AUG.TRAIN.AUTO_AUGMENT.ENABLE = False
_C.AUG.TRAIN.AUTO_AUGMENT.POLICY = 0.0
_C.AUG.TRAIN.RANDOMERASE = CN()
_C.AUG.TRAIN.RANDOMERASE.PROB = 0.0
_C.AUG.TRAIN.RANDOMERASE.MODE = "const"
_C.AUG.TRAIN.RANDOMERASE.MAX_COUNT = None
_C.AUG.TRAIN.MIX = CN() # mixup & cutmix
_C.AUG.TRAIN.MIX.PROB = 0.0
_C.AUG.TRAIN.MIX.MODE = "batch"
_C.AUG.TRAIN.MIX.SWITCH_MIXUP_CUTMIX_PROB = 0.0
_C.AUG.TRAIN.MIX.MIXUP_ALPHA = 0.0
_C.AUG.TRAIN.MIX.CUTMIX_ALPHA = 0.0
_C.AUG.TRAIN.MIX.CUTMIX_MIXMAX = None
_C.AUG.TRAIN.NORMLIZATION = CN()
_C.AUG.TRAIN.NORMLIZATION.MEAN = []
_C.AUG.TRAIN.NORMLIZATION.STD = []
_C.AUG.EVALUATION = CN()
_C.AUG.EVALUATION.RESIZE = CN()
_C.AUG.EVALUATION.RESIZE.ENABLE = False
_C.AUG.EVALUATION.RESIZE.SIZE = (-1, -1) # h, w
_C.AUG.EVALUATION.RESIZE.KEEP_RATIO = True
_C.AUG.EVALUATION.RESIZE.INTERPOLATION = "bilinear"
_C.AUG.EVALUATION.CENTERCROP = CN()
_C.AUG.EVALUATION.CENTERCROP.ENABLE = False
_C.AUG.EVALUATION.NORMLIZATION = CN()
_C.AUG.EVALUATION.NORMLIZATION.MEAN = []
_C.AUG.EVALUATION.NORMLIZATION.STD = []
_C.QUANT = CN()
_C.QUANT.TYPE = "" # support 'qat' / ptq
_C.QUANT.BIT_ASSIGNER = CN()
_C.QUANT.BIT_ASSIGNER.NAME = None # support 'HAWQ'
_C.QUANT.BIT_ASSIGNER.W_BIT_CHOICES = [2, 4, 8]
_C.QUANT.BIT_ASSIGNER.A_BIT_CHOICES = [2, 4, 8, 16]
_C.QUANT.BIT_ASSIGNER.HAWQ = CN()
_C.QUANT.BIT_ASSIGNER.HAWQ.EIGEN_TYPE = "avg" # support 'max' / 'avg'
_C.QUANT.BIT_ASSIGNER.HAWQ.SENSITIVITY_CALC_ITER_NUM = 50
_C.QUANT.BIT_ASSIGNER.HAWQ.LIMITATION = CN()
_C.QUANT.BIT_ASSIGNER.HAWQ.LIMITATION.BIT_ASCEND_SORT = False
_C.QUANT.BIT_ASSIGNER.HAWQ.LIMITATION.BIT_WIDTH_COEFF = 1e10
_C.QUANT.BIT_ASSIGNER.HAWQ.LIMITATION.BOPS_COEFF = 1e10
_C.QUANT.BIT_CONFIG = (
[]
) # a mapping, key is layer_name, value is {"w":w_bit, "a":a_bit}
_C.QUANT.FOLD_BN = False
_C.QUANT.W = CN()
_C.QUANT.W.QUANTIZER = None # support "LSQ" / "DOREFA"
_C.QUANT.W.BIT = 8
_C.QUANT.W.BIT_RANGE = [2,9] # left include, right exclude, default 2~8
_C.QUANT.W.SYMMETRY = True
_C.QUANT.W.GRANULARITY = (
"channelwise" # support "layerwise"/"channelwise" currently, default is channelwise
)
_C.QUANT.W.OBSERVER_METHOD = CN()
_C.QUANT.W.OBSERVER_METHOD.NAME = (
"MINMAX" # support "MINMAX"/"MSE" currently, default is MINMAX
)
_C.QUANT.W.OBSERVER_METHOD.ALPHA = 0.0001 # support percentile
_C.QUANT.W.OBSERVER_METHOD.BINS = 2049 # support kl_histogram
_C.QUANT.A = CN()
_C.QUANT.A.BIT = 8
_C.QUANT.A.BIT_RANGE = [4,9] # left include, right exclude, default 4~8
_C.QUANT.A.QUANTIZER = None # support "LSQ" / "DOREFA"
_C.QUANT.A.SYMMETRY = False
_C.QUANT.A.GRANULARITY = (
"layerwise" # support "layerwise"/"channelwise" currently, default is layerwise
)
_C.QUANT.A.OBSERVER_METHOD = CN()
_C.QUANT.A.OBSERVER_METHOD.NAME = (
"MINMAX" # support "MINMAX"/"MSE" currently, default is MINMAX
)
_C.QUANT.A.OBSERVER_METHOD.ALPHA = 0.0001 # support percentile
_C.QUANT.A.OBSERVER_METHOD.BINS = 2049 # support kl_histogram
_C.QUANT.CALIBRATION = CN()
_C.QUANT.CALIBRATION.PATH = ""
_C.QUANT.CALIBRATION.TYPE = "" # support tarfile / python_module
_C.QUANT.CALIBRATION.MODULE_PATH = "" # the import path of calibration dataset
_C.QUANT.CALIBRATION.SIZE = 0
_C.QUANT.CALIBRATION.BATCHSIZE = 1
_C.QUANT.CALIBRATION.NUM_WORKERS = 0
_C.QUANT.FINETUNE = CN()
_C.QUANT.FINETUNE.ENABLE = False
_C.QUANT.FINETUNE.METHOD = ""
_C.QUANT.FINETUNE.BATCHSIZE = 32
_C.QUANT.FINETUNE.ITERS_W = 0
_C.QUANT.FINETUNE.ITERS_A = 0
_C.QUANT.FINETUNE.BRECQ = CN()
_C.QUANT.FINETUNE.BRECQ.KEEP_GPU = True
_C.SSL = CN()
_C.SSL.TYPE = None # support mocov2
_C.SSL.SETTING = CN()
_C.SSL.SETTING.DIM = 128 # output dimension for the MLP head
_C.SSL.SETTING.HIDDEN_DIM = 2048 # hidden dimension for the MLP head
_C.SSL.SETTING.T = 0.07 # temperature for InfoNCE loss
_C.SSL.SETTING.MOCO_K = 65536 # size of memory bank for MoCo
_C.SSL.SETTING.MOMENTUM = 0.999 # MoCo momentum of updating key encoder
_C.SSL.SETTING.MLP = True # whether to use MLP head, default True
def get_config(args):
"""Get a yacs CfgNode object with default values."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
config = _C.clone()
config.defrost()
if hasattr(args, "checkpoint"):
config.MODEL.CHECKPOINT = args.checkpoint
if hasattr(args, "pretrained"):
config.MODEL.PRETRAINED = args.pretrained
if hasattr(args, "calibration"):
config.QUANT.CALIBRATION.PATH = args.calibration
if hasattr(args, "batch_size"):
config.QUANT.CALIBRATION.BATCHSIZE = args.batch_size
if hasattr(args, "num_workers"):
config.QUANT.CALIBRATION.NUM_WORKERS = args.num_workers
config.TRAIN.NUM_WORKERS = args.num_workers
if hasattr(args, "eval_domain"):
config.EVALUATION_DOMAIN = args.eval_domain
if hasattr(args, "print_freq"):
config.TRAIN.PRINT_FREQ = args.print_freq
if hasattr(args, "output"):
config.OUTPUT = args.output
config.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# get config depend chain files recursively
config_depends = []
tmp_config = _C.clone()
tmp_config.defrost()
tmp_config.merge_from_file(args.config)
config_depends.append(args.config)
while tmp_config.BASE_CONFIG:
next_config = tmp_config.BASE_CONFIG
config_depends.append(next_config)
tmp_config.BASE_CONFIG = ""
tmp_config.merge_from_file(next_config)
# tmp_config's merge order is reversed so can't use it directly
for conf_path in reversed(config_depends):
config.merge_from_file(conf_path)
config.freeze()
return config
def update_config(config, key, value):
config.defrost()
keys = key.split(".")
def _set_config_attr(cfg, keys, value):
if len(keys) > 1:
cfg = getattr(cfg, keys[0].upper())
_set_config_attr(cfg, keys[1:], value)
else:
setattr(cfg, keys[0].upper(), value)
_set_config_attr(config, keys, value)
config.freeze()
return config