From 91b79eb52754aae0d98c2c4bfb8b0e775e3837cb Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 27 Sep 2024 10:24:29 -0700 Subject: [PATCH] Add ckpt contents check Signed-off-by: Alexandros Koumparoulis --- .../llm/megatron_mixtral_pretraining.py | 230 ++++++++++++++++-- 1 file changed, 213 insertions(+), 17 deletions(-) diff --git a/tests/collections/llm/megatron_mixtral_pretraining.py b/tests/collections/llm/megatron_mixtral_pretraining.py index f973f51bac18..1a668b1d3fef 100644 --- a/tests/collections/llm/megatron_mixtral_pretraining.py +++ b/tests/collections/llm/megatron_mixtral_pretraining.py @@ -1,18 +1,16 @@ import argparse import os -import torch - from pathlib import Path -from nemo.lightning import Trainer, MegatronStrategy -from nemo.collections.llm import PreTrainingDataModule, MixtralConfig8x3B, MixtralModel - -from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule as MegatronOptim, OptimizerConfig +import torch from megatron.core.distributed import DistributedDataParallelConfig as McoreDDPConfig -from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer -from nemo.collections.llm.api import train -from nemo.lightning import NeMoLogger +from nemo.collections.llm import MixtralConfig8x3B, MixtralModel, PreTrainingDataModule +from nemo.collections.llm.api import train +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.lightning import MegatronStrategy, NeMoLogger, Trainer +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule as MegatronOptim +from nemo.lightning.pytorch.optim.megatron import OptimizerConfig def tokenizer(vocab_path, merges_path): @@ -23,6 +21,32 @@ def tokenizer(vocab_path, merges_path): merges_file=merges_path, ) + +def load_dcp(ckpt_dir, torch_tensor=True): + from pathlib import Path + + import torch + import torch.distributed.checkpoint as dcp + from torch.distributed.checkpoint import FileSystemReader + + if not isinstance(ckpt_dir, Path): + ckpt_dir = Path(ckpt_dir) + fs_reader = FileSystemReader(ckpt_dir) + metadata = fs_reader.read_metadata() + + state_dict = { + k: torch.empty(tp.size, dtype=tp.properties.dtype) + for k, tp in metadata.state_dict_metadata.items() + if type(tp).__name__ == 'TensorStorageMetadata' + } + + dcp.load( + state_dict, + storage_reader=fs_reader, + ) + return state_dict + + def main(args): strategy = MegatronStrategy( expert_model_parallel_size=args.devices, @@ -34,12 +58,12 @@ def main(args): autocast_dtype=torch.float32, precision=torch.bfloat16, ddp=McoreDDPConfig( - grad_reduce_in_fp32=True, - overlap_grad_reduce=False, - use_distributed_optimizer=True, - check_for_nan_in_grad=True, - bucket_size=None, - ) + grad_reduce_in_fp32=True, + overlap_grad_reduce=False, + use_distributed_optimizer=True, + check_for_nan_in_grad=True, + bucket_size=None, + ), ) trainer = Trainer( @@ -80,7 +104,7 @@ def main(args): params_dtype=torch.bfloat16, pipeline_dtype=torch.bfloat16, ) - mixtral_config.overlap_param_gather_with_optimizer_step=True + mixtral_config.overlap_param_gather_with_optimizer_step = True optim_config = OptimizerConfig( fp16=False, @@ -119,6 +143,7 @@ def main(args): optim=opt, ) + # Confirm checkpoint directory structure output_path = Path(args.experiment_dir) / "checkpoints/--None=0.0000-epoch=0/" assert output_path.exists(), f"Expected {output_path} to exist" assert output_path.is_dir(), f"Expected {output_path} to be a directory" @@ -133,11 +158,181 @@ def main(args): for file in os.listdir(output_path): assert file in output_files, f"Got unexpected {file} in checkpoint directory" + # Finally confirm checkpoint contents + expected_ckpt = { + "module.embedding.word_embeddings.weight": (torch.Size([50304, 128]), torch.bfloat16, "cpu"), + "module.decoder.layers.self_attention.linear_proj.weight": (torch.Size([2, 128, 128]), torch.bfloat16, "cpu"), + "module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": ( + torch.Size([2, 128]), + torch.bfloat16, + "cpu", + ), + "module.decoder.layers.self_attention.linear_qkv.weight": (torch.Size([2, 384, 128]), torch.bfloat16, "cpu"), + "module.decoder.layers.pre_mlp_layernorm.weight": (torch.Size([2, 128]), torch.bfloat16, "cpu"), + "module.decoder.layers.mlp.router.weight": (torch.Size([2, 8, 128]), torch.bfloat16, "cpu"), + "module.decoder.layers.mlp.experts.experts.linear_fc1.weight": ( + torch.Size([2, 8, 640, 128]), + torch.bfloat16, + "cpu", + ), + "module.decoder.layers.mlp.experts.experts.linear_fc2.weight": ( + torch.Size([2, 8, 128, 320]), + torch.bfloat16, + "cpu", + ), + "module.decoder.final_layernorm.weight": (torch.Size([128]), torch.bfloat16, "cpu"), + "module.output_layer.weight": (torch.Size([50304, 128]), torch.bfloat16, "cpu"), + "optimizer.state.fp32_param.module.output_layer.weight": (torch.Size([1, 1, 6438912]), torch.float32, "cpu"), + "optimizer.state.exp_avg.module.output_layer.weight": (torch.Size([1, 1, 6438912]), torch.float32, "cpu"), + "optimizer.state.exp_avg_sq.module.output_layer.weight": (torch.Size([1, 1, 6438912]), torch.float32, "cpu"), + "optimizer.state.fp32_param.module.decoder.final_layernorm.weight": (torch.Size([128]), torch.float32, "cpu"), + "optimizer.state.exp_avg.module.decoder.final_layernorm.weight": (torch.Size([128]), torch.float32, "cpu"), + "optimizer.state.exp_avg_sq.module.decoder.final_layernorm.weight": (torch.Size([128]), torch.float32, "cpu"), + "optimizer.state.fp32_param.module.decoder.layers.mlp.experts.experts.linear_fc2.weight": ( + torch.Size([2, 8, 1, 1, 40960]), + torch.float32, + "cpu", + ), + "optimizer.state.exp_avg.module.decoder.layers.mlp.experts.experts.linear_fc2.weight": ( + torch.Size([2, 8, 1, 1, 40960]), + torch.float32, + "cpu", + ), + "optimizer.state.exp_avg_sq.module.decoder.layers.mlp.experts.experts.linear_fc2.weight": ( + torch.Size([2, 8, 1, 1, 40960]), + torch.float32, + "cpu", + ), + "optimizer.state.fp32_param.module.decoder.layers.mlp.experts.experts.linear_fc1.weight": ( + torch.Size([2, 8, 2, 1, 40960]), + torch.float32, + "cpu", + ), + "optimizer.state.exp_avg.module.decoder.layers.mlp.experts.experts.linear_fc1.weight": ( + torch.Size([2, 8, 2, 1, 40960]), + torch.float32, + "cpu", + ), + "optimizer.state.exp_avg_sq.module.decoder.layers.mlp.experts.experts.linear_fc1.weight": ( + torch.Size([2, 8, 2, 1, 40960]), + torch.float32, + "cpu", + ), + "optimizer.state.fp32_param.module.decoder.layers.mlp.router.weight": ( + torch.Size([2, 1, 1, 1024]), + torch.float32, + "cpu", + ), + "optimizer.state.exp_avg.module.decoder.layers.mlp.router.weight": ( + torch.Size([2, 1, 1, 1024]), + torch.float32, + "cpu", + ), + "optimizer.state.exp_avg_sq.module.decoder.layers.mlp.router.weight": ( + torch.Size([2, 1, 1, 1024]), + torch.float32, + "cpu", + ), + "optimizer.state.fp32_param.module.decoder.layers.pre_mlp_layernorm.weight": ( + torch.Size([2, 1, 128]), + torch.float32, + "cpu", + ), + "optimizer.state.exp_avg.module.decoder.layers.pre_mlp_layernorm.weight": ( + torch.Size([2, 1, 128]), + torch.float32, + "cpu", + ), + "optimizer.state.exp_avg_sq.module.decoder.layers.pre_mlp_layernorm.weight": ( + torch.Size([2, 1, 128]), + torch.float32, + "cpu", + ), + "optimizer.state.fp32_param.module.decoder.layers.self_attention.linear_qkv.weight": ( + torch.Size([2, 1, 1, 49152]), + torch.float32, + "cpu", + ), + "optimizer.state.exp_avg.module.decoder.layers.self_attention.linear_qkv.weight": ( + torch.Size([2, 1, 1, 49152]), + torch.float32, + "cpu", + ), + "optimizer.state.exp_avg_sq.module.decoder.layers.self_attention.linear_qkv.weight": ( + torch.Size([2, 1, 1, 49152]), + torch.float32, + "cpu", + ), + "optimizer.state.fp32_param.module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": ( + torch.Size([2, 1, 128]), + torch.float32, + "cpu", + ), + "optimizer.state.exp_avg.module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": ( + torch.Size([2, 1, 128]), + torch.float32, + "cpu", + ), + "optimizer.state.exp_avg_sq.module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": ( + torch.Size([2, 1, 128]), + torch.float32, + "cpu", + ), + "optimizer.state.fp32_param.module.decoder.layers.self_attention.linear_proj.weight": ( + torch.Size([2, 1, 1, 16384]), + torch.float32, + "cpu", + ), + "optimizer.state.exp_avg.module.decoder.layers.self_attention.linear_proj.weight": ( + torch.Size([2, 1, 1, 16384]), + torch.float32, + "cpu", + ), + "optimizer.state.exp_avg_sq.module.decoder.layers.self_attention.linear_proj.weight": ( + torch.Size([2, 1, 1, 16384]), + torch.float32, + "cpu", + ), + "optimizer.state.fp32_param.module.embedding.word_embeddings.weight": ( + torch.Size([1, 1, 6438912]), + torch.float32, + "cpu", + ), + "optimizer.state.exp_avg.module.embedding.word_embeddings.weight": ( + torch.Size([1, 1, 6438912]), + torch.float32, + "cpu", + ), + "optimizer.state.exp_avg_sq.module.embedding.word_embeddings.weight": ( + torch.Size([1, 1, 6438912]), + torch.float32, + "cpu", + ), + } + ckpt = load_dcp(output_path) + ckpt_keys = set(ckpt.keys()) + expected_keys = set(expected_ckpt.keys()) + assert len(ckpt) == len(expected_ckpt), ( + "Checkpoint length mismatch ", + len(ckpt), + len(expected_ckpt), + ckpt_keys - expected_keys, + ) + for key, (shape, dtype, device) in expected_ckpt.items(): + assert key in ckpt, f"Expected {key} to be in ckpt" + assert isinstance(ckpt[key], torch.Tensor), f"Expected {key} to be a tensor" + assert ckpt[key].shape == shape, f"Expected {key} shapes to match {ckpt[key].shape} & {shape}" + assert ckpt[key].dtype == dtype, f"Expected {key} dtype to match {ckpt[key].dtype} & {dtype}" + assert str(ckpt[key].device) == device, f"Expected {key} device to match {ckpt[key].device} & {device}" + + def parse_args(): parser = argparse.ArgumentParser(description='Train a small Mixtral model using NeMo 2.0') parser.add_argument('--devices', type=int, default=1, help="Number of devices to use for training") parser.add_argument('--max-steps', type=int, default=4, help="Number of steps to train for") - parser.add_argument('--experiment-dir', type=str, default='/tmp/exp_dir', help="directory to write results and checkpoints to") + parser.add_argument( + '--experiment-dir', type=str, default='/tmp/exp_dir', help="directory to write results and checkpoints to" + ) parser.add_argument('--experiment-name', type=str, default='mini_mixtral_test', help="name of experiment") parser.add_argument('--data-path', type=str, help="Path to data file") parser.add_argument('--vocab-path', type=str, default=None, help="Path to vocab file") @@ -145,5 +340,6 @@ def parse_args(): return parser.parse_args() + if __name__ == "__main__": main(parse_args())