Skip to content

Commit

Permalink
Add ckpt contents check
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
  • Loading branch information
akoumpa committed Sep 27, 2024
1 parent 706e9c1 commit 91b79eb
Showing 1 changed file with 213 additions and 17 deletions.
230 changes: 213 additions & 17 deletions tests/collections/llm/megatron_mixtral_pretraining.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import argparse
import os
import torch

from pathlib import Path

from nemo.lightning import Trainer, MegatronStrategy
from nemo.collections.llm import PreTrainingDataModule, MixtralConfig8x3B, MixtralModel

from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule as MegatronOptim, OptimizerConfig
import torch
from megatron.core.distributed import DistributedDataParallelConfig as McoreDDPConfig
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.collections.llm.api import train
from nemo.lightning import NeMoLogger

from nemo.collections.llm import MixtralConfig8x3B, MixtralModel, PreTrainingDataModule
from nemo.collections.llm.api import train
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.lightning import MegatronStrategy, NeMoLogger, Trainer
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule as MegatronOptim
from nemo.lightning.pytorch.optim.megatron import OptimizerConfig


def tokenizer(vocab_path, merges_path):
Expand All @@ -23,6 +21,32 @@ def tokenizer(vocab_path, merges_path):
merges_file=merges_path,
)


def load_dcp(ckpt_dir, torch_tensor=True):
from pathlib import Path

import torch
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint import FileSystemReader

if not isinstance(ckpt_dir, Path):
ckpt_dir = Path(ckpt_dir)
fs_reader = FileSystemReader(ckpt_dir)
metadata = fs_reader.read_metadata()

state_dict = {
k: torch.empty(tp.size, dtype=tp.properties.dtype)
for k, tp in metadata.state_dict_metadata.items()
if type(tp).__name__ == 'TensorStorageMetadata'
}

dcp.load(
state_dict,
storage_reader=fs_reader,
)
return state_dict


def main(args):
strategy = MegatronStrategy(
expert_model_parallel_size=args.devices,
Expand All @@ -34,12 +58,12 @@ def main(args):
autocast_dtype=torch.float32,
precision=torch.bfloat16,
ddp=McoreDDPConfig(
grad_reduce_in_fp32=True,
overlap_grad_reduce=False,
use_distributed_optimizer=True,
check_for_nan_in_grad=True,
bucket_size=None,
)
grad_reduce_in_fp32=True,
overlap_grad_reduce=False,
use_distributed_optimizer=True,
check_for_nan_in_grad=True,
bucket_size=None,
),
)

trainer = Trainer(
Expand Down Expand Up @@ -80,7 +104,7 @@ def main(args):
params_dtype=torch.bfloat16,
pipeline_dtype=torch.bfloat16,
)
mixtral_config.overlap_param_gather_with_optimizer_step=True
mixtral_config.overlap_param_gather_with_optimizer_step = True

optim_config = OptimizerConfig(
fp16=False,
Expand Down Expand Up @@ -119,6 +143,7 @@ def main(args):
optim=opt,
)

# Confirm checkpoint directory structure
output_path = Path(args.experiment_dir) / "checkpoints/--None=0.0000-epoch=0/"
assert output_path.exists(), f"Expected {output_path} to exist"
assert output_path.is_dir(), f"Expected {output_path} to be a directory"
Expand All @@ -133,17 +158,188 @@ def main(args):
for file in os.listdir(output_path):
assert file in output_files, f"Got unexpected {file} in checkpoint directory"

# Finally confirm checkpoint contents
expected_ckpt = {
"module.embedding.word_embeddings.weight": (torch.Size([50304, 128]), torch.bfloat16, "cpu"),
"module.decoder.layers.self_attention.linear_proj.weight": (torch.Size([2, 128, 128]), torch.bfloat16, "cpu"),
"module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": (
torch.Size([2, 128]),
torch.bfloat16,
"cpu",
),
"module.decoder.layers.self_attention.linear_qkv.weight": (torch.Size([2, 384, 128]), torch.bfloat16, "cpu"),
"module.decoder.layers.pre_mlp_layernorm.weight": (torch.Size([2, 128]), torch.bfloat16, "cpu"),
"module.decoder.layers.mlp.router.weight": (torch.Size([2, 8, 128]), torch.bfloat16, "cpu"),
"module.decoder.layers.mlp.experts.experts.linear_fc1.weight": (
torch.Size([2, 8, 640, 128]),
torch.bfloat16,
"cpu",
),
"module.decoder.layers.mlp.experts.experts.linear_fc2.weight": (
torch.Size([2, 8, 128, 320]),
torch.bfloat16,
"cpu",
),
"module.decoder.final_layernorm.weight": (torch.Size([128]), torch.bfloat16, "cpu"),
"module.output_layer.weight": (torch.Size([50304, 128]), torch.bfloat16, "cpu"),
"optimizer.state.fp32_param.module.output_layer.weight": (torch.Size([1, 1, 6438912]), torch.float32, "cpu"),
"optimizer.state.exp_avg.module.output_layer.weight": (torch.Size([1, 1, 6438912]), torch.float32, "cpu"),
"optimizer.state.exp_avg_sq.module.output_layer.weight": (torch.Size([1, 1, 6438912]), torch.float32, "cpu"),
"optimizer.state.fp32_param.module.decoder.final_layernorm.weight": (torch.Size([128]), torch.float32, "cpu"),
"optimizer.state.exp_avg.module.decoder.final_layernorm.weight": (torch.Size([128]), torch.float32, "cpu"),
"optimizer.state.exp_avg_sq.module.decoder.final_layernorm.weight": (torch.Size([128]), torch.float32, "cpu"),
"optimizer.state.fp32_param.module.decoder.layers.mlp.experts.experts.linear_fc2.weight": (
torch.Size([2, 8, 1, 1, 40960]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg.module.decoder.layers.mlp.experts.experts.linear_fc2.weight": (
torch.Size([2, 8, 1, 1, 40960]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg_sq.module.decoder.layers.mlp.experts.experts.linear_fc2.weight": (
torch.Size([2, 8, 1, 1, 40960]),
torch.float32,
"cpu",
),
"optimizer.state.fp32_param.module.decoder.layers.mlp.experts.experts.linear_fc1.weight": (
torch.Size([2, 8, 2, 1, 40960]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg.module.decoder.layers.mlp.experts.experts.linear_fc1.weight": (
torch.Size([2, 8, 2, 1, 40960]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg_sq.module.decoder.layers.mlp.experts.experts.linear_fc1.weight": (
torch.Size([2, 8, 2, 1, 40960]),
torch.float32,
"cpu",
),
"optimizer.state.fp32_param.module.decoder.layers.mlp.router.weight": (
torch.Size([2, 1, 1, 1024]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg.module.decoder.layers.mlp.router.weight": (
torch.Size([2, 1, 1, 1024]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg_sq.module.decoder.layers.mlp.router.weight": (
torch.Size([2, 1, 1, 1024]),
torch.float32,
"cpu",
),
"optimizer.state.fp32_param.module.decoder.layers.pre_mlp_layernorm.weight": (
torch.Size([2, 1, 128]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg.module.decoder.layers.pre_mlp_layernorm.weight": (
torch.Size([2, 1, 128]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg_sq.module.decoder.layers.pre_mlp_layernorm.weight": (
torch.Size([2, 1, 128]),
torch.float32,
"cpu",
),
"optimizer.state.fp32_param.module.decoder.layers.self_attention.linear_qkv.weight": (
torch.Size([2, 1, 1, 49152]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg.module.decoder.layers.self_attention.linear_qkv.weight": (
torch.Size([2, 1, 1, 49152]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg_sq.module.decoder.layers.self_attention.linear_qkv.weight": (
torch.Size([2, 1, 1, 49152]),
torch.float32,
"cpu",
),
"optimizer.state.fp32_param.module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": (
torch.Size([2, 1, 128]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg.module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": (
torch.Size([2, 1, 128]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg_sq.module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": (
torch.Size([2, 1, 128]),
torch.float32,
"cpu",
),
"optimizer.state.fp32_param.module.decoder.layers.self_attention.linear_proj.weight": (
torch.Size([2, 1, 1, 16384]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg.module.decoder.layers.self_attention.linear_proj.weight": (
torch.Size([2, 1, 1, 16384]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg_sq.module.decoder.layers.self_attention.linear_proj.weight": (
torch.Size([2, 1, 1, 16384]),
torch.float32,
"cpu",
),
"optimizer.state.fp32_param.module.embedding.word_embeddings.weight": (
torch.Size([1, 1, 6438912]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg.module.embedding.word_embeddings.weight": (
torch.Size([1, 1, 6438912]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg_sq.module.embedding.word_embeddings.weight": (
torch.Size([1, 1, 6438912]),
torch.float32,
"cpu",
),
}
ckpt = load_dcp(output_path)
ckpt_keys = set(ckpt.keys())
expected_keys = set(expected_ckpt.keys())
assert len(ckpt) == len(expected_ckpt), (
"Checkpoint length mismatch ",
len(ckpt),
len(expected_ckpt),
ckpt_keys - expected_keys,
)
for key, (shape, dtype, device) in expected_ckpt.items():
assert key in ckpt, f"Expected {key} to be in ckpt"
assert isinstance(ckpt[key], torch.Tensor), f"Expected {key} to be a tensor"
assert ckpt[key].shape == shape, f"Expected {key} shapes to match {ckpt[key].shape} & {shape}"
assert ckpt[key].dtype == dtype, f"Expected {key} dtype to match {ckpt[key].dtype} & {dtype}"
assert str(ckpt[key].device) == device, f"Expected {key} device to match {ckpt[key].device} & {device}"


def parse_args():
parser = argparse.ArgumentParser(description='Train a small Mixtral model using NeMo 2.0')
parser.add_argument('--devices', type=int, default=1, help="Number of devices to use for training")
parser.add_argument('--max-steps', type=int, default=4, help="Number of steps to train for")
parser.add_argument('--experiment-dir', type=str, default='/tmp/exp_dir', help="directory to write results and checkpoints to")
parser.add_argument(
'--experiment-dir', type=str, default='/tmp/exp_dir', help="directory to write results and checkpoints to"
)
parser.add_argument('--experiment-name', type=str, default='mini_mixtral_test', help="name of experiment")
parser.add_argument('--data-path', type=str, help="Path to data file")
parser.add_argument('--vocab-path', type=str, default=None, help="Path to vocab file")
parser.add_argument('--merges-path', type=str, default=None, help="Path to merges file")

return parser.parse_args()


if __name__ == "__main__":
main(parse_args())

0 comments on commit 91b79eb

Please sign in to comment.