Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>
  • Loading branch information
akoumpa committed Sep 27, 2024
1 parent 706e9c1 commit db987c8
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions tests/collections/llm/megatron_mixtral_pretraining.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import argparse
import os
import torch

from pathlib import Path

from nemo.lightning import Trainer, MegatronStrategy
from nemo.collections.llm import PreTrainingDataModule, MixtralConfig8x3B, MixtralModel

from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule as MegatronOptim, OptimizerConfig
import torch
from megatron.core.distributed import DistributedDataParallelConfig as McoreDDPConfig
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.collections.llm.api import train
from nemo.lightning import NeMoLogger

from nemo.collections.llm import MixtralConfig8x3B, MixtralModel, PreTrainingDataModule
from nemo.collections.llm.api import train
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.lightning import MegatronStrategy, NeMoLogger, Trainer
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule as MegatronOptim
from nemo.lightning.pytorch.optim.megatron import OptimizerConfig


def tokenizer(vocab_path, merges_path):
Expand All @@ -23,6 +21,7 @@ def tokenizer(vocab_path, merges_path):
merges_file=merges_path,
)


def main(args):
strategy = MegatronStrategy(
expert_model_parallel_size=args.devices,
Expand All @@ -34,12 +33,12 @@ def main(args):
autocast_dtype=torch.float32,
precision=torch.bfloat16,
ddp=McoreDDPConfig(
grad_reduce_in_fp32=True,
overlap_grad_reduce=False,
use_distributed_optimizer=True,
check_for_nan_in_grad=True,
bucket_size=None,
)
grad_reduce_in_fp32=True,
overlap_grad_reduce=False,
use_distributed_optimizer=True,
check_for_nan_in_grad=True,
bucket_size=None,
),
)

trainer = Trainer(
Expand Down Expand Up @@ -80,7 +79,7 @@ def main(args):
params_dtype=torch.bfloat16,
pipeline_dtype=torch.bfloat16,
)
mixtral_config.overlap_param_gather_with_optimizer_step=True
mixtral_config.overlap_param_gather_with_optimizer_step = True

optim_config = OptimizerConfig(
fp16=False,
Expand Down Expand Up @@ -133,17 +132,21 @@ def main(args):
for file in os.listdir(output_path):
assert file in output_files, f"Got unexpected {file} in checkpoint directory"


def parse_args():
parser = argparse.ArgumentParser(description='Train a small Mixtral model using NeMo 2.0')
parser.add_argument('--devices', type=int, default=1, help="Number of devices to use for training")
parser.add_argument('--max-steps', type=int, default=4, help="Number of steps to train for")
parser.add_argument('--experiment-dir', type=str, default='/tmp/exp_dir', help="directory to write results and checkpoints to")
parser.add_argument(
'--experiment-dir', type=str, default='/tmp/exp_dir', help="directory to write results and checkpoints to"
)
parser.add_argument('--experiment-name', type=str, default='mini_mixtral_test', help="name of experiment")
parser.add_argument('--data-path', type=str, help="Path to data file")
parser.add_argument('--vocab-path', type=str, default=None, help="Path to vocab file")
parser.add_argument('--merges-path', type=str, default=None, help="Path to merges file")

return parser.parse_args()


if __name__ == "__main__":
main(parse_args())

0 comments on commit db987c8

Please sign in to comment.