-
Notifications
You must be signed in to change notification settings - Fork 323
/
Copy pathmain_multi_gpu.py
485 lines (434 loc) · 20.9 KB
/
main_multi_gpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
# Copyright (c) 2021 PPViT Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ViT train and eval using multiple GPU """
import sys
import os
import time
import argparse
import random
import math
import numpy as np
import paddle
from datasets import get_dataloader
from datasets import get_dataset
from config import get_config
from config import update_config
from utils import AverageMeter
from utils import get_logger
from utils import write_log
from utils import all_reduce_mean
from vit import build_vit as build_model
def get_arguments():
"""return argumeents, this will overwrite the config by (1) yaml file (2) argument values"""
parser = argparse.ArgumentParser('ViT')
parser.add_argument('-cfg', type=str, default=None)
parser.add_argument('-dataset', type=str, default=None)
parser.add_argument('-data_path', type=str, default=None)
parser.add_argument('-output', type=str, default=None)
parser.add_argument('-batch_size', type=int, default=None)
parser.add_argument('-batch_size_eval', type=int, default=None)
parser.add_argument('-image_size', type=int, default=None)
parser.add_argument('-accum_iter', type=int, default=None)
parser.add_argument('-pretrained', type=str, default=None)
parser.add_argument('-resume', type=str, default=None)
parser.add_argument('-last_epoch', type=int, default=None)
parser.add_argument('-eval', action='store_true')
parser.add_argument('-amp', action='store_true')
arguments = parser.parse_args()
return arguments
def train(dataloader,
model,
optimizer,
criterion,
epoch,
total_epochs,
total_batches,
debug_steps=100,
accum_iter=1,
amp_grad_scaler=None,
local_logger=None,
master_logger=None):
"""Training for one epoch
Args:
dataloader: paddle.io.DataLoader, dataloader instance
model: nn.Layer, a ViT model
optimizer: nn.optimizer
criterion: nn.XXLoss
epoch: int, current epoch
total_epochs: int, total num of epochs
total_batches: int, total num of batches for one epoch
debug_steps: int, num of iters to log info, default: 100
accum_iter: int, num of iters for accumulating gradients, default: 1
amp_grad_scaler: GradScaler, if not None pass the GradScaler and enable AMP, default: None
local_logger: logger for local process/gpu, default: None
master_logger: logger for main process, default: None
Returns:
train_loss_meter.avg: float, average loss on current process/gpu
train_acc_meter.avg: float, average acc@1 on current process/gpu
master_loss_meter.avg: float, average loss on all processes/gpus
master_acc_meter.avg: float, average acc@1 on all processes/gpus
train_time: float, training time
"""
time_st = time.time()
train_loss_meter = AverageMeter()
train_acc_meter = AverageMeter()
master_loss_meter = AverageMeter()
master_acc_meter = AverageMeter()
model.train()
optimizer.clear_grad()
for batch_id, data in enumerate(dataloader):
# get data
images = data[0]
label = data[1]
batch_size = images.shape[0]
# forward
with paddle.amp.auto_cast(amp_grad_scaler is not None):
output = model(images)
loss = criterion(output, label)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
loss = loss / accum_iter
# backward and step
if amp_grad_scaler is None: # fp32
loss.backward()
if ((batch_id + 1) % accum_iter == 0) or (batch_id + 1 == len(dataloader)):
optimizer.step()
optimizer.clear_grad()
else: # amp
scaled_loss = amp_grad_scaler.scale(loss)
scaled_loss.backward()
if ((batch_id + 1) % accum_iter == 0) or (batch_id + 1 == len(dataloader)):
# amp for param group reference: https://github.com/PaddlePaddle/Paddle/issues/37188
amp_grad_scaler.step(optimizer)
amp_grad_scaler.update()
optimizer.clear_grad()
pred = paddle.nn.functional.softmax(output)
acc = paddle.metric.accuracy(pred, label.unsqueeze(1)).item()
# sync from other gpus for overall loss and acc
master_loss = all_reduce_mean(loss_value)
master_acc = all_reduce_mean(acc)
master_batch_size = all_reduce_mean(batch_size)
master_loss_meter.update(master_loss, master_batch_size)
master_acc_meter.update(master_acc, master_batch_size)
train_loss_meter.update(loss_value, batch_size)
train_acc_meter.update(acc, batch_size)
if batch_id % debug_steps == 0 or batch_id + 1 == len(dataloader):
general_message = (f"Epoch[{epoch:03d}/{total_epochs:03d}], "
f"Step[{batch_id:04d}/{total_batches:04d}], "
f"Lr: {optimizer.get_lr():04f}, ")
local_message = (general_message +
f"Loss: {loss_value:.4f} ({train_loss_meter.avg:.4f}), "
f"Avg Acc: {train_acc_meter.avg:.4f}")
master_message = (general_message +
f"Loss: {master_loss:.4f} ({master_loss_meter.avg:.4f}), "
f"Avg Acc: {master_acc_meter.avg:.4f}")
write_log(local_logger, master_logger, local_message, master_message)
paddle.distributed.barrier()
train_time = time.time() - time_st
return (train_loss_meter.avg,
train_acc_meter.avg,
master_loss_meter.avg,
master_acc_meter.avg,
train_time)
@paddle.no_grad()
def validate(dataloader,
model,
criterion,
total_batches,
debug_steps=100,
local_logger=None,
master_logger=None):
"""Validation for the whole dataset
Args:
dataloader: paddle.io.DataLoader, dataloader instance
model: nn.Layer, a ViT model
total_batches: int, total num of batches for one epoch
debug_steps: int, num of iters to log info, default: 100
local_logger: logger for local process/gpu, default: None
master_logger: logger for main process, default: None
Returns:
val_loss_meter.avg: float, average loss on current process/gpu
val_acc1_meter.avg: float, average top1 accuracy on current processes/gpus
val_acc5_meter.avg: float, average top5 accuracy on current processes/gpus
master_loss_meter.avg: float, average loss on all processes/gpus
master_acc1_meter.avg: float, average top1 accuracy on all processes/gpus
master_acc5_meter.avg: float, average top5 accuracy on all processes/gpus
val_time: float, validation time
"""
model.eval()
val_loss_meter = AverageMeter()
val_acc1_meter = AverageMeter()
val_acc5_meter = AverageMeter()
master_loss_meter = AverageMeter()
master_acc1_meter = AverageMeter()
master_acc5_meter = AverageMeter()
time_st = time.time()
for batch_id, data in enumerate(dataloader):
# get data
images = data[0]
label = data[1]
batch_size = images.shape[0]
output = model(images)
loss = criterion(output, label)
loss_value = loss.item()
pred = paddle.nn.functional.softmax(output)
acc1 = paddle.metric.accuracy(pred, label.unsqueeze(1)).item()
acc5 = paddle.metric.accuracy(pred, label.unsqueeze(1), k=5).item()
# sync from other gpus for overall loss and acc
master_loss = all_reduce_mean(loss_value)
master_acc1 = all_reduce_mean(acc1)
master_acc5 = all_reduce_mean(acc5)
master_batch_size = all_reduce_mean(batch_size)
master_loss_meter.update(master_loss, master_batch_size)
master_acc1_meter.update(master_acc1, master_batch_size)
master_acc5_meter.update(master_acc5, master_batch_size)
val_loss_meter.update(loss_value, batch_size)
val_acc1_meter.update(acc1, batch_size)
val_acc5_meter.update(acc5, batch_size)
if batch_id % debug_steps == 0:
local_message = (f"Step[{batch_id:04d}/{total_batches:04d}], "
f"Avg Loss: {val_loss_meter.avg:.4f}, "
f"Avg Acc@1: {val_acc1_meter.avg:.4f}, "
f"Avg Acc@5: {val_acc5_meter.avg:.4f}")
master_message = (f"Step[{batch_id:04d}/{total_batches:04d}], "
f"Avg Loss: {master_loss_meter.avg:.4f}, "
f"Avg Acc@1: {master_acc1_meter.avg:.4f}, "
f"Avg Acc@5: {master_acc5_meter.avg:.4f}")
write_log(local_logger, master_logger, local_message, master_message)
paddle.distributed.barrier()
val_time = time.time() - time_st
return (val_loss_meter.avg,
val_acc1_meter.avg,
val_acc5_meter.avg,
master_loss_meter.avg,
master_acc1_meter.avg,
master_acc5_meter.avg,
val_time)
def main_worker(*args):
"""main method for each process"""
# STEP 0: Preparation
paddle.device.set_device('gpu')
paddle.distributed.init_parallel_env()
world_size = paddle.distributed.get_world_size()
local_rank = paddle.distributed.get_rank()
config = args[0]
last_epoch = config.TRAIN.LAST_EPOCH
seed = config.SEED + local_rank
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
local_logger, master_logger = get_logger(config.SAVE)
message = (f'----- world_size = {world_size}, local_rank = {local_rank} \n'
f'----- {config}')
write_log(local_logger, master_logger, message)
# STEP 1: Create model
model = build_model(config)
# STEP 2: Create train and val dataloader
if not config.EVAL:
dataset_train = args[1]
dataloader_train = get_dataloader(config, dataset_train, True, True)
total_batch_train = len(dataloader_train)
message = f'----- Total # of train batch (single gpu): {total_batch_train}'
write_log(local_logger, master_logger, message)
dataset_val = args[2]
dataloader_val = get_dataloader(config, dataset_val, False, True)
total_batch_val = len(dataloader_val)
message = f'----- Total # of val batch (single gpu): {total_batch_val}'
write_log(local_logger, master_logger, message)
# STEP 3: Define loss/criterion
criterion = paddle.nn.CrossEntropyLoss()
# STEP 4: Define optimizer and lr_scheduler
if not config.EVAL:
# define scaler for amp training
amp_grad_scaler = paddle.amp.GradScaler() if config.AMP else None
# warmup + cosine lr scheduler
if config.TRAIN.WARMUP_EPOCHS > 0:
cosine_lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=config.TRAIN.BASE_LR,
T_max=config.TRAIN.NUM_EPOCHS - config.TRAIN.WARMUP_EPOCHS,
eta_min=config.TRAIN.END_LR,
last_epoch=-1) # do not set last epoch, handled in warmup sched get_lr()
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
learning_rate=cosine_lr_scheduler, # use cosine lr sched after warmup
warmup_steps=config.TRAIN.WARMUP_EPOCHS, # only support position integet
start_lr=config.TRAIN.WARMUP_START_LR,
end_lr=config.TRAIN.BASE_LR,
last_epoch=config.TRAIN.LAST_EPOCH)
else:
lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=config.TRAIN.BASE_LR,
T_max=config.TRAIN.NUM_EPOCHS,
eta_min=config.TRAIN.END_LR,
last_epoch=config.TRAIN.LAST_EPOCH)
# set gradient clip
if config.TRAIN.GRAD_CLIP:
clip = paddle.nn.ClipGradByGlobalNorm(config.TRAIN.GRAD_CLIP)
else:
clip = None
# set optimizer
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
learning_rate=lr_scheduler, # set to scheduler
beta1=config.TRAIN.OPTIMIZER.BETAS[0],
beta2=config.TRAIN.OPTIMIZER.BETAS[1],
weight_decay=config.TRAIN.WEIGHT_DECAY,
epsilon=config.TRAIN.OPTIMIZER.EPS,
grad_clip=clip)
# STEP 5: (Optional) Load pretrained model weights for evaluation or finetuning
if config.MODEL.PRETRAINED:
assert os.path.isfile(config.MODEL.PRETRAINED) is True
model_state = paddle.load(config.MODEL.PRETRAINED)
if 'model' in model_state: # load state_dict with multi items: model, optimier, and epoch
# pretrain only load model weight, opt and epoch are ignored
model_state = model_state['model']
model.set_state_dict(model_state)
message = f"----- Pretrained: Load model state from {config.MODEL.PRETRAINED}"
write_log(local_logger, master_logger, message)
# STEP 6: (Optional) Load model weights and status for resume training
if config.MODEL.RESUME:
assert os.path.isfile(config.MODEL.RESUME) is True
model_state = paddle.load(config.MODEL.RESUME)
if 'model' in model_state: # load state_dict with multi items: model, optimier, and epoch
model.set_state_dict(model_state['model'])
if 'optimizer' in model_state:
optimizer.set_state_dict(model_state['optimizer'])
if 'epoch' in model_state:
config.TRAIN.LAST_EPOCH = model_state['epoch']
last_epoch = model_state['epoch']
if 'lr_scheduler' in model_state:
lr_scheduler.set_state_dict(model_state['lr_scheduler'])
if 'amp_grad_scaler' in model_state and amp_grad_scaler is not None:
amp_grad_scaler.load_state_dict(model_state['amp_grad_scaler'])
lr_scheduler.step(last_epoch + 1)
message = (f"----- Resume Training: Load model from {config.MODEL.RESUME}, w/t "
f"opt = [{'optimizer' in model_state}], "
f"lr_scheduler = [{'lr_scheduler' in model_state}], "
f"epoch = [{model_state.get('epoch', -1)}], "
f"amp_grad_scaler = [{'amp_grad_scaler' in model_state}]")
write_log(local_logger, master_logger, message)
else: # direct load pdparams without other items
message = f"----- Resume Training: Load {config.MODEL.RESUME}, w/o opt/epoch/scaler"
write_log(local_logger, master_logger, message, 'warning')
model.set_state_dict(model_state)
lr_scheduler.step(last_epoch + 1)
# STEP 8: Enable model data parallelism on multi processes
model = paddle.DataParallel(model)
# STEP 9: (Optional) Run evaluation and return
if config.EVAL:
write_log(local_logger, master_logger, "----- Start Validation")
val_loss, val_acc1, val_acc5, avg_loss, avg_acc1, avg_acc5, val_time = validate(
dataloader=dataloader_val,
model=model,
criterion=criterion,
total_batches=total_batch_val,
debug_steps=config.REPORT_FREQ,
local_logger=local_logger,
master_logger=master_logger)
local_message = ("----- Validation: " +
f"Validation Loss: {val_loss:.4f}, " +
f"Validation Acc@1: {val_acc1:.4f}, " +
f"Validation Acc@5: {val_acc5:.4f}, " +
f"time: {val_time:.2f}")
master_message = ("----- Validation: " +
f"Validation Loss: {avg_loss:.4f}, " +
f"Validation Acc@1: {avg_acc1:.4f}, " +
f"Validation Acc@5: {avg_acc5:.4f}, " +
f"time: {val_time:.2f}")
write_log(local_logger, master_logger, local_message, master_message)
return
# STEP 10: Run training
write_log(local_logger, master_logger, f"----- Start training from epoch {last_epoch+1}.")
for epoch in range(last_epoch + 1, config.TRAIN.NUM_EPOCHS + 1):
# Train one epoch
write_log(local_logger, master_logger, f"Train epoch {epoch}. LR={optimizer.get_lr():.6e}")
train_loss, train_acc, avg_loss, avg_acc, train_time = train(
dataloader=dataloader_train,
model=model,
optimizer=optimizer,
criterion=criterion,
epoch=epoch,
total_epochs=config.TRAIN.NUM_EPOCHS,
total_batches=total_batch_train,
debug_steps=config.REPORT_FREQ,
accum_iter=config.TRAIN.ACCUM_ITER,
amp_grad_scaler=amp_grad_scaler,
local_logger=local_logger,
master_logger=master_logger)
# update lr
lr_scheduler.step()
general_message = (f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], "
f"Lr: {optimizer.get_lr():.4f}, "
f"time: {train_time:.2f}, ")
local_message = (general_message +
f"Train Loss: {train_loss:.4f}, "
f"Train Acc: {train_acc:.4f}")
master_message = (general_message +
f"Train Loss: {avg_loss:.4f}, "
f"Train Acc: {avg_acc:.4f}")
write_log(local_logger, master_logger, local_message, master_message)
# Evaluation (optional)
if epoch % config.VALIDATE_FREQ == 0 or epoch == config.TRAIN.NUM_EPOCHS:
write_log(local_logger, master_logger, f'----- Validation after Epoch: {epoch}')
val_loss, val_acc1, val_acc5, avg_loss, avg_acc1, avg_acc5, val_time = validate(
dataloader=dataloader_val,
model=model,
criterion=criterion,
total_batches=total_batch_val,
debug_steps=config.REPORT_FREQ,
local_logger=local_logger,
master_logger=master_logger)
local_message = (f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], " +
f"Validation Loss: {val_loss:.4f}, " +
f"Validation Acc@1: {val_acc1:.4f}, " +
f"Validation Acc@5: {val_acc5:.4f}, " +
f"time: {val_time:.2f}")
master_message = (f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], " +
f"Validation Loss: {avg_loss:.4f}, " +
f"Validation Acc@1: {avg_acc1:.4f}, " +
f"Validation Acc@5: {avg_acc5:.4f}, " +
f"time: {val_time:.2f}")
write_log(local_logger, master_logger, local_message, master_message)
# Save model weights and training status
if local_rank == 0:
if epoch % config.SAVE_FREQ == 0 or epoch == config.TRAIN.NUM_EPOCHS:
model_path = os.path.join(
config.SAVE, f"Epoch-{epoch}-Loss-{avg_loss}.pdparams")
state_dict = dict()
state_dict['model'] = model.state_dict()
state_dict['optimizer'] = optimizer.state_dict()
state_dict['epoch'] = epoch
if lr_scheduler is not None:
state_dict['lr_scheduler'] = lr_scheduler.state_dict()
if amp_grad_scaler is not None:
state_dict['amp_grad_scaler'] = amp_grad_scaler.state_dict()
paddle.save(state_dict, model_path)
message = (f"----- Save model: {model_path}")
write_log(local_logger, master_logger, message)
def main():
# config is updated in order: (1) default in config.py, (2) yaml file, (3) arguments
config = update_config(get_config(), get_arguments())
# set output folder
config.SAVE = os.path.join(config.SAVE,
f"{'eval' if config.EVAL else 'train'}-{time.strftime('%Y%m%d-%H-%M')}")
if not os.path.exists(config.SAVE):
os.makedirs(config.SAVE, exist_ok=True)
# get train dataset if in train mode and val dataset
dataset_train = get_dataset(config, is_train=True) if not config.EVAL else None
dataset_val = get_dataset(config, is_train=False)
# dist spawn lunch: use CUDA_VISIBLE_DEVICES to set available gpus
paddle.distributed.spawn(main_worker, args=(config, dataset_train, dataset_val))
if __name__ == "__main__":
main()