-
Notifications
You must be signed in to change notification settings - Fork 0
/
arguments.py
110 lines (87 loc) · 3.7 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
import os
import torch
import numpy as np
import torch
import random
import re
import yaml
import shutil
import warnings
import datetime
import csv
class Namespace(object):
def __init__(self, somedict):
for key, value in somedict.items():
assert isinstance(key, str) and re.match("[A-Za-z_-]", key)
if isinstance(value, dict):
self.__dict__[key] = Namespace(value)
else:
self.__dict__[key] = value
def __getattr__(self, attribute):
raise AttributeError(f"Can not find {attribute} in namespace. Please write {attribute} in your config file(xxx.yaml)!")
def set_deterministic(seed):
# seed by default is None
if seed is not None:
print(f"Deterministic with seed = {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config-file', default='configs/spectral_resnet_mlp1000_norelu_cifar10_lr003_mu1.yaml', type=str)
parser.add_argument('--debug', action='store_true')
parser.add_argument('--log_freq', type=int, default=20)
parser.add_argument('--workers', type=int, default=32)
parser.add_argument('--test_bs', type=int, default=80)
parser.add_argument('--download', action='store_true', help="if can't find dataset, download from web")
parser.add_argument('--data_dir', type=str, default='PATH_TO_DATASET')
parser.add_argument('--dist_url', type=str, default='tcp://localhost:10001')
parser.add_argument('--log_dir', type=str, default='./log/spectral')
parser.add_argument('--ckpt_dir', type=str, default='~/.cache/')
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--eval_from', type=str, default=None)
parser.add_argument('--hide_progress', action='store_true')
parser.add_argument('--vis_freq', type=int, default=10)
args = parser.parse_args()
with open(args.config_file, 'r') as f:
for key, value in Namespace(yaml.load(f, Loader=yaml.FullLoader)).__dict__.items():
vars(args)[key] = value
if args.debug:
if args.train:
args.train.batch_size = 2
args.train.num_epochs = 1
args.train.stop_at_epoch = 1
if args.eval:
args.eval.batch_size = 2
args.eval.num_epochs = 1 # train only one epoch
args.dataset.num_workers = 0
assert not None in [args.log_dir, args.data_dir, args.ckpt_dir, args.name]
args.log_dir = os.path.join(args.log_dir, 'in-progress-'+'{}'.format(datetime.date.today())+args.name+'-log_freq:{}'.format(args.log_freq))
os.makedirs(args.log_dir, exist_ok=True)
print(f'creating file {args.log_dir}')
os.makedirs(args.ckpt_dir, exist_ok=True)
shutil.copy2(args.config_file, args.log_dir)
set_deterministic(args.seed)
vars(args)['aug_kwargs'] = {
'name': args.model.name,
'image_size': args.dataset.image_size
}
vars(args)['dataset_kwargs'] = {
'dataset':args.dataset.name,
'data_dir': args.data_dir,
'download':args.download,
}
vars(args)['dataloader_kwargs'] = {
'drop_last': True,
'pin_memory': True,
'num_workers': args.dataset.num_workers,
}
log_file = open(os.path.join(args.log_dir, 'log.csv'), mode='w')
fieldnames = ['epoch', 'loss', 'lr', 'test_loss']
log_writer = csv.DictWriter(log_file, fieldnames=fieldnames)
log_writer.writeheader()
return args, log_file, log_writer