-
Notifications
You must be signed in to change notification settings - Fork 38
/
train.py
319 lines (271 loc) · 10.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import argparse
import itertools
from tensorboardX import SummaryWriter
import torch
from torch import nn, optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
import yaml
from bisect import bisect
from visdialch.data.dataset import VisDialDataset
from visdialch.encoders import Encoder
from visdialch.decoders import Decoder
from visdialch.metrics import SparseGTMetrics, NDCG
from visdialch.model import EncoderDecoderModel
from visdialch.utils.checkpointing import CheckpointManager, load_checkpoint
parser = argparse.ArgumentParser()
parser.add_argument(
"--config-yml",
default="configs/lf_disc_faster_rcnn_x101.yml",
help="Path to a config file listing reader, model and solver parameters.",
)
parser.add_argument(
"--train-json",
default="data/visdial_1.0_train.json",
help="Path to json file containing VisDial v1.0 training data.",
)
parser.add_argument(
"--val-json",
default="data/visdial_1.0_val.json",
help="Path to json file containing VisDial v1.0 validation data.",
)
parser.add_argument(
"--val-dense-json",
default="data/visdial_1.0_val_dense_annotations.json",
help="Path to json file containing VisDial v1.0 validation dense ground "
"truth annotations.",
)
parser.add_argument_group(
"Arguments independent of experiment reproducibility"
)
parser.add_argument(
"--gpu-ids",
nargs="+",
type=int,
default=0,
help="List of ids of GPUs to use.",
)
parser.add_argument(
"--cpu-workers",
type=int,
default=4,
help="Number of CPU workers for dataloader.",
)
parser.add_argument(
"--overfit",
action="store_true",
help="Overfit model on 5 examples, meant for debugging.",
)
parser.add_argument(
"--validate",
action="store_true",
help="Whether to validate on val split after every epoch.",
)
parser.add_argument(
"--in-memory",
action="store_true",
help="Load the whole dataset and pre-extracted image features in memory. "
"Use only in presence of large RAM, atleast few tens of GBs.",
)
parser.add_argument_group("Checkpointing related arguments")
parser.add_argument(
"--save-dirpath",
default="checkpoints/",
help="Path of directory to create checkpoint directory and save "
"checkpoints.",
)
parser.add_argument(
"--load-pthpath",
default="",
help="To continue training, path to .pth file of saved checkpoint.",
)
# For reproducibility.
# Refer https://pytorch.org/docs/stable/notes/randomness.html
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# =============================================================================
# INPUT ARGUMENTS AND CONFIG
# =============================================================================
args = parser.parse_args()
# keys: {"dataset", "model", "solver"}
config = yaml.load(open(args.config_yml))
if isinstance(args.gpu_ids, int):
args.gpu_ids = [args.gpu_ids]
device = (
torch.device("cuda", args.gpu_ids[0])
if args.gpu_ids[0] >= 0
else torch.device("cpu")
)
torch.cuda.set_device(device)
# Print config and args.
print(yaml.dump(config, default_flow_style=False))
for arg in vars(args):
print("{:<20}: {}".format(arg, getattr(args, arg)))
# =============================================================================
# SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER, SCHEDULER
# =============================================================================
train_dataset = VisDialDataset(
config["dataset"],
args.train_json,
overfit=args.overfit,
in_memory=args.in_memory,
num_workers=args.cpu_workers,
return_options=True if config["model"]["decoder"] == "disc" else False,
add_boundary_toks=False if config["model"]["decoder"] == "disc" else True,
)
train_dataloader = DataLoader(
train_dataset,
batch_size=config["solver"]["batch_size"],
num_workers=args.cpu_workers,
shuffle=True,
)
val_dataset = VisDialDataset(
config["dataset"],
args.val_json,
args.val_dense_json,
overfit=args.overfit,
in_memory=args.in_memory,
num_workers=args.cpu_workers,
return_options=True,
add_boundary_toks=False if config["model"]["decoder"] == "disc" else True,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=config["solver"]["batch_size"]
if config["model"]["decoder"] == "disc"
else 5,
num_workers=args.cpu_workers,
)
# Pass vocabulary to construct Embedding layer.
encoder = Encoder(config["model"], train_dataset.vocabulary)
decoder = Decoder(config["model"], train_dataset.vocabulary)
print("Encoder: {}".format(config["model"]["encoder"]))
print("Decoder: {}".format(config["model"]["decoder"]))
# Share word embedding between encoder and decoder.
decoder.word_embed = encoder.word_embed
# Wrap encoder and decoder in a model.
model = EncoderDecoderModel(encoder, decoder).to(device)
if -1 not in args.gpu_ids:
model = nn.DataParallel(model, args.gpu_ids)
# Loss function.
if config["model"]["decoder"] == "disc":
criterion = nn.CrossEntropyLoss()
elif config["model"]["decoder"] == "gen":
criterion = nn.CrossEntropyLoss(
ignore_index=train_dataset.vocabulary.PAD_INDEX
)
else:
raise NotImplementedError
if config["solver"]["training_splits"] == "trainval":
iterations = (len(train_dataset) + len(val_dataset)) // config["solver"][
"batch_size"
] + 1
else:
iterations = len(train_dataset) // config["solver"]["batch_size"] + 1
def lr_lambda_fun(current_iteration: int) -> float:
"""Returns a learning rate multiplier.
Till `warmup_epochs`, learning rate linearly increases to `initial_lr`,
and then gets multiplied by `lr_gamma` every time a milestone is crossed.
"""
current_epoch = float(current_iteration) / iterations
if current_epoch <= config["solver"]["warmup_epochs"]:
alpha = current_epoch / float(config["solver"]["warmup_epochs"])
return config["solver"]["warmup_factor"] * (1.0 - alpha) + alpha
else:
idx = bisect(config["solver"]["lr_milestones"], current_epoch)
return pow(config["solver"]["lr_gamma"], idx)
optimizer = optim.Adamax(model.parameters(), lr=config["solver"]["initial_lr"])
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_fun)
# =============================================================================
# SETUP BEFORE TRAINING LOOP
# =============================================================================
summary_writer = SummaryWriter(log_dir=args.save_dirpath)
checkpoint_manager = CheckpointManager(
model, optimizer, args.save_dirpath, config=config
)
sparse_metrics = SparseGTMetrics()
ndcg = NDCG()
# If loading from checkpoint, adjust start epoch and load parameters.
if args.load_pthpath == "":
start_epoch = 0
else:
# "path/to/checkpoint_xx.pth" -> xx
start_epoch = int(args.load_pthpath.split("_")[-1][:-4])
model_state_dict, optimizer_state_dict = load_checkpoint(args.load_pthpath)
if isinstance(model, nn.DataParallel):
model.module.load_state_dict(model_state_dict)
else:
model.load_state_dict(model_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
print("Loaded model from {}".format(args.load_pthpath))
# =============================================================================
# TRAINING LOOP
# =============================================================================
# Forever increasing counter to keep track of iterations (for tensorboard log).
global_iteration_step = start_epoch * iterations
for epoch in range(start_epoch, config["solver"]["num_epochs"]):
# -------------------------------------------------------------------------
# ON EPOCH START (combine dataloaders if training on train + val)
# -------------------------------------------------------------------------
if config["solver"]["training_splits"] == "trainval":
combined_dataloader = itertools.chain(train_dataloader, val_dataloader)
else:
combined_dataloader = itertools.chain(train_dataloader)
print(f"\nTraining for epoch {epoch}:")
for i, batch in enumerate(tqdm(combined_dataloader)):
for key in batch:
batch[key] = batch[key].to(device)
optimizer.zero_grad()
output = model(batch)
target = (
batch["ans_ind"]
if config["model"]["decoder"] == "disc"
else batch["ans_out"]
)
batch_loss = criterion(
output.view(-1, output.size(-1)), target.view(-1)
)
batch_loss.backward()
optimizer.step()
summary_writer.add_scalar(
"train/loss", batch_loss, global_iteration_step
)
summary_writer.add_scalar(
"train/lr", optimizer.param_groups[0]["lr"], global_iteration_step
)
scheduler.step(global_iteration_step)
global_iteration_step += 1
torch.cuda.empty_cache()
# -------------------------------------------------------------------------
# ON EPOCH END (checkpointing and validation)
# -------------------------------------------------------------------------
checkpoint_manager.step()
# Validate and report automatic metrics.
if args.validate:
# Switch dropout, batchnorm etc to the correct mode.
model.eval()
print(f"\nValidation after epoch {epoch}:")
for i, batch in enumerate(tqdm(val_dataloader)):
for key in batch:
batch[key] = batch[key].to(device)
with torch.no_grad():
output = model(batch)
sparse_metrics.observe(output, batch["ans_ind"])
if "gt_relevance" in batch:
output = output[
torch.arange(output.size(0)), batch["round_id"] - 1, :
]
ndcg.observe(output, batch["gt_relevance"])
all_metrics = {}
all_metrics.update(sparse_metrics.retrieve(reset=True))
all_metrics.update(ndcg.retrieve(reset=True))
for metric_name, metric_value in all_metrics.items():
print(f"{metric_name}: {metric_value}")
summary_writer.add_scalars(
"metrics", all_metrics, global_iteration_step
)
model.train()
torch.cuda.empty_cache()