-
Notifications
You must be signed in to change notification settings - Fork 54
/
configs.py
85 lines (77 loc) · 2.27 KB
/
configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import inspect
from collections import OrderedDict
from dataclasses import dataclass
from src.training.params import get_default_params
@dataclass
class Config:
train_data = None
val_data = None
train_num_samples = None
val_num_samples = None
dataset_type = "auto"
dataset_resampled = False
csv_separator = "\t"
csv_img_key = "filepath"
csv_caption_key = "title"
imagenet_val = "/datasets01/imagenet_full_size/061417/val"
imagenet_v2 = None
logs = "./logs/"
log_local = False
name = None
workers = 8
batch_size = 64
epochs = 32
lr = None
beta1 = None
beta2 = None
eps = None
wd = 0.2
warmup = 2000 # 10000
min_ratio = 0.
skip_scheduler = False
save_frequency = 1
save_most_recent = True # False
zeroshot_frequency = 1
val_frequency = 1
resume = None
precision = "amp"
clip_model = "CLIP"
model = "RN50"
pretrained = ''
pretrained_image = False
grad_checkpointing = False
local_loss = False
gather_with_grad = False
force_quick_gelu = False
torchscript = False
trace = False
dist_url = "env://"
dist_backend = "nccl"
debug = False
report_to = ""
ddp_static_graph = False
no_set_device_rank = False
seed = 0
norm_gradient_clip = None
def __post_init__(self):
args = self
args.name = self.__class__.__name__
args.output_dir = os.path.join(args.logs, args.name)
for name, val in get_default_params(args.model).items():
if getattr(args, name) is None:
setattr(args, name, val)
def search_config(config_name):
import importlib
all_configs = {}
for code in os.listdir("config"):
if code.endswith(".py"):
module = importlib.import_module(f"config.{code[:-3]}")
for _config_name in dir(module):
if _config_name in ["Config"] or _config_name.startswith("__") or _config_name.startswith("run_config"):
continue
if _config_name not in all_configs:
all_configs[_config_name] = module
print(f"launching {config_name} from {all_configs[config_name].__file__}")
config = getattr(all_configs[config_name], config_name)()
return config