diff --git a/uer/initialize.py b/uer/initialize.py new file mode 100644 index 00000000..26e57785 --- /dev/null +++ b/uer/initialize.py @@ -0,0 +1,17 @@ +import torch + + +def init_env(args): + if args.dist_train: + # Initialize multiprocessing distributed training environment. + args.global_rank = args.gpu_ranks[args.local_rank] + torch.distributed.init_process_group(backend=args.backend, + init_method=args.master_ip, + world_size=args.world_size, + rank=args.global_rank) + elif args.single_gpu: + args.global_rank = None + else: + args.global_rank = None + + return None diff --git a/uer/trainer.py b/uer/trainer.py index e80c379f..2aa97b44 100644 --- a/uer/trainer.py +++ b/uer/trainer.py @@ -3,6 +3,7 @@ import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel +from uer.initialize import init_env from uer.model_loader import load_model from uer.model_saver import save_model from uer.model_builder import build_model @@ -13,17 +14,7 @@ from uer.utils.seed import set_seed -def train_and_validate(args): - set_seed(args.seed) - - # Load vocabulary. - if args.data_processor == "mt": - args.tgt_tokenizer = str2tokenizer[args.tgt_tokenizer](args, is_src=False) - args.tgt_vocab = args.tgt_tokenizer.vocab - - args.tokenizer = str2tokenizer[args.tokenizer](args) - args.vocab = args.tokenizer.vocab - +def init_model(args): # Build model. model = build_model(args) @@ -47,16 +38,54 @@ def train_and_validate(args): for n, p in list(model.named_parameters()): if "gamma" not in n and "beta" not in n: p.data.normal_(0, 0.02) + return model + + +def init_optimizer(args, model): + # Build optimizer. + param_optimizer = list(model.named_parameters()) + no_decay = ["bias", "gamma", "beta"] + optimizer_grouped_parameters = [ + {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01}, + {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0} + ] + + if args.optimizer in ["adamw"]: + custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False) + else: + custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, scale_parameter=False, relative_step=False) + if args.scheduler in ["constant"]: + custom_scheduler = str2scheduler[args.scheduler](custom_optimizer) + elif args.scheduler in ["constant_with_warmup"]: + custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup) + elif args.scheduler in ["tri_stage"]: + custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps*args.lr_decay, args.total_steps) + else: + custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps) + + return custom_optimizer, custom_scheduler + + +def train_and_validate(args): + set_seed(args.seed) + + # Load vocabulary. + if args.data_processor == "mt": + args.tgt_tokenizer = str2tokenizer[args.tgt_tokenizer](args, is_src=False) + args.tgt_vocab = args.tgt_tokenizer.vocab + + args.tokenizer = str2tokenizer[args.tokenizer](args) + args.vocab = args.tokenizer.vocab if args.dist_train: # Multiprocessing distributed mode. - mp.spawn(worker, nprocs=args.ranks_num, args=(args.gpu_ranks, args, model), daemon=False) + mp.spawn(worker, nprocs=args.ranks_num, args=(args.gpu_ranks, args), daemon=False) elif args.single_gpu: # Single GPU mode. - worker(args.local_rank, None, args, model) + worker(args.local_rank, None, args) else: # CPU mode. - worker(None, None, args, model) + worker(None, None, args) class Trainer(object): @@ -411,7 +440,7 @@ class PrefixlmTrainer(MlmTrainer): "bart": BartTrainer, "prefixlm": PrefixlmTrainer, "cls_mlm": ClsMlmTrainer} -def worker(local_rank, gpu_ranks, args, model): +def worker(local_rank, gpu_ranks, args): """ Args: local_rank: The id of GPU for single GPU mode; @@ -423,33 +452,16 @@ def worker(local_rank, gpu_ranks, args, model): # Get logger args.logger = init_logger(args) - if args.dist_train: - global_rank = gpu_ranks[local_rank] - elif args.single_gpu: - global_rank = None - else: - global_rank = None + # Env initialize. + args.local_rank = local_rank + init_env(args) + global_rank = args.global_rank - # Build optimizer. - param_optimizer = list(model.named_parameters()) - no_decay = ["bias", "gamma", "beta"] - optimizer_grouped_parameters = [ - {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01}, - {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0} - ] + # Build model. + model = init_model(args) - if args.optimizer in ["adamw"]: - custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False) - else: - custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, scale_parameter=False, relative_step=False) - if args.scheduler in ["constant"]: - custom_scheduler = str2scheduler[args.scheduler](custom_optimizer) - elif args.scheduler in ["constant_with_warmup"]: - custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup) - elif args.scheduler in ["tri_stage"]: - custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps*args.lr_decay, args.total_steps) - else: - custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps) + # Build optimizer. + custom_optimizer, custom_scheduler = init_optimizer(args, model) if local_rank is not None: model.cuda(local_rank) @@ -457,11 +469,6 @@ def worker(local_rank, gpu_ranks, args, model): scheduler = custom_scheduler if args.dist_train: - # Initialize multiprocessing distributed training environment. - dist.init_process_group(backend=args.backend, - init_method=args.master_ip, - world_size=args.world_size, - rank=global_rank) model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True) args.logger.info("Worker %d is training ... " % global_rank) else: