forked from garvita-tiwari/PoseNDF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
37 lines (30 loc) · 1.14 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import argparse
from configs.config import load_config
# General config
#from model_quat import train_manifold2 as train_manifold
from model.train_posendf import PoseNDF_trainer
import shutil
from data.data_splits import amass_splits
def train(opt,config_file):
trainer = PoseNDF_trainer(opt)
# copy the config file
copy_config = '{}/{}/{}'.format(opt['experiment']['root_dir'], trainer.exp_name, 'config.yaml')
shutil.copyfile(config_file,copy_config )
val = opt['experiment']['val']
test = opt['experiment']['test']
if test:
trainer.inference(trainer.ep)
for i in range(trainer.ep, opt['train']['max_epoch']):
loss,epoch_loss = trainer.train_model(i)
if val and i%100==0:
trainer.validate(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Train PoseNDF.'
)
parser.add_argument('--config', '-c', default='configs/amass.yaml', type=str, help='Path to config file.')
parser.add_argument('--test', '-t', action="store_true")
args = parser.parse_args()
opt = load_config(args.config)
#save the config file
train(opt, args.config)