-
Notifications
You must be signed in to change notification settings - Fork 74
/
arguments.py
124 lines (112 loc) · 6.19 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
class ArgParser(object):
def __init__(self):
parser = argparse.ArgumentParser()
# Model related arguments
parser.add_argument('--id', default='',
help="a name for identifying the model")
parser.add_argument('--num_mix', default=2, type=int,
help="number of sounds to mix")
parser.add_argument('--arch_sound', default='unet7',
help="architecture of net_sound")
parser.add_argument('--arch_frame', default='resnet18dilated',
help="architecture of net_frame")
parser.add_argument('--arch_synthesizer', default='linear',
help="architecture of net_synthesizer")
parser.add_argument('--weights_sound', default='',
help="weights to finetune net_sound")
parser.add_argument('--weights_frame', default='',
help="weights to finetune net_frame")
parser.add_argument('--weights_synthesizer', default='',
help="weights to finetune net_synthesizer")
parser.add_argument('--num_channels', default=32, type=int,
help='number of channels')
parser.add_argument('--num_frames', default=1, type=int,
help='number of frames')
parser.add_argument('--stride_frames', default=1, type=int,
help='sampling stride of frames')
parser.add_argument('--img_pool', default='maxpool',
help="avg or max pool image features")
parser.add_argument('--img_activation', default='sigmoid',
help="activation on the image features")
parser.add_argument('--sound_activation', default='no',
help="activation on the sound features")
parser.add_argument('--output_activation', default='sigmoid',
help="activation on the output")
parser.add_argument('--binary_mask', default=1, type=int,
help="whether to use bianry masks")
parser.add_argument('--mask_thres', default=0.5, type=float,
help="threshold in the case of binary masks")
parser.add_argument('--loss', default='l1',
help="loss function to use")
parser.add_argument('--weighted_loss', default=0, type=int,
help="weighted loss")
parser.add_argument('--log_freq', default=1, type=int,
help="log frequency scale")
# Data related arguments
parser.add_argument('--num_gpus', default=1, type=int,
help='number of gpus to use')
parser.add_argument('--batch_size_per_gpu', default=32, type=int,
help='input batch size')
parser.add_argument('--workers', default=32, type=int,
help='number of data loading workers')
parser.add_argument('--num_val', default=-1, type=int,
help='number of images to evalutate')
parser.add_argument('--num_vis', default=40, type=int,
help='number of images to evalutate')
parser.add_argument('--audLen', default=65535, type=int,
help='sound length')
parser.add_argument('--audRate', default=11025, type=int,
help='sound sampling rate')
parser.add_argument('--stft_frame', default=1022, type=int,
help="stft frame length")
parser.add_argument('--stft_hop', default=256, type=int,
help="stft hop length")
parser.add_argument('--imgSize', default=224, type=int,
help='size of input frame')
parser.add_argument('--frameRate', default=8, type=float,
help='video frame sampling rate')
# Misc arguments
parser.add_argument('--seed', default=1234, type=int,
help='manual seed')
parser.add_argument('--ckpt', default='./ckpt',
help='folder to output checkpoints')
parser.add_argument('--disp_iter', type=int, default=20,
help='frequency to display')
parser.add_argument('--eval_epoch', type=int, default=1,
help='frequency to evaluate')
self.parser = parser
def add_train_arguments(self):
parser = self.parser
parser.add_argument('--mode', default='train',
help="train/eval")
parser.add_argument('--list_train',
default='data/train.csv')
parser.add_argument('--list_val',
default='data/val.csv')
parser.add_argument('--dup_trainset', default=100, type=int,
help='duplicate so that one epoch has more iters')
# optimization related arguments
parser.add_argument('--num_epoch', default=100, type=int,
help='epochs to train for')
parser.add_argument('--lr_frame', default=1e-4, type=float, help='LR')
parser.add_argument('--lr_sound', default=1e-3, type=float, help='LR')
parser.add_argument('--lr_synthesizer',
default=1e-3, type=float, help='LR')
parser.add_argument('--lr_steps',
nargs='+', type=int, default=[40, 60],
help='steps to drop LR in epochs')
parser.add_argument('--beta1', default=0.9, type=float,
help='momentum for sgd, beta1 for adam')
parser.add_argument('--weight_decay', default=1e-4, type=float,
help='weights regularizer')
self.parser = parser
def print_arguments(self, args):
print("Input arguments:")
for key, val in vars(args).items():
print("{:16} {}".format(key, val))
def parse_train_arguments(self):
self.add_train_arguments()
args = self.parser.parse_args()
self.print_arguments(args)
return args