Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>
  • Loading branch information
akoumpa committed Sep 27, 2024
1 parent d37c723 commit daf0cf7
Showing 1 changed file with 166 additions and 47 deletions.
213 changes: 166 additions & 47 deletions tests/collections/llm/megatron_mixtral_pretraining.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import argparse
import os
import torch

from pathlib import Path

from nemo.lightning import Trainer, MegatronStrategy
from nemo.collections.llm import PreTrainingDataModule, MixtralConfig8x3B, MixtralModel

from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule as MegatronOptim, OptimizerConfig
import torch
from megatron.core.distributed import DistributedDataParallelConfig as McoreDDPConfig
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.collections.llm.api import train
from nemo.lightning import NeMoLogger

from nemo.collections.llm import MixtralConfig8x3B, MixtralModel, PreTrainingDataModule
from nemo.collections.llm.api import train
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.lightning import MegatronStrategy, NeMoLogger, Trainer
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule as MegatronOptim
from nemo.lightning.pytorch.optim.megatron import OptimizerConfig


def tokenizer(vocab_path, merges_path):
Expand All @@ -23,12 +21,16 @@ def tokenizer(vocab_path, merges_path):
merges_file=merges_path,
)


def load_dcp(ckpt_dir, torch_tensor=True):
from pathlib import Path
from torch.distributed.checkpoint import FileSystemReader

import torch
import torch.distributed.checkpoint as dcp
if not isinstance(ckpt_dir, Path): ckpt_dir = Path(ckpt_dir)
from torch.distributed.checkpoint import FileSystemReader

if not isinstance(ckpt_dir, Path):
ckpt_dir = Path(ckpt_dir)
fs_reader = FileSystemReader(ckpt_dir)
metadata = fs_reader.read_metadata()

Expand All @@ -51,6 +53,7 @@ def load_dcp(ckpt_dir, torch_tensor=True):
state_dict[k] = v.numpy()
return state_dict


def main(args):
strategy = MegatronStrategy(
expert_model_parallel_size=args.devices,
Expand All @@ -62,12 +65,12 @@ def main(args):
autocast_dtype=torch.float32,
precision=torch.bfloat16,
ddp=McoreDDPConfig(
grad_reduce_in_fp32=True,
overlap_grad_reduce=False,
use_distributed_optimizer=True,
check_for_nan_in_grad=True,
bucket_size=None,
)
grad_reduce_in_fp32=True,
overlap_grad_reduce=False,
use_distributed_optimizer=True,
check_for_nan_in_grad=True,
bucket_size=None,
),
)

trainer = Trainer(
Expand Down Expand Up @@ -108,7 +111,7 @@ def main(args):
params_dtype=torch.bfloat16,
pipeline_dtype=torch.bfloat16,
)
mixtral_config.overlap_param_gather_with_optimizer_step=True
mixtral_config.overlap_param_gather_with_optimizer_step = True

optim_config = OptimizerConfig(
fp16=False,
Expand Down Expand Up @@ -166,12 +169,24 @@ def main(args):
expected_ckpt = {
"module.embedding.word_embeddings.weight": (torch.Size([50304, 128]), torch.bfloat16, "cpu"),
"module.decoder.layers.self_attention.linear_proj.weight": (torch.Size([2, 128, 128]), torch.bfloat16, "cpu"),
"module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": (torch.Size([2, 128]), torch.bfloat16, "cpu"),
"module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": (
torch.Size([2, 128]),
torch.bfloat16,
"cpu",
),
"module.decoder.layers.self_attention.linear_qkv.weight": (torch.Size([2, 384, 128]), torch.bfloat16, "cpu"),
"module.decoder.layers.pre_mlp_layernorm.weight": (torch.Size([2, 128]), torch.bfloat16, "cpu"),
"module.decoder.layers.mlp.router.weight": (torch.Size([2, 8, 128]), torch.bfloat16, "cpu"),
"module.decoder.layers.mlp.experts.experts.linear_fc1.weight": (torch.Size([2, 8, 640, 128]), torch.bfloat16, "cpu"),
"module.decoder.layers.mlp.experts.experts.linear_fc2.weight": (torch.Size([2, 8, 128, 320]), torch.bfloat16, "cpu"),
"module.decoder.layers.mlp.experts.experts.linear_fc1.weight": (
torch.Size([2, 8, 640, 128]),
torch.bfloat16,
"cpu",
),
"module.decoder.layers.mlp.experts.experts.linear_fc2.weight": (
torch.Size([2, 8, 128, 320]),
torch.bfloat16,
"cpu",
),
"module.decoder.final_layernorm.weight": (torch.Size([128]), torch.bfloat16, "cpu"),
"module.output_layer.weight": (torch.Size([50304, 128]), torch.bfloat16, "cpu"),
"optimizer.state.fp32_param.module.output_layer.weight": (torch.Size([1, 1, 6438912]), torch.float32, "cpu"),
Expand All @@ -180,35 +195,136 @@ def main(args):
"optimizer.state.fp32_param.module.decoder.final_layernorm.weight": (torch.Size([128]), torch.float32, "cpu"),
"optimizer.state.exp_avg.module.decoder.final_layernorm.weight": (torch.Size([128]), torch.float32, "cpu"),
"optimizer.state.exp_avg_sq.module.decoder.final_layernorm.weight": (torch.Size([128]), torch.float32, "cpu"),
"optimizer.state.fp32_param.module.decoder.layers.mlp.experts.experts.linear_fc2.weight": (torch.Size([2, 8, 1, 1, 40960]), torch.float32, "cpu"),
"optimizer.state.exp_avg.module.decoder.layers.mlp.experts.experts.linear_fc2.weight": (torch.Size([2, 8, 1, 1, 40960]), torch.float32, "cpu"),
"optimizer.state.exp_avg_sq.module.decoder.layers.mlp.experts.experts.linear_fc2.weight": (torch.Size([2, 8, 1, 1, 40960]), torch.float32, "cpu"),
"optimizer.state.fp32_param.module.decoder.layers.mlp.experts.experts.linear_fc1.weight": (torch.Size([2, 8, 2, 1, 40960]), torch.float32, "cpu"),
"optimizer.state.exp_avg.module.decoder.layers.mlp.experts.experts.linear_fc1.weight": (torch.Size([2, 8, 2, 1, 40960]), torch.float32, "cpu"),
"optimizer.state.exp_avg_sq.module.decoder.layers.mlp.experts.experts.linear_fc1.weight": (torch.Size([2, 8, 2, 1, 40960]), torch.float32, "cpu"),
"optimizer.state.fp32_param.module.decoder.layers.mlp.router.weight": (torch.Size([2, 1, 1, 1024]), torch.float32, "cpu"),
"optimizer.state.exp_avg.module.decoder.layers.mlp.router.weight": (torch.Size([2, 1, 1, 1024]), torch.float32, "cpu"),
"optimizer.state.exp_avg_sq.module.decoder.layers.mlp.router.weight": (torch.Size([2, 1, 1, 1024]), torch.float32, "cpu"),
"optimizer.state.fp32_param.module.decoder.layers.pre_mlp_layernorm.weight": (torch.Size([2, 1, 128]), torch.float32, "cpu"),
"optimizer.state.exp_avg.module.decoder.layers.pre_mlp_layernorm.weight": (torch.Size([2, 1, 128]), torch.float32, "cpu"),
"optimizer.state.exp_avg_sq.module.decoder.layers.pre_mlp_layernorm.weight": (torch.Size([2, 1, 128]), torch.float32, "cpu"),
"optimizer.state.fp32_param.module.decoder.layers.self_attention.linear_qkv.weight": (torch.Size([2, 1, 1, 49152]), torch.float32, "cpu"),
"optimizer.state.exp_avg.module.decoder.layers.self_attention.linear_qkv.weight": (torch.Size([2, 1, 1, 49152]), torch.float32, "cpu"),
"optimizer.state.exp_avg_sq.module.decoder.layers.self_attention.linear_qkv.weight": (torch.Size([2, 1, 1, 49152]), torch.float32, "cpu"),
"optimizer.state.fp32_param.module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": (torch.Size([2, 1, 128]), torch.float32, "cpu"),
"optimizer.state.exp_avg.module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": (torch.Size([2, 1, 128]), torch.float32, "cpu"),
"optimizer.state.exp_avg_sq.module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": (torch.Size([2, 1, 128]), torch.float32, "cpu"),
"optimizer.state.fp32_param.module.decoder.layers.self_attention.linear_proj.weight": (torch.Size([2, 1, 1, 16384]), torch.float32, "cpu"),
"optimizer.state.exp_avg.module.decoder.layers.self_attention.linear_proj.weight": (torch.Size([2, 1, 1, 16384]), torch.float32, "cpu"),
"optimizer.state.exp_avg_sq.module.decoder.layers.self_attention.linear_proj.weight": (torch.Size([2, 1, 1, 16384]), torch.float32, "cpu"),
"optimizer.state.fp32_param.module.embedding.word_embeddings.weight": (torch.Size([1, 1, 6438912]), torch.float32, "cpu"),
"optimizer.state.exp_avg.module.embedding.word_embeddings.weight": (torch.Size([1, 1, 6438912]), torch.float32, "cpu"),
"optimizer.state.exp_avg_sq.module.embedding.word_embeddings.weight": (torch.Size([1, 1, 6438912]), torch.float32, "cpu"),
"optimizer.state.fp32_param.module.decoder.layers.mlp.experts.experts.linear_fc2.weight": (
torch.Size([2, 8, 1, 1, 40960]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg.module.decoder.layers.mlp.experts.experts.linear_fc2.weight": (
torch.Size([2, 8, 1, 1, 40960]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg_sq.module.decoder.layers.mlp.experts.experts.linear_fc2.weight": (
torch.Size([2, 8, 1, 1, 40960]),
torch.float32,
"cpu",
),
"optimizer.state.fp32_param.module.decoder.layers.mlp.experts.experts.linear_fc1.weight": (
torch.Size([2, 8, 2, 1, 40960]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg.module.decoder.layers.mlp.experts.experts.linear_fc1.weight": (
torch.Size([2, 8, 2, 1, 40960]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg_sq.module.decoder.layers.mlp.experts.experts.linear_fc1.weight": (
torch.Size([2, 8, 2, 1, 40960]),
torch.float32,
"cpu",
),
"optimizer.state.fp32_param.module.decoder.layers.mlp.router.weight": (
torch.Size([2, 1, 1, 1024]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg.module.decoder.layers.mlp.router.weight": (
torch.Size([2, 1, 1, 1024]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg_sq.module.decoder.layers.mlp.router.weight": (
torch.Size([2, 1, 1, 1024]),
torch.float32,
"cpu",
),
"optimizer.state.fp32_param.module.decoder.layers.pre_mlp_layernorm.weight": (
torch.Size([2, 1, 128]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg.module.decoder.layers.pre_mlp_layernorm.weight": (
torch.Size([2, 1, 128]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg_sq.module.decoder.layers.pre_mlp_layernorm.weight": (
torch.Size([2, 1, 128]),
torch.float32,
"cpu",
),
"optimizer.state.fp32_param.module.decoder.layers.self_attention.linear_qkv.weight": (
torch.Size([2, 1, 1, 49152]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg.module.decoder.layers.self_attention.linear_qkv.weight": (
torch.Size([2, 1, 1, 49152]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg_sq.module.decoder.layers.self_attention.linear_qkv.weight": (
torch.Size([2, 1, 1, 49152]),
torch.float32,
"cpu",
),
"optimizer.state.fp32_param.module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": (
torch.Size([2, 1, 128]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg.module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": (
torch.Size([2, 1, 128]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg_sq.module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": (
torch.Size([2, 1, 128]),
torch.float32,
"cpu",
),
"optimizer.state.fp32_param.module.decoder.layers.self_attention.linear_proj.weight": (
torch.Size([2, 1, 1, 16384]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg.module.decoder.layers.self_attention.linear_proj.weight": (
torch.Size([2, 1, 1, 16384]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg_sq.module.decoder.layers.self_attention.linear_proj.weight": (
torch.Size([2, 1, 1, 16384]),
torch.float32,
"cpu",
),
"optimizer.state.fp32_param.module.embedding.word_embeddings.weight": (
torch.Size([1, 1, 6438912]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg.module.embedding.word_embeddings.weight": (
torch.Size([1, 1, 6438912]),
torch.float32,
"cpu",
),
"optimizer.state.exp_avg_sq.module.embedding.word_embeddings.weight": (
torch.Size([1, 1, 6438912]),
torch.float32,
"cpu",
),
}
ckpt = load_dcp(output_path)
ckpt_keys = set(ckpt.keys())
expected_keys = set(expected_ckpt.keys())
assert len(ckpt) == len(expected_ckpt), ("Checkpoint length mismatch ", len(ckpt), len(expected_ckpt), ckpt_keys - expected_keys)
assert len(ckpt) == len(expected_ckpt), (
"Checkpoint length mismatch ",
len(ckpt),
len(expected_ckpt),
ckpt_keys - expected_keys,
)
for key, (shape, dtype, device) in expected_ckpt.items():
assert key in ckpt, f"Expected {key} to be in ckpt"
assert isinstance(ckpt[key], torch.Tensor), f"Expected {key} to be a tensor"
Expand All @@ -221,13 +337,16 @@ def parse_args():
parser = argparse.ArgumentParser(description='Train a small Mixtral model using NeMo 2.0')
parser.add_argument('--devices', type=int, default=1, help="Number of devices to use for training")
parser.add_argument('--max-steps', type=int, default=4, help="Number of steps to train for")
parser.add_argument('--experiment-dir', type=str, default='/tmp/exp_dir', help="directory to write results and checkpoints to")
parser.add_argument(
'--experiment-dir', type=str, default='/tmp/exp_dir', help="directory to write results and checkpoints to"
)
parser.add_argument('--experiment-name', type=str, default='mini_mixtral_test', help="name of experiment")
parser.add_argument('--data-path', type=str, help="Path to data file")
parser.add_argument('--vocab-path', type=str, default=None, help="Path to vocab file")
parser.add_argument('--merges-path', type=str, default=None, help="Path to merges file")

return parser.parse_args()


if __name__ == "__main__":
main(parse_args())

0 comments on commit daf0cf7

Please sign in to comment.