Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor trainer.py #396

Merged
merged 7 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions uer/initialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch


def init_env(args):
if args.dist_train:
# Initialize multiprocessing distributed training environment.
args.global_rank = args.gpu_ranks[args.local_rank]
torch.distributed.init_process_group(backend=args.backend,
init_method=args.master_ip,
world_size=args.world_size,
rank=args.global_rank)
elif args.single_gpu:
args.global_rank = None
else:
args.global_rank = None

return None
97 changes: 52 additions & 45 deletions uer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from uer.initialize import init_env
from uer.model_loader import load_model
from uer.model_saver import save_model
from uer.model_builder import build_model
Expand All @@ -13,17 +14,7 @@
from uer.utils.seed import set_seed


def train_and_validate(args):
set_seed(args.seed)

# Load vocabulary.
if args.data_processor == "mt":
args.tgt_tokenizer = str2tokenizer[args.tgt_tokenizer](args, is_src=False)
args.tgt_vocab = args.tgt_tokenizer.vocab

args.tokenizer = str2tokenizer[args.tokenizer](args)
args.vocab = args.tokenizer.vocab

def init_model(args):
# Build model.
model = build_model(args)

Expand All @@ -47,16 +38,54 @@ def train_and_validate(args):
for n, p in list(model.named_parameters()):
if "gamma" not in n and "beta" not in n:
p.data.normal_(0, 0.02)
return model


def init_optimizer(args, model):
# Build optimizer.
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "gamma", "beta"]
optimizer_grouped_parameters = [
{"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
]

if args.optimizer in ["adamw"]:
custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
else:
custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, scale_parameter=False, relative_step=False)
if args.scheduler in ["constant"]:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer)
elif args.scheduler in ["constant_with_warmup"]:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup)
elif args.scheduler in ["tri_stage"]:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps*args.lr_decay, args.total_steps)
else:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps)

return custom_optimizer, custom_scheduler


def train_and_validate(args):
set_seed(args.seed)

# Load vocabulary.
if args.data_processor == "mt":
args.tgt_tokenizer = str2tokenizer[args.tgt_tokenizer](args, is_src=False)
args.tgt_vocab = args.tgt_tokenizer.vocab

args.tokenizer = str2tokenizer[args.tokenizer](args)
args.vocab = args.tokenizer.vocab

if args.dist_train:
# Multiprocessing distributed mode.
mp.spawn(worker, nprocs=args.ranks_num, args=(args.gpu_ranks, args, model), daemon=False)
mp.spawn(worker, nprocs=args.ranks_num, args=(args.gpu_ranks, args), daemon=False)
elif args.single_gpu:
# Single GPU mode.
worker(args.local_rank, None, args, model)
worker(args.local_rank, None, args)
else:
# CPU mode.
worker(None, None, args, model)
worker(None, None, args)


class Trainer(object):
Expand Down Expand Up @@ -411,7 +440,7 @@ class PrefixlmTrainer(MlmTrainer):
"bart": BartTrainer, "prefixlm": PrefixlmTrainer, "cls_mlm": ClsMlmTrainer}


def worker(local_rank, gpu_ranks, args, model):
def worker(local_rank, gpu_ranks, args):
"""
Args:
local_rank: The id of GPU for single GPU mode;
Expand All @@ -423,45 +452,23 @@ def worker(local_rank, gpu_ranks, args, model):
# Get logger
args.logger = init_logger(args)

if args.dist_train:
global_rank = gpu_ranks[local_rank]
elif args.single_gpu:
global_rank = None
else:
global_rank = None
# Env initialize.
args.local_rank = local_rank
init_env(args)
global_rank = args.global_rank

# Build optimizer.
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "gamma", "beta"]
optimizer_grouped_parameters = [
{"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
]
# Build model.
model = init_model(args)

if args.optimizer in ["adamw"]:
custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
else:
custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, scale_parameter=False, relative_step=False)
if args.scheduler in ["constant"]:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer)
elif args.scheduler in ["constant_with_warmup"]:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup)
elif args.scheduler in ["tri_stage"]:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps*args.lr_decay, args.total_steps)
else:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps)
# Build optimizer.
custom_optimizer, custom_scheduler = init_optimizer(args, model)

if local_rank is not None:
model.cuda(local_rank)
optimizer = custom_optimizer
scheduler = custom_scheduler

if args.dist_train:
# Initialize multiprocessing distributed training environment.
dist.init_process_group(backend=args.backend,
init_method=args.master_ip,
world_size=args.world_size,
rank=global_rank)
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True)
args.logger.info("Worker %d is training ... " % global_rank)
else:
Expand Down