-
Notifications
You must be signed in to change notification settings - Fork 50
/
Copy pathmain.py
96 lines (81 loc) · 3.56 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# ------------------------------------------------------------------------
# Modified from MGMatting (https://github.com/yucornetto/MGMatting)
# ------------------------------------------------------------------------
import os
import toml
import argparse
from pprint import pprint
import torch
from torch.utils.data import DataLoader
import utils
from utils import CONFIG
from trainer import Trainer
from dataloader.image_file import ImageFileTrain
from dataloader.data_generator import DataGenerator
from dataloader.prefetcher import Prefetcher
import wandb
import warnings
warnings.filterwarnings("ignore")
def main():
# Train or Test
if CONFIG.phase.lower() == "train":
# set distributed training
if CONFIG.dist:
CONFIG.gpu = CONFIG.local_rank
torch.cuda.set_device(CONFIG.gpu)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
CONFIG.world_size = torch.distributed.get_world_size()
# Create directories if not exist.
if CONFIG.local_rank == 0:
utils.make_dir(CONFIG.log.logging_path)
utils.make_dir(CONFIG.log.tensorboard_path)
utils.make_dir(CONFIG.log.checkpoint_path)
if CONFIG.wandb:
wandb.init(project="mam", config=CONFIG, name=CONFIG.version)
# Create a logger
logger, tb_logger = utils.get_logger(CONFIG.log.logging_path,
CONFIG.log.tensorboard_path,
logging_level=CONFIG.log.logging_level)
train_dataset = DataGenerator(phase='train')
if CONFIG.dist:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_dataloader = DataLoader(train_dataset,
batch_size=CONFIG.model.batch_size,
shuffle=(train_sampler is None),
num_workers=CONFIG.data.workers,
pin_memory=True,
sampler=train_sampler,
drop_last=True)
train_dataloader = Prefetcher(train_dataloader)
trainer = Trainer(train_dataloader=train_dataloader,
test_dataloader=None,
logger=logger,
tb_logger=tb_logger)
trainer.train()
else:
raise NotImplementedError("Unknown Phase: {}".format(CONFIG.phase))
if __name__ == '__main__':
print('Torch Version: ', torch.__version__)
parser = argparse.ArgumentParser()
parser.add_argument('--phase', type=str, default='train')
parser.add_argument('--config', type=str, default='config/gca-dist.toml')
parser.add_argument('--local_rank', type=int, default=0)
# Parse configuration
args = parser.parse_args()
with open(args.config) as f:
utils.load_config(toml.load(f))
# Check if toml config file is loaded
if CONFIG.is_default:
raise ValueError("No .toml config loaded.")
CONFIG.phase = args.phase
CONFIG.log.logging_path = os.path.join(CONFIG.log.logging_path, CONFIG.version)
CONFIG.log.tensorboard_path = os.path.join(CONFIG.log.tensorboard_path, CONFIG.version)
CONFIG.log.checkpoint_path = os.path.join(CONFIG.log.checkpoint_path, CONFIG.version)
if args.local_rank == 0:
print('CONFIG: ')
pprint(CONFIG)
CONFIG.local_rank = args.local_rank
# Train
main()