Skip to content

Commit

Permalink
[train_engine] support fsdp (#2412)
Browse files Browse the repository at this point in the history
* [train_engine] support fsdp

* [train_engine] support fsdp

* unify scaler and amp

* fp32&&fp16 works in fsdp env

* fix fsdp in cv auto cast

* try to fix wenet.join fsdp

* implementing zero1 under fsdp is almost equivalent to deepspeed's zero1

* fix clip_and_grad_

* fix train summary

* all wenet xxxformer works (-paraformer -transducer)

* try to fix nan

* add barrier for cv

* add destroy group for end of all train

* refactor wrap methods and ckpt works

* fix ckpt

* fix cv in dtype != float32

* fix ckpt in model mode

* fix bf16 amp

* refactor scaler and autocast, fix fp32 fp16 bf16 for fsdp

* fix fp32 nullcontext to nullcontext()

* modify after review

* fix lint

* fix lint
  • Loading branch information
Mddct authored Apr 7, 2024
1 parent 648fee8 commit b8191ce
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 55 deletions.
20 changes: 10 additions & 10 deletions wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,24 @@
from wenet.utils.init_model import init_model
from wenet.utils.init_tokenizer import init_tokenizer
from wenet.utils.train_utils import (
add_model_args, add_dataset_args, add_ddp_args, add_deepspeed_args,
add_trace_args, init_distributed, init_dataset_and_dataloader,
check_modify_and_save_config, init_optimizer_and_scheduler,
trace_and_print_model, wrap_cuda_model, init_summarywriter, save_model,
log_per_epoch)
add_fsdp_args, add_model_args, add_dataset_args, add_ddp_args,
add_deepspeed_args, add_trace_args, init_distributed,
init_dataset_and_dataloader, check_modify_and_save_config,
init_optimizer_and_scheduler, init_scaler, trace_and_print_model,
wrap_cuda_model, init_summarywriter, save_model, log_per_epoch)


def get_args():
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--train_engine',
default='torch_ddp',
choices=['torch_ddp', 'deepspeed'],
choices=['torch_ddp', 'torch_fsdp', 'deepspeed'],
help='Engine for paralleled training')
parser = add_model_args(parser)
parser = add_dataset_args(parser)
parser = add_ddp_args(parser)
parser = add_deepspeed_args(parser)
parser = add_fsdp_args(parser)
parser = add_trace_args(parser)
args = parser.parse_args()
if args.train_engine == "deepspeed":
Expand Down Expand Up @@ -96,7 +97,7 @@ def main():
writer = init_summarywriter(args)

# Dispatch model from cpu to gpu
model, device = wrap_cuda_model(args, model)
model, device = wrap_cuda_model(args, model, configs)

# Get optimizer & scheduler
model, optimizer, scheduler = init_optimizer_and_scheduler(
Expand All @@ -118,9 +119,7 @@ def main():
int("step_" in tag))

# Init scaler, used for pytorch amp mixed precision training
scaler = None
if args.use_amp:
scaler = torch.cuda.amp.GradScaler()
scaler = init_scaler(args)

# Start training loop
start_epoch = configs["init_infos"].get('epoch', 0) + int("epoch_" in tag)
Expand Down Expand Up @@ -173,6 +172,7 @@ def main():
final_model_path) else None
os.symlink('{}.pt'.format(final_epoch), final_model_path)
writer.close()
dist.destroy_process_group()


if __name__ == '__main__':
Expand Down
20 changes: 12 additions & 8 deletions wenet/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
return configs


def save_state_dict_and_infos(state_dict, path: str, infos=None):
torch.save(state_dict, path)
info_path = re.sub('.pt$', '.yaml', path)
if infos is None:
infos = {}
infos['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
with open(info_path, 'w') as fout:
data = yaml.dump(infos)
fout.write(data)


def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
'''
Args:
Expand All @@ -52,14 +63,7 @@ def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, path)
info_path = re.sub('.pt$', '.yaml', path)
if infos is None:
infos = {}
infos['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
with open(info_path, 'w') as fout:
data = yaml.dump(infos)
fout.write(data)
save_state_dict_and_infos(state_dict, path, infos)


def filter_modules(model_state_dict, modules):
Expand Down
11 changes: 8 additions & 3 deletions wenet/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def train(self, model, optimizer, scheduler, train_data_loader,
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
if info_dict.get("train_engine", "torch_ddp") == "torch_ddp" and \
(batch_idx + 1) % info_dict["accum_grad"] != 0:
if info_dict.get("train_engine", "torch_ddp") in [
"torch_ddp", "torch_fsdp"
] and (batch_idx + 1) % info_dict["accum_grad"] != 0:
context = model.no_sync
# Used for single gpu training and DDP gradient synchronization
# processes.
Expand All @@ -87,6 +88,9 @@ def train(self, model, optimizer, scheduler, train_data_loader,
save_interval = info_dict.get('save_interval', sys.maxsize)
if self.step % save_interval == 0 and self.step != 0 \
and (batch_idx + 1) % info_dict["accum_grad"] == 0:
import torch.distributed as dist
# Ensure all ranks start CV at the same time in step mode
dist.barrier()
loss_dict = self.cv(model, cv_data_loader, configs)
model.train()
info_dict.update({
Expand All @@ -100,11 +104,12 @@ def train(self, model, optimizer, scheduler, train_data_loader,
optimizer.param_groups[0]['lr']
})
save_model(model, info_dict)
# Ensure all ranks start Train at the same time in step mode
dist.barrier()
log_per_step(writer, info_dict, timer=self.train_step_timer)
self.step += 1 if (batch_idx +
1) % info_dict["accum_grad"] == 0 else 0


def cv(self, model, cv_data_loader, configs):
''' Cross validation on
'''
Expand Down
115 changes: 115 additions & 0 deletions wenet/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from functools import partial
import os
from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP,
FullStateDictConfig, StateDictType)

from torch.distributed.fsdp.wrap import (lambda_auto_wrap_policy,
transformer_auto_wrap_policy)
from wenet.branchformer.encoder_layer import BranchformerEncoderLayer
from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer
from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer
from wenet.paraformer.layers import AliParaformerEncoderLayer, SanmDecoderLayer
from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer
from wenet.transformer.encoder_layer import (ConformerEncoderLayer,
TransformerEncoderLayer)
from wenet.transformer.decoder_layer import DecoderLayer
from wenet.utils.checkpoint import save_state_dict_and_infos
from wenet.utils.init_model import WENET_DECODER_CLASSES, WENET_ENCODER_CLASSES

WENET_ENCODER_LAYERS_CLASSES = {
'transformer_encoder_layer': TransformerEncoderLayer,
'conformer_encoder_layer': ConformerEncoderLayer,
'paraformer_encoder_layer': AliParaformerEncoderLayer,
'squeezeformer_encoder_layer': SqueezeformerEncoderLayer,
'ebranchformer_encoder_layer': EBranchformerEncoderLayer,
'efficient_conformer_encoder_layer': StrideConformerEncoderLayer,
'branchformer_encoder_layer': BranchformerEncoderLayer,
}

WENET_DECODER_LAYERS_CLASSES = {
'transformer_decoder_layer': DecoderLayer,
'paraformer_decoder_layer': SanmDecoderLayer,
# TODO(Mddct):
# 1 wrap transducer's predictor and joint
# 2 wrap paraformer's cif and ignore lstm
}


def wenet_fsdp_wrap_policy(mode):
# different wrap methods
# please refer: https://openmmlab.medium.com/its-2023-is-pytorch-s-fsdp-the-best-choice-for-training-large-models-fe8d2848832f # noqa
assert mode in ['no_shard', 'model', 'zero2', 'zero3']
if mode == 'no_shard':
return None
else:
# TODO(Mddct): Support user customization
# see more wrap methods:
# https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/utils/fsdp_utils.py#L13 # noqa
if mode == 'model':
enc_dec_wrap_policy = partial(
lambda_auto_wrap_policy,
lambda_fn=lambda module: isinstance(
module,
tuple(WENET_ENCODER_CLASSES.values()) + tuple(
WENET_DECODER_CLASSES.values())))
return enc_dec_wrap_policy
else:
to_wrap_class = set()
to_wrap_class.update(set(WENET_ENCODER_LAYERS_CLASSES.values()))
to_wrap_class.update(set(WENET_DECODER_LAYERS_CLASSES.values()))
layers_wrap_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls=to_wrap_class)
return layers_wrap_policy


fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True,
rank0_only=True)


def fsdp_save_model(model, save_model_path, info_dict):
# TODO(Mddct); When the model is large, saving a model will take a long time.
# We only need to keep the sharding in an asynchronous manner, but it is
# good now. This feature will be supported when llm is supported in the future.

rank = int(os.environ.get('RANK', 0))
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT,
fullstate_save_policy):
state_dict = model.state_dict()
if rank == 0:
save_state_dict_and_infos(state_dict, save_model_path, info_dict)


def check_gradient_checkpoint(model):
ckpt_laye_types = []
if hasattr(model, 'encoder') and hasattr(model.encoder,
'gradient_checkpointing'):
if model.encoder.gradient_checkpointing:
model.encoder.gradient_checkpointing = False
ckpt_laye_types += list(WENET_ENCODER_LAYERS_CLASSES.values())
if hasattr(model, 'decoder') and hasattr(model.decoder,
'gradient_checkpointing'):
if model.decoder.gradient_checkpointing:
model.decoder.gradient_checkpointing = False
ckpt_laye_types += list(WENET_DECODER_LAYERS_CLASSES.values())
return tuple(ckpt_laye_types)


def apply_fsdp_checkpointing(model, ckpt_layer_types: tuple):
# NOTE(Mddct): torch.utils.checkpoint is currently incompatible with
# wenet's model mode. Using this writing method, Please refer to
# https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/policies/activation_checkpointing_functions.py#L21 # noqa
if len(ckpt_layer_types) == 0:
return
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
non_reentrant_wrapper = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=lambda submodule: isinstance(submodule, ckpt_layer_types))
Loading

0 comments on commit b8191ce

Please sign in to comment.