-
Notifications
You must be signed in to change notification settings - Fork 46
/
main.py
324 lines (284 loc) · 14 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
import argparse
import datetime
import json
import random
import time
from pathlib import Path
import numpy as np
import os
import torch
from torch.utils.data import DataLoader
import datasets
import util.misc as utils
import datasets.samplers as samplers
from datasets import build_dataset, get_coco_api_from_dataset
from engine import evaluate, train_one_epoch, viz
from models import build_model
from models.backbone import build_swav_backbone, build_swav_backbone_old
from util.default_args import set_model_defaults, get_args_parser
PRETRAINING_DATASETS = ['imagenet', 'imagenet100', 'coco_pretrain', 'airbus_pretrain']
def main(args):
utils.init_distributed_mode(args)
print("git:\n {}\n".format(utils.get_sha()))
if args.frozen_weights is not None:
assert args.masks, "Frozen training is meant for segmentation only"
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
if args.random_seed:
args.seed = np.random.randint(0, 1000000)
if args.resume:
checkpoint_args = torch.load(args.resume, map_location='cpu')['args']
args.seed = checkpoint_args.seed
print("Loaded random seed from checkpoint:", checkpoint_args.seed)
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
print(f"Using random seed: {seed}")
swav_model = None
if args.dataset in PRETRAINING_DATASETS:
if args.obj_embedding_head == 'head':
swav_model = build_swav_backbone(args, device)
elif args.obj_embedding_head == 'intermediate':
swav_model = build_swav_backbone_old(args, device)
model, criterion, postprocessors = build_model(args)
model.to(device)
model_without_ddp = model
n_parameters = sum(p.numel()
for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
dataset_train, dataset_val = get_datasets(args)
if args.distributed:
if args.cache_mode:
sampler_train = samplers.NodeDistributedSampler(dataset_train)
sampler_val = samplers.NodeDistributedSampler(
dataset_val, shuffle=False)
else:
sampler_train = samplers.DistributedSampler(dataset_train)
sampler_val = samplers.DistributedSampler(
dataset_val, shuffle=False)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
coco_evaluator = None
batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, args.batch_size, drop_last=True)
data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
collate_fn=utils.collate_fn, num_workers=args.num_workers,
pin_memory=True)
data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers,
pin_memory=True)
# lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"]
def match_name_keywords(n, name_keywords):
out = False
for b in name_keywords:
if b in n:
out = True
break
return out
for n, p in model_without_ddp.named_parameters():
print(n)
param_dicts = [
{
"params":
[p for n, p in model_without_ddp.named_parameters()
if not match_name_keywords(n, args.lr_backbone_names) and not match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
"lr": args.lr,
},
{
"params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad],
"lr": args.lr_backbone,
},
{
"params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
"lr": args.lr * args.lr_linear_proj_mult,
}
]
if args.sgd:
optimizer = torch.optim.SGD(param_dicts, lr=args.lr, momentum=0.9,
weight_decay=args.weight_decay)
else:
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.gpu])
model_without_ddp = model.module
if args.dataset_file == "coco_panoptic":
# We also evaluate AP during panoptic training, on original coco DS
coco_val = datasets.coco.build("val", args)
base_ds = get_coco_api_from_dataset(coco_val)
elif args.dataset_file == "coco" or args.dataset_file == "airbus":
base_ds = get_coco_api_from_dataset(dataset_val)
else:
base_ds = dataset_val
if args.frozen_weights is not None:
checkpoint = torch.load(args.frozen_weights, map_location='cpu')
model_without_ddp.detr.load_state_dict(checkpoint['model'])
output_dir = Path(args.output_dir)
if args.pretrain:
print('Initialized from the pre-training model')
checkpoint = torch.load(args.pretrain, map_location='cpu')
state_dict = checkpoint['model']
for k in list(state_dict.keys()):
# remove useless class embed
if 'class_embed' in k:
del state_dict[k]
msg = model_without_ddp.load_state_dict(state_dict, strict=False)
print(msg)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
missing_keys, unexpected_keys = model_without_ddp.load_state_dict(
checkpoint['model'], strict=False)
unexpected_keys = [k for k in unexpected_keys if not (
k.endswith('total_params') or k.endswith('total_ops'))]
if len(missing_keys) > 0:
print('Missing Keys: {}'.format(missing_keys))
if len(unexpected_keys) > 0:
print('Unexpected Keys: {}'.format(unexpected_keys))
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
import copy
p_groups = copy.deepcopy(optimizer.param_groups)
optimizer.load_state_dict(checkpoint['optimizer'])
for pg, pg_old in zip(optimizer.param_groups, p_groups):
pg['lr'] = pg_old['lr']
pg['initial_lr'] = pg_old['initial_lr']
print(optimizer.param_groups)
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
# todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance).
args.override_resumed_lr_drop = True
if args.override_resumed_lr_drop:
print('Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler.')
lr_scheduler.step_size = args.lr_drop
lr_scheduler.base_lrs = list(
map(lambda group: group['initial_lr'], optimizer.param_groups))
lr_scheduler.step(lr_scheduler.last_epoch)
args.start_epoch = checkpoint['epoch'] + 1
# check the resumed model
if (not args.eval and not args.viz and args.dataset in ['coco', 'voc']):
test_stats, coco_evaluator = evaluate(
model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir
)
if args.eval:
test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
data_loader_val, base_ds, device, args.output_dir)
if args.output_dir:
utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
return
if args.viz:
viz(model, criterion, postprocessors,
data_loader_val, base_ds, device, args.output_dir)
return
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
sampler_train.set_epoch(epoch)
train_stats = train_one_epoch(
model, swav_model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm)
lr_scheduler.step()
if args.output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth']
# extra checkpoint before LR drop and every 5 epochs
if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 5 == 0:
checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
for checkpoint_path in checkpoint_paths:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args,
}, checkpoint_path)
if args.dataset in ['coco', 'voc'] and epoch % args.eval_every == 0:
test_stats, coco_evaluator = evaluate(
model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir
)
else:
test_stats = {}
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
# for evaluation logs
if 'imagenet' not in args.dataset and coco_evaluator is not None:
(output_dir / 'eval').mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ['latest.pth']
if epoch % 50 == 0:
filenames.append(f'{epoch:03}.pth')
for name in filenames:
torch.save(coco_evaluator.coco_eval["bbox"].eval,
output_dir / "eval" / name)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def get_datasets(args):
if args.dataset == 'coco':
dataset_train = build_dataset(image_set='train', args=args)
dataset_val = build_dataset(image_set='val', args=args)
elif args.dataset == 'coco_pretrain':
from datasets.selfdet import build_selfdet
dataset_train = build_selfdet(
'train', args=args, p=os.path.join(args.coco_path, 'train2017'))
dataset_val = build_dataset(image_set='val', args=args)
elif args.dataset == 'airbus':
dataset_train = build_dataset(image_set='train', args=args)
dataset_val = build_dataset(image_set='val', args=args)
elif args.dataset == 'airbus_pretrain':
from datasets.selfdet import build_selfdet
dataset_train = build_selfdet(
'train', args=args, p=os.path.join(args.airbus_path, 'train_v2'))
dataset_val = build_dataset(image_set='val', args=args)
elif args.dataset == 'imagenet':
from datasets.selfdet import build_selfdet
dataset_train = build_selfdet(
'train', args=args, p=os.path.join(args.imagenet_path, 'train'))
dataset_val = build_dataset(image_set='val', args=args)
elif args.dataset == 'imagenet100':
from datasets.selfdet import build_selfdet
dataset_train = build_selfdet(
'train', args=args, p=os.path.join(args.imagenet100_path, 'train'))
dataset_val = build_dataset(image_set='val', args=args)
elif args.dataset == 'voc':
from datasets.torchvision_datasets.voc import VOCDetection
from datasets.coco import make_coco_transforms
dataset_train = VOCDetection(args.voc_path, ["2007", "2012"], image_sets=['trainval', 'trainval'],
transforms=make_coco_transforms('train'), filter_pct=args.filter_pct, seed=args.seed)
dataset_val = VOCDetection(args.voc_path, ["2007"], image_sets=[
'test'], transforms=make_coco_transforms('val'), seed=args.seed)
else:
raise ValueError(f"Wrong dataset name: {args.dataset}")
return dataset_train, dataset_val
def set_dataset_path(args):
args.coco_path = os.path.join(args.data_root, 'MSCoco')
args.airbus_path = os.path.join(args.data_root, 'airbus-ship-detection')
args.imagenet_path = os.path.join(args.data_root, 'ilsvrc')
args.imagenet100_path = os.path.join(args.data_root, 'ilsvrc100')
args.voc_path = os.path.join(args.data_root, 'pascal')
if __name__ == '__main__':
parser = argparse.ArgumentParser('Deformable DETR training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
set_dataset_path(args)
set_model_defaults(args)
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)