diff --git a/configs/video_warper_trainer.yaml b/configs/video_warper_trainer.yaml index 18bb6df..7ec24f3 100644 --- a/configs/video_warper_trainer.yaml +++ b/configs/video_warper_trainer.yaml @@ -38,7 +38,7 @@ trainer: model: - type: models.styleheat.mirror_warper::MirrorWarper + type: models.styleheat.warper::VideoWarper mode: train_video_warper optimized_param: all from_scratch_param: all diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..2a6bd59 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,63 @@ +import importlib + +import torch.utils.data +from utils.distributed import master_only_print as print + + +def find_dataset_using_name(dataset_name): + dataset_filename = dataset_name + module, target = dataset_name.split('::') + datasetlib = importlib.import_module(module) + dataset = None + for name, cls in datasetlib.__dict__.items(): + if name == target: + dataset = cls + + if dataset is None: + raise ValueError("In %s.py, there should be a class " + "with class name that matches %s in lowercase." % + (dataset_filename, target)) + return dataset + + +def get_option_setter(dataset_name): + dataset_class = find_dataset_using_name(dataset_name) + return dataset_class.modify_commandline_options + + +def create_dataloader(opt, is_inference): + dataset = find_dataset_using_name(opt.type) + instance = dataset(opt, is_inference) + phase = 'val' if is_inference else 'training' + batch_size = opt.val.batch_size if is_inference else opt.train.batch_size + print("%s dataset [%s] of size %d was created" % + (phase, opt.type, len(instance))) + dataloader = torch.utils.data.DataLoader( + instance, + batch_size=batch_size, + sampler=data_sampler(instance, shuffle=not is_inference, distributed=opt.train.distributed), + drop_last=not is_inference, + num_workers=getattr(opt, 'num_workers', 0), + ) + + return dataloader + + +def data_sampler(dataset, shuffle, distributed): + if distributed: + return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) + if shuffle: + return torch.utils.data.RandomSampler(dataset) + else: + return torch.utils.data.SequentialSampler(dataset) + + +def get_dataloader(opt, is_inference=False): + dataset = create_dataloader(opt, is_inference=is_inference) + return dataset + + +def get_train_val_dataloader(opt): + val_dataset = create_dataloader(opt, is_inference=True) + train_dataset = create_dataloader(opt, is_inference=False) + return val_dataset, train_dataset diff --git a/third_part/Deep3DFaceRecon_pytorch/temp/2.png b/third_part/Deep3DFaceRecon_pytorch/temp/2.png deleted file mode 100644 index edb8dcb..0000000 Binary files a/third_part/Deep3DFaceRecon_pytorch/temp/2.png and /dev/null differ diff --git a/train.py b/train.py new file mode 100644 index 0000000..40d48ac --- /dev/null +++ b/train.py @@ -0,0 +1,89 @@ +import os +import argparse +import data as Dataset + +from configs.config import Config +from utils.logging import init_logging, make_logging_dir +from utils.trainer import get_model_optimizer_and_scheduler_with_pretrain, set_random_seed, get_trainer, get_model_optimizer_and_scheduler +from utils.distributed import init_dist +from utils.distributed import master_only_print as print + + +def parse_args(): + parser = argparse.ArgumentParser(description='Training') + parser.add_argument('--config', required=True) + parser.add_argument('--name', required=True) + parser.add_argument('--checkpoints_dir', default='result', help='Dir for saving logs and models.') + parser.add_argument('--seed', type=int, default=0, help='Random seed.') + parser.add_argument('--which_iter', type=int, default=None) + parser.add_argument('--no_resume', action='store_true') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument('--single_gpu', action='store_true') + parser.add_argument('--debug', action='store_true') + args = parser.parse_args() + return args + + +def main(): + # get training options + args = parse_args() + set_random_seed(args.seed) + + opt = Config(args.config, args, is_train=True) + + if not args.single_gpu: + opt.local_rank = args.local_rank + init_dist(opt.local_rank) + opt.device = opt.local_rank + print('Distributed DataParallel Training.') + else: + print('Single GPU Training.') + opt.device = 'cuda' + opt.local_rank = 0 + opt.distributed = False + opt.data.train.distributed = False + opt.data.val.distributed = False + + # create a visualizer + date_uid, logdir = init_logging(opt) + opt.logdir = logdir + make_logging_dir(logdir, date_uid) + os.system(f'cp {args.config} {opt.logdir}') + # create a dataset + val_dataset, train_dataset = Dataset.get_train_val_dataloader(opt.data) + + # create a model + net_G, net_G_ema, opt_G, sch_G = get_model_optimizer_and_scheduler_with_pretrain(opt) + + trainer = get_trainer(opt, net_G, net_G_ema, opt_G, sch_G, train_dataset) + current_epoch, current_iteration = trainer.load_checkpoint(opt, args.which_iter) + + # training flag + if args.debug: + trainer.test_everything(train_dataset, val_dataset, current_epoch, current_iteration) + exit() + + # Start training. + for epoch in range(current_epoch, opt.max_epoch): + print('Epoch {} ...'.format(epoch)) + if not args.single_gpu: + train_dataset.sampler.set_epoch(current_epoch) + trainer.start_of_epoch(current_epoch) + for it, data in enumerate(train_dataset): + data = trainer.start_of_iteration(data, current_iteration) + trainer.optimize_parameters(data) + current_iteration += 1 + trainer.end_of_iteration(data, current_epoch, current_iteration) + + if current_iteration >= opt.max_iter: + print('Done with training!!!') + break + current_epoch += 1 + trainer.end_of_epoch(data, val_dataset, current_epoch, current_iteration) + trainer.test(val_dataset, output_dir=os.path.join(logdir, 'evaluation'), test_limit=10) + + +if __name__ == '__main__': + main() + +