forked from Pixelvision-VIP/VIPCUP2023_OLIVES_edit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
83 lines (69 loc) · 3.71 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import math
import os
def parse_option():
parser = argparse.ArgumentParser('argument for training')
parser.add_argument('--print_freq', type=int, default=10,
help='print frequency')
parser.add_argument('--save_freq', type=int, default=50,
help='save frequency')
parser.add_argument('--batch_size', type=int, default=128,
help='batch_size')
parser.add_argument('--num_workers', type=int, default=8,
help='num of workers to use')
parser.add_argument('--epochs', type=int, default=100,
help='number of training epochs')
parser.add_argument('--device', type=str, default='cuda:0')
# optimization
parser.add_argument('--learning_rate', type=float, default=0.05,
help='learning rate')
parser.add_argument('--patient_lambda', type=float, default=1,
help='learning rate')
parser.add_argument('--cluster_lambda', type=float, default=1,
help='learning rate')
parser.add_argument('--lr_decay_epochs', type=str, default='100',
help='where to decay lr, can be a list')
parser.add_argument('--lr_decay_rate', type=float, default=0.1,
help='decay rate for learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-4,
help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9,
help='momentum')
parser.add_argument('--train_csv_path', type=str, default='train data csv')
parser.add_argument('--test_csv_path', type=str, default='test data csv')
parser.add_argument('--train_image_path', type=str, default='train data csv')
parser.add_argument('--test_image_path', type=str, default='test data csv')
parser.add_argument('--parallel', type=int, default=1, help='data parallel')
parser.add_argument('--ncls', type=int, default=6, help='Number of Classes')
# model dataset
parser.add_argument('--model', type=str, default='resnet50')
parser.add_argument('--dataset', type=str, default='TREX_DME',
choices=[ 'OLIVES'], help='dataset')
parser.add_argument('--mean', type=str, help='mean of dataset in path in form of str tuple')
parser.add_argument('--std', type=str, help='std of dataset in path in form of str tuple')
parser.add_argument('--data_folder', type=str, default=None, help='path to custom dataset')
parser.add_argument('--size', type=int, default=128, help='parameter for RandomResizedCrop')
# temperature
parser.add_argument('--temp', type=float, default=0.07,
help='temperature for loss function')
opt = parser.parse_args()
# check if dataset is path that passed required arguments
if opt.dataset == 'path':
assert opt.data_folder is not None \
and opt.mean is not None \
and opt.std is not None
# set the path according to the environment
if opt.data_folder is None:
opt.data_folder = './datasets/'
opt.model_path = './save/{}_models'.format(opt.dataset)
iterations = opt.lr_decay_epochs.split(',')
opt.lr_decay_epochs = list([])
for it in iterations:
opt.lr_decay_epochs.append(int(it))
opt.model_name = '{}_lr_{}_decay_{}_bsz_{}_temp_{}'. \
format(opt.model, opt.learning_rate,
opt.weight_decay, opt.batch_size, opt.temp)
opt.save_folder = os.path.join(opt.model_path, opt.model_name)
if not os.path.isdir(opt.save_folder):
os.makedirs(opt.save_folder)
return opt