-
Notifications
You must be signed in to change notification settings - Fork 118
/
train.py
99 lines (85 loc) · 4.67 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
from torch import nn
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
import click
import numpy as np
from pspnet import PSPNet
models = {
'squeezenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='squeezenet'),
'densenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=1024, deep_features_size=512, backend='densenet'),
'resnet18': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18'),
'resnet34': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34'),
'resnet50': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50'),
'resnet101': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101'),
'resnet152': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet152')
}
def build_network(snapshot, backend):
epoch = 0
backend = backend.lower()
net = models[backend]()
net = nn.DataParallel(net)
if snapshot is not None:
_, epoch = os.path.basename(snapshot).split('_')
epoch = int(epoch)
net.load_state_dict(torch.load(snapshot))
logging.info("Snapshot for epoch {} loaded from {}".format(epoch, snapshot))
net = net.cuda()
return net, epoch
@click.command()
@click.option('--data-path', type=str, help='Path to dataset folder')
@click.option('--models-path', type=str, help='Path for storing model snapshots')
@click.option('--backend', type=str, default='resnet34', help='Feature extractor')
@click.option('--snapshot', type=str, default=None, help='Path to pretrained weights')
@click.option('--crop_x', type=int, default=256, help='Horizontal random crop size')
@click.option('--crop_y', type=int, default=256, help='Vertical random crop size')
@click.option('--batch-size', type=int, default=16)
@click.option('--alpha', type=float, default=1.0, help='Coefficient for classification loss term')
@click.option('--epochs', type=int, default=20, help='Number of training epochs to run')
@click.option('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3')
@click.option('--start-lr', type=float, default=0.001)
@click.option('--milestones', type=str, default='10,20,30', help='Milestones for LR decreasing')
def train(data_path, models_path, backend, snapshot, crop_x, crop_y, batch_size, alpha, epochs, start_lr, milestones, gpu):
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
net, starting_epoch = build_network(snapshot, backend)
data_path = os.path.abspath(os.path.expanduser(data_path))
models_path = os.path.abspath(os.path.expanduser(models_path))
os.makedirs(models_path, exist_ok=True)
'''
To follow this training routine you need a DataLoader that yields the tuples of the following format:
(Bx3xHxW FloatTensor x, BxHxW LongTensor y, BxN LongTensor y_cls) where
x - batch of input images,
y - batch of groung truth seg maps,
y_cls - batch of 1D tensors of dimensionality N: N total number of classes,
y_cls[i, T] = 1 if class T is present in image i, 0 otherwise
'''
train_loader, class_weights, n_images = None, None, None
optimizer = optim.Adam(net.parameters(), lr=start_lr)
scheduler = MultiStepLR(optimizer, milestones=[int(x) for x in milestones.split(',')])
for epoch in range(starting_epoch, starting_epoch + epochs):
seg_criterion = nn.NLLLoss2d(weight=class_weights)
cls_criterion = nn.BCEWithLogitsLoss(weight=class_weights)
epoch_losses = []
train_iterator = tqdm(loader, total=max_steps // batch_size + 1)
net.train()
for x, y, y_cls in train_iterator:
steps += batch_size
optimizer.zero_grad()
x, y, y_cls = Variable(x).cuda(), Variable(y).cuda(), Variable(y_cls).cuda()
out, out_cls = net(x)
seg_loss, cls_loss = seg_criterion(out, y), cls_criterion(out_cls, y_cls)
loss = seg_loss + alpha * cls_loss
epoch_losses.append(loss.data[0])
status = '[{0}] loss = {1:0.5f} avg = {2:0.5f}, LR = {5:0.7f}'.format(
epoch + 1, loss.data[0], np.mean(epoch_losses), scheduler.get_lr()[0])
train_iterator.set_description(status)
loss.backward()
optimizer.step()
scheduler.step()
torch.save(net.state_dict(), os.path.join(models_path, '_'.join(["PSPNet", str(epoch + 1)])))
train_loss = np.mean(epoch_losses)
if __name__ == '__main__':
train()