-
Notifications
You must be signed in to change notification settings - Fork 44
/
main.py
108 lines (88 loc) · 3.74 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
Main model training training script.
See Makefile `main` to see usage.
"""
import argparse
from pathlib import Path
import torch
from torch import optim
from torchvision import transforms
from tinyfaces import trainer
from tinyfaces.datasets import get_dataloader
from tinyfaces.models.loss import DetectionCriterion
from tinyfaces.models.model import DetectionModel
def arguments():
parser = argparse.ArgumentParser()
parser.add_argument("traindata")
parser.add_argument("valdata")
parser.add_argument("--dataset-root", default="")
parser.add_argument("--dataset", default="WIDERFace")
parser.add_argument("--lr", default=1e-4, type=float)
parser.add_argument("--weight-decay", default=0.0005, type=float)
parser.add_argument("--momentum", default=0.9, type=float)
parser.add_argument("--batch_size", default=12, type=int)
parser.add_argument("--workers", default=8, type=int)
parser.add_argument("--start-epoch", default=0, type=int)
parser.add_argument("--epochs", default=50, type=int)
parser.add_argument("--save-every", default=10, type=int)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--debug", action="store_true")
return parser.parse_args()
def main():
args = arguments()
num_templates = 25 # aka the number of clusters
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
img_transforms = transforms.Compose([transforms.ToTensor(), normalize])
train_loader, _ = get_dataloader(args.traindata,
args,
num_templates,
img_transforms=img_transforms)
model = DetectionModel(num_objects=1, num_templates=num_templates)
loss_fn = DetectionCriterion(num_templates)
# directory where we'll store model weights
weights_dir = Path("weights")
if not weights_dir.exists():
weights_dir.mkdir()
# check for CUDA
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
# As per Peiyun, SGD is more robust than Adam and works really well
optimizer = optim.SGD(model.learnable_parameters(args.lr),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.resume:
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
# Set the start epoch if it has not been
if not args.start_epoch:
args.start_epoch = checkpoint['epoch']
scheduler = optim.lr_scheduler.StepLR(optimizer,
step_size=20,
last_epoch=args.start_epoch - 1)
# train and evalute for `epochs`
for epoch in range(args.start_epoch, args.epochs):
trainer.train(model,
loss_fn,
optimizer,
train_loader,
epoch,
device=device)
scheduler.step()
if (epoch + 1) % args.save_every == 0:
trainer.save_checkpoint(
{
'epoch': epoch + 1,
'batch_size': train_loader.batch_size,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()
},
filename="checkpoint_{0}.pth".format(epoch + 1),
save_path=weights_dir)
if __name__ == '__main__':
main()