-
Notifications
You must be signed in to change notification settings - Fork 21
/
train.py
80 lines (72 loc) · 2.29 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Copyright (c) Gorilla-Lab. All rights reserved.
import os
from os.path import join as opj
import torch
import torch.nn as nn
import numpy as np
from gorilla.config import Config
import models
import loss
from utils import *
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Train a model")
parser.add_argument("config", help="train config file path")
parser.add_argument("--work_dir", help="the dir to save logs and models")
parser.add_argument(
"--gpu",
type=str,
default=None,
help="Number of gpus to use"
)
parser.add_argument(
"--checkpoint",
type=str,
default=None,
help="The checkpoint to be resume"
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
cfg = Config.fromfile(args.config)
if args.work_dir != None:
cfg.work_dir = args.work_dir
if args.gpu != None:
cfg.training_cfg.gpu = args.gpu
print(cfg)
logger = IOStream(opj(cfg.work_dir, 'run.log'))
os.environ["CUDA_VISIBLE_DEVICES"] = cfg.training_cfg.gpu
num_gpu = len(cfg.training_cfg.gpu.split(','))
logger.cprint('Use %d GPUs: %s' % (num_gpu, cfg.training_cfg.gpu))
if cfg.get('seed', None) != None:
set_random_seed(cfg.seed)
logger.cprint('Set seed to %d' % cfg.seed)
model = build_model(cfg).cuda()
if num_gpu > 1:
model = nn.DataParallel(model)
logger.cprint('Use DataParallel!')
if args.checkpoint != None:
print("Loading checkpoint....")
_, exten = os.path.splitext(args.checkpoint)
if exten == '.t7':
model.load_state_dict(torch.load(args.checkpoint))
elif exten == '.pth':
check = torch.load(args.checkpoint)
model.load_state_dict(check['model_state_dict'])
else:
print("Training from scratch!")
dataset_dict = build_dataset(cfg)
loader_dict = build_loader(cfg, dataset_dict)
train_loss = build_loss(cfg)
optim_dict = build_optimizer(cfg, model)
training = dict(
model=model,
dataset_dict=dataset_dict,
loader_dict=loader_dict,
loss=train_loss,
optim_dict=optim_dict,
logger=logger
)
task_trainer = Trainer(cfg, training)
task_trainer.run()