-
Notifications
You must be signed in to change notification settings - Fork 316
/
Copy pathtrain_amp.py
210 lines (169 loc) · 6.12 KB
/
train_amp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import sys
sys.path.insert(0, '.')
import os
import os.path as osp
import random
import logging
import time
import argparse
import numpy as np
from tabulate import tabulate
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DataLoader
import torch.cuda.amp as amp
from lib.models import model_factory
from configs import set_cfg_from_file
from lib.get_dataloader import get_data_loader
from evaluate import eval_model
from lib.ohem_ce_loss import OhemCELoss
from lib.lr_scheduler import WarmupPolyLrScheduler
from lib.meters import TimeMeter, AvgMeter
from lib.logger import setup_logger, print_log_msg
## fix all random seeds
# torch.manual_seed(123)
# torch.cuda.manual_seed(123)
# np.random.seed(123)
# random.seed(123)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = True
# torch.multiprocessing.set_sharing_strategy('file_system')
def parse_args():
parse = argparse.ArgumentParser()
parse.add_argument('--local_rank', dest='local_rank', type=int, default=-1,)
parse.add_argument('--port', dest='port', type=int, default=44554,)
parse.add_argument('--config', dest='config', type=str,
default='configs/bisenetv2.py',)
parse.add_argument('--finetune-from', type=str, default=None,)
return parse.parse_args()
args = parse_args()
cfg = set_cfg_from_file(args.config)
def set_model():
logger = logging.getLogger()
net = model_factory[cfg.model_type](cfg.n_cats)
if not args.finetune_from is None:
logger.info(f'load pretrained weights from {args.finetune_from}')
net.load_state_dict(torch.load(args.finetune_from, map_location='cpu'))
if cfg.use_sync_bn: net = nn.SyncBatchNorm.convert_sync_batchnorm(net)
net.cuda()
net.train()
criteria_pre = OhemCELoss(0.7)
criteria_aux = [OhemCELoss(0.7) for _ in range(cfg.num_aux_heads)]
return net, criteria_pre, criteria_aux
def set_optimizer(model):
if hasattr(model, 'get_params'):
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = model.get_params()
# wd_val = cfg.weight_decay
wd_val = 0
params_list = [
{'params': wd_params, },
{'params': nowd_params, 'weight_decay': wd_val},
{'params': lr_mul_wd_params, 'lr': cfg.lr_start * 10},
{'params': lr_mul_nowd_params, 'weight_decay': wd_val, 'lr': cfg.lr_start * 10},
]
else:
wd_params, non_wd_params = [], []
for name, param in model.named_parameters():
if param.dim() == 1:
non_wd_params.append(param)
elif param.dim() == 2 or param.dim() == 4:
wd_params.append(param)
params_list = [
{'params': wd_params, },
{'params': non_wd_params, 'weight_decay': 0},
]
optim = torch.optim.SGD(
params_list,
lr=cfg.lr_start,
momentum=0.9,
weight_decay=cfg.weight_decay,
)
return optim
def set_model_dist(net):
local_rank = dist.get_rank()
net = nn.parallel.DistributedDataParallel(
net,
device_ids=[local_rank, ],
# find_unused_parameters=True,
output_device=local_rank
)
return net
def set_meters():
time_meter = TimeMeter(cfg.max_iter)
loss_meter = AvgMeter('loss')
loss_pre_meter = AvgMeter('loss_prem')
loss_aux_meters = [AvgMeter('loss_aux{}'.format(i))
for i in range(cfg.num_aux_heads)]
return time_meter, loss_meter, loss_pre_meter, loss_aux_meters
def train():
logger = logging.getLogger()
is_dist = dist.is_initialized()
## dataset
dl = get_data_loader(cfg, mode='train', distributed=is_dist)
## model
net, criteria_pre, criteria_aux = set_model()
## optimizer
optim = set_optimizer(net)
## mixed precision training
scaler = amp.GradScaler()
## ddp training
net = set_model_dist(net)
## meters
time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters()
## lr scheduler
lr_schdr = WarmupPolyLrScheduler(optim, power=0.9,
max_iter=cfg.max_iter, warmup_iter=cfg.warmup_iters,
warmup_ratio=0.1, warmup='exp', last_epoch=-1,)
## train loop
for it, (im, lb) in enumerate(dl):
im = im.cuda()
lb = lb.cuda()
lb = torch.squeeze(lb, 1)
optim.zero_grad()
with amp.autocast(enabled=cfg.use_fp16):
logits, *logits_aux = net(im)
loss_pre = criteria_pre(logits, lb)
loss_aux = [crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux)]
loss = loss_pre + sum(loss_aux)
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
torch.cuda.synchronize()
time_meter.update()
loss_meter.update(loss.item())
loss_pre_meter.update(loss_pre.item())
_ = [mter.update(lss.item()) for mter, lss in zip(loss_aux_meters, loss_aux)]
## print training log message
if (it + 1) % 100 == 0:
lr = lr_schdr.get_lr()
lr = sum(lr) / len(lr)
print_log_msg(
it, cfg.max_iter, lr, time_meter, loss_meter,
loss_pre_meter, loss_aux_meters)
lr_schdr.step()
## dump the final model and evaluate the result
save_pth = osp.join(cfg.respth, 'model_final.pth')
logger.info('\nsave models to {}'.format(save_pth))
state = net.module.state_dict()
if dist.get_rank() == 0: torch.save(state, save_pth)
logger.info('\nevaluating the final model')
torch.cuda.empty_cache()
heads, mious = eval_model(cfg, net.module)
logger.info(tabulate([mious, ], headers=heads, tablefmt='orgtbl'))
return
def main():
torch.cuda.set_device(args.local_rank)
dist.init_process_group(
backend='nccl',
init_method='tcp://127.0.0.1:{}'.format(args.port),
world_size=torch.cuda.device_count(),
rank=args.local_rank
)
if not osp.exists(cfg.respth): os.makedirs(cfg.respth)
setup_logger(f'{cfg.model_type}-{cfg.dataset.lower()}-train', cfg.respth)
train()
if __name__ == "__main__":
main()