From daf25493ad078542fcace6965b538a7db77fb067 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Fri, 27 Sep 2024 06:25:21 -0700 Subject: [PATCH] add mistral 12b recipe for 24.09 Signed-off-by: dimapihtar --- nemo/collections/llm/__init__.py | 2 + nemo/collections/llm/gpt/model/__init__.py | 2 +- nemo/collections/llm/gpt/model/mistral.py | 6 +- nemo/collections/llm/recipes/__init__.py | 5 +- .../llm/recipes/{mistral.py => mistral_7b.py} | 2 +- .../llm/recipes/mistral_nemo_12b.py | 189 ++++++++++++++++++ 6 files changed, 200 insertions(+), 6 deletions(-) rename nemo/collections/llm/recipes/{mistral.py => mistral_7b.py} (99%) create mode 100644 nemo/collections/llm/recipes/mistral_nemo_12b.py diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 614af0df400c..4fde396c6af6 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -70,6 +70,7 @@ LlamaModel, MaskedTokenLossReduction, MistralConfig7B, + MistralNeMoConfig12B, MistralModel, MixtralConfig8x3B, MixtralConfig8x7B, @@ -118,6 +119,7 @@ "gpt_data_step", "gpt_forward_step", "MaskedTokenLossReduction", + "MistralNeMoConfig12B", "MistralConfig7B", "MistralModel", "MixtralConfig8x3B", diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index aa3615b3ddfd..ebecc06140fe 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -53,7 +53,7 @@ LlamaConfig, LlamaModel, ) -from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel +from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel, MistralNeMoConfig12B from nemo.collections.llm.gpt.model.mixtral import ( MixtralConfig8x3B, MixtralConfig8x7B, diff --git a/nemo/collections/llm/gpt/model/mistral.py b/nemo/collections/llm/gpt/model/mistral.py index a6415769112a..fce2684e6c5e 100644 --- a/nemo/collections/llm/gpt/model/mistral.py +++ b/nemo/collections/llm/gpt/model/mistral.py @@ -58,7 +58,7 @@ class MistralConfig7B(GPTConfig): @dataclass -class MistralNeMo2407Config12B(MistralConfig7B): +class MistralNeMoConfig12B(MistralConfig7B): """ https://mistral.ai/news/mistral-nemo/ """ @@ -74,7 +74,7 @@ class MistralNeMo2407Config12B(MistralConfig7B): @dataclass -class MistralNeMo2407Config123B(MistralConfig7B): +class MistralNeMoConfig123B(MistralConfig7B): """ https://mistral.ai/news/mistral-large-2407/ """ @@ -324,7 +324,7 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv): target_key="decoder.layers.*.mlp.linear_fc1.weight", ) def _import_linear_fc1(down, gate): - return torch.cat((down, gate), axis=0).float() + return torch.cat((down, gate), axis=0) @io.state_transform( diff --git a/nemo/collections/llm/recipes/__init__.py b/nemo/collections/llm/recipes/__init__.py index 950ca6db7ac6..303bb4f74bc8 100644 --- a/nemo/collections/llm/recipes/__init__.py +++ b/nemo/collections/llm/recipes/__init__.py @@ -6,6 +6,8 @@ llama3_70b_16k, llama3_70b_64k, mistral, + mistral_7b, + mistral_nemo_12b, mixtral_8x3b, mixtral_8x3b_16k, mixtral_8x3b_64k, @@ -24,7 +26,8 @@ "llama3_70b", "llama3_70b_16k", "llama3_70b_64k", - "mistral", + "mistral_7b", + "mistral_nemo_12b", "mixtral_8x3b", "mixtral_8x3b_16k", "mixtral_8x3b_64k", diff --git a/nemo/collections/llm/recipes/mistral.py b/nemo/collections/llm/recipes/mistral_7b.py similarity index 99% rename from nemo/collections/llm/recipes/mistral.py rename to nemo/collections/llm/recipes/mistral_7b.py index 99b9ef4c9e03..8357c4672ed8 100644 --- a/nemo/collections/llm/recipes/mistral.py +++ b/nemo/collections/llm/recipes/mistral_7b.py @@ -12,7 +12,7 @@ from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.utils import Config, Partial, factory -NAME = "mistral" +NAME = "mistral_7b" @factory(name=NAME) diff --git a/nemo/collections/llm/recipes/mistral_nemo_12b.py b/nemo/collections/llm/recipes/mistral_nemo_12b.py new file mode 100644 index 000000000000..6ad99a1f1253 --- /dev/null +++ b/nemo/collections/llm/recipes/mistral_nemo_12b.py @@ -0,0 +1,189 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.model.mistral import MistralModel, MistralNeMoConfig12B +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.utils.exp_manager import TimingCallback + +NAME = "mistral_nemo_base_12b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Mistral-Nemo-Base-12B model configuration. + Returns: + run.Config[pl.LightningModule]: Configuration for the Mistral-Nemo-Base-12B model. + Examples: + CLI usage: + $ nemo llm pretrain model=mistral_nemo_base_12b ... + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + return run.Config(MistralModel, config=run.Config(MistralNeMoConfig12B)) + + +def trainer( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 2, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 100, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Mistral-Nemo-Base-12B model. + This function sets up the distributed training strategy and other training parameters. + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + Examples: + CLI usage: + $ nemo llm pretrain trainer=mistral_nemo_base_12b ... + Python API usage: + >>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8) + >>> print(trainer_config) + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_include_optimizer=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + gradient_clip_val=1.0, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain +) -> run.Partial: + """ + Create a pre-training recipe for Mistral-Nemo-Base-12B model. + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + Returns: + run.Partial: Partial configuration for pre-training. + Examples: + CLI usage: + $ nemo llm pretrain --factory mistral_nemo_base_12b + $ nemo llm pretrain --factory "mistral_nemo_base_12b(num_nodes=2, name='my_mistral_pretrain')" + Python API usage: + >>> recipe = pretrain_recipe(name="mistral_pretrain", num_nodes=2) + >>> print(recipe) + """ + return run.Partial( + fn, + model=model(), + trainer=trainer( + tensor_parallelism=2, + pipeline_parallelism=1, + pipeline_parallelism_type=None, + virtual_pipeline_parallelism=None, + context_parallelism=2, + sequence_parallelism=False, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=4096, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + +@run.cli.factory(name=NAME + "_hf") +def hf_resume() -> run.Config[nl.AutoResume]: + """ + Configure automatic resumption from a Hugging Face checkpoint for Mistral-Nemo-Base-12B model. + This function sets up the configuration to resume training from a pre-trained + Hugging Face model checkpoint. + More info about the model can be found at: https://huggingface.co/mistralai/Mistral-Nemo-Base-2407 + Returns: + run.Config[nl.AutoResume]: Configuration for resuming from HuggingFace checkpoint. + Note: + This is particularly useful for fine-tuning scenarios where you want to + start from the pre-trained Mistral-Nemo-Base-12B model. + """ + return run.Config( + nl.AutoResume, restore_config=run.Config(nl.RestoreConfig, path="hf://mistralai/Mistral-Nemo-Base-2407") + )