Skip to content

Commit

Permalink
Refactor trainer.py (#396)
Browse files Browse the repository at this point in the history
* Refactor trainer.py

* Refactor trainer.py

* Refactor trainer.py

* Refactor trainer.py

* Refactor trainer.py

* Update trainer.py

* Update opts.py
  • Loading branch information
hhou435 authored Oct 30, 2023
1 parent efd6e28 commit 19e1216
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 45 deletions.
17 changes: 17 additions & 0 deletions uer/initialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch


def init_env(args):
if args.dist_train:
# Initialize multiprocessing distributed training environment.
args.global_rank = args.gpu_ranks[args.local_rank]
torch.distributed.init_process_group(backend=args.backend,
init_method=args.master_ip,
world_size=args.world_size,
rank=args.global_rank)
elif args.single_gpu:
args.global_rank = None
else:
args.global_rank = None

return None
97 changes: 52 additions & 45 deletions uer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from uer.initialize import init_env
from uer.model_loader import load_model
from uer.model_saver import save_model
from uer.model_builder import build_model
Expand All @@ -13,17 +14,7 @@
from uer.utils.seed import set_seed


def train_and_validate(args):
set_seed(args.seed)

# Load vocabulary.
if args.data_processor == "mt":
args.tgt_tokenizer = str2tokenizer[args.tgt_tokenizer](args, is_src=False)
args.tgt_vocab = args.tgt_tokenizer.vocab

args.tokenizer = str2tokenizer[args.tokenizer](args)
args.vocab = args.tokenizer.vocab

def init_model(args):
# Build model.
model = build_model(args)

Expand All @@ -47,16 +38,54 @@ def train_and_validate(args):
for n, p in list(model.named_parameters()):
if "gamma" not in n and "beta" not in n:
p.data.normal_(0, 0.02)
return model


def init_optimizer(args, model):
# Build optimizer.
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "gamma", "beta"]
optimizer_grouped_parameters = [
{"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
]

if args.optimizer in ["adamw"]:
custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
else:
custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, scale_parameter=False, relative_step=False)
if args.scheduler in ["constant"]:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer)
elif args.scheduler in ["constant_with_warmup"]:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup)
elif args.scheduler in ["tri_stage"]:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps*args.lr_decay, args.total_steps)
else:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps)

return custom_optimizer, custom_scheduler


def train_and_validate(args):
set_seed(args.seed)

# Load vocabulary.
if args.data_processor == "mt":
args.tgt_tokenizer = str2tokenizer[args.tgt_tokenizer](args, is_src=False)
args.tgt_vocab = args.tgt_tokenizer.vocab

args.tokenizer = str2tokenizer[args.tokenizer](args)
args.vocab = args.tokenizer.vocab

if args.dist_train:
# Multiprocessing distributed mode.
mp.spawn(worker, nprocs=args.ranks_num, args=(args.gpu_ranks, args, model), daemon=False)
mp.spawn(worker, nprocs=args.ranks_num, args=(args.gpu_ranks, args), daemon=False)
elif args.single_gpu:
# Single GPU mode.
worker(args.local_rank, None, args, model)
worker(args.local_rank, None, args)
else:
# CPU mode.
worker(None, None, args, model)
worker(None, None, args)


class Trainer(object):
Expand Down Expand Up @@ -411,7 +440,7 @@ class PrefixlmTrainer(MlmTrainer):
"bart": BartTrainer, "prefixlm": PrefixlmTrainer, "cls_mlm": ClsMlmTrainer}


def worker(local_rank, gpu_ranks, args, model):
def worker(local_rank, gpu_ranks, args):
"""
Args:
local_rank: The id of GPU for single GPU mode;
Expand All @@ -423,45 +452,23 @@ def worker(local_rank, gpu_ranks, args, model):
# Get logger
args.logger = init_logger(args)

if args.dist_train:
global_rank = gpu_ranks[local_rank]
elif args.single_gpu:
global_rank = None
else:
global_rank = None
# Env initialize.
args.local_rank = local_rank
init_env(args)
global_rank = args.global_rank

# Build optimizer.
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "gamma", "beta"]
optimizer_grouped_parameters = [
{"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
]
# Build model.
model = init_model(args)

if args.optimizer in ["adamw"]:
custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
else:
custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, scale_parameter=False, relative_step=False)
if args.scheduler in ["constant"]:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer)
elif args.scheduler in ["constant_with_warmup"]:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup)
elif args.scheduler in ["tri_stage"]:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps*args.lr_decay, args.total_steps)
else:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps)
# Build optimizer.
custom_optimizer, custom_scheduler = init_optimizer(args, model)

if local_rank is not None:
model.cuda(local_rank)
optimizer = custom_optimizer
scheduler = custom_scheduler

if args.dist_train:
# Initialize multiprocessing distributed training environment.
dist.init_process_group(backend=args.backend,
init_method=args.master_ip,
world_size=args.world_size,
rank=global_rank)
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True)
args.logger.info("Worker %d is training ... " % global_rank)
else:
Expand Down

0 comments on commit 19e1216

Please sign in to comment.