Skip to content

Commit

Permalink
add mistral 12b recipe for 24.09
Browse files Browse the repository at this point in the history
Signed-off-by: dimapihtar <dpihtar@gmail.com>
  • Loading branch information
dimapihtar committed Sep 27, 2024
1 parent 4c54e2a commit daf2549
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 6 deletions.
2 changes: 2 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
LlamaModel,
MaskedTokenLossReduction,
MistralConfig7B,
MistralNeMoConfig12B,
MistralModel,
MixtralConfig8x3B,
MixtralConfig8x7B,
Expand Down Expand Up @@ -118,6 +119,7 @@
"gpt_data_step",
"gpt_forward_step",
"MaskedTokenLossReduction",
"MistralNeMoConfig12B",
"MistralConfig7B",
"MistralModel",
"MixtralConfig8x3B",
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
LlamaConfig,
LlamaModel,
)
from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel
from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel, MistralNeMoConfig12B
from nemo.collections.llm.gpt.model.mixtral import (
MixtralConfig8x3B,
MixtralConfig8x7B,
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/llm/gpt/model/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class MistralConfig7B(GPTConfig):


@dataclass
class MistralNeMo2407Config12B(MistralConfig7B):
class MistralNeMoConfig12B(MistralConfig7B):
"""
https://mistral.ai/news/mistral-nemo/
"""
Expand All @@ -74,7 +74,7 @@ class MistralNeMo2407Config12B(MistralConfig7B):


@dataclass
class MistralNeMo2407Config123B(MistralConfig7B):
class MistralNeMoConfig123B(MistralConfig7B):
"""
https://mistral.ai/news/mistral-large-2407/
"""
Expand Down Expand Up @@ -324,7 +324,7 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv):
target_key="decoder.layers.*.mlp.linear_fc1.weight",
)
def _import_linear_fc1(down, gate):
return torch.cat((down, gate), axis=0).float()
return torch.cat((down, gate), axis=0)


@io.state_transform(
Expand Down
5 changes: 4 additions & 1 deletion nemo/collections/llm/recipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
llama3_70b_16k,
llama3_70b_64k,
mistral,
mistral_7b,
mistral_nemo_12b,
mixtral_8x3b,
mixtral_8x3b_16k,
mixtral_8x3b_64k,
Expand All @@ -24,7 +26,8 @@
"llama3_70b",
"llama3_70b_16k",
"llama3_70b_64k",
"mistral",
"mistral_7b",
"mistral_nemo_12b",
"mixtral_8x3b",
"mixtral_8x3b_16k",
"mixtral_8x3b_64k",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.utils import Config, Partial, factory

NAME = "mistral"
NAME = "mistral_7b"


@factory(name=NAME)
Expand Down
189 changes: 189 additions & 0 deletions nemo/collections/llm/recipes/mistral_nemo_12b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Optional

import nemo_run as run
import pytorch_lightning as pl
import torch
from megatron.core.distributed import DistributedDataParallelConfig
from pytorch_lightning.callbacks.callback import Callback

from nemo import lightning as nl
from nemo.collections.llm.api import finetune, pretrain
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.model.mistral import MistralModel, MistralNeMoConfig12B
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.utils.exp_manager import TimingCallback

NAME = "mistral_nemo_base_12b"


@run.cli.factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Mistral-Nemo-Base-12B model configuration.
Returns:
run.Config[pl.LightningModule]: Configuration for the Mistral-Nemo-Base-12B model.
Examples:
CLI usage:
$ nemo llm pretrain model=mistral_nemo_base_12b ...
Python API usage:
>>> model_config = model()
>>> print(model_config)
"""
return run.Config(MistralModel, config=run.Config(MistralNeMoConfig12B))


def trainer(
tensor_parallelism: int = 1,
pipeline_parallelism: int = 1,
pipeline_parallelism_type: Optional[torch.dtype] = None,
virtual_pipeline_parallelism: Optional[int] = None,
context_parallelism: int = 2,
sequence_parallelism: bool = False,
num_nodes: int = 1,
num_gpus_per_node: int = 8,
max_steps: int = 100,
callbacks: Optional[list[run.Config[Callback]]] = None,
) -> run.Config[nl.Trainer]:
"""
Configure the NeMo Lightning Trainer for Mistral-Nemo-Base-12B model.
This function sets up the distributed training strategy and other training parameters.
Args:
tensor_parallelism (int): Degree of tensor model parallelism.
pipeline_parallelism (int): Degree of pipeline model parallelism.
pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism.
virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism.
context_parallelism (int): Degree of context parallelism.
sequence_parallelism (bool): Whether to use sequence parallelism.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
max_steps (int): Maximum number of training steps.
callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations.
Returns:
run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer.
Examples:
CLI usage:
$ nemo llm pretrain trainer=mistral_nemo_base_12b ...
Python API usage:
>>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8)
>>> print(trainer_config)
"""
strategy = run.Config(
nl.MegatronStrategy,
tensor_model_parallel_size=tensor_parallelism,
pipeline_model_parallel_size=pipeline_parallelism,
pipeline_dtype=pipeline_parallelism_type,
virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism,
context_parallel_size=context_parallelism,
sequence_parallel=sequence_parallelism,
gradient_as_bucket_view=True,
ckpt_include_optimizer=True,
ckpt_async_save=True,
ckpt_parallel_load=True,
ddp=run.Config(
DistributedDataParallelConfig,
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
overlap_grad_reduce=True,
overlap_param_gather=True,
),
)

trainer = run.Config(
nl.Trainer,
accelerator="gpu",
accumulate_grad_batches=1,
callbacks=callbacks,
devices=num_gpus_per_node,
gradient_clip_val=1.0,
limit_test_batches=50,
limit_val_batches=32,
log_every_n_steps=10,
max_steps=max_steps,
num_nodes=num_nodes,
plugins=bf16_mixed(),
strategy=strategy,
use_distributed_sampler=False,
val_check_interval=2000,
)

return trainer


@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a pre-training recipe for Mistral-Nemo-Base-12B model.
This function sets up a complete configuration for pre-training, including
model, trainer, data, logging, optimization, and resumption settings.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
Returns:
run.Partial: Partial configuration for pre-training.
Examples:
CLI usage:
$ nemo llm pretrain --factory mistral_nemo_base_12b
$ nemo llm pretrain --factory "mistral_nemo_base_12b(num_nodes=2, name='my_mistral_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe(name="mistral_pretrain", num_nodes=2)
>>> print(recipe)
"""
return run.Partial(
fn,
model=model(),
trainer=trainer(
tensor_parallelism=2,
pipeline_parallelism=1,
pipeline_parallelism_type=None,
virtual_pipeline_parallelism=None,
context_parallelism=2,
sequence_parallelism=False,
num_nodes=num_nodes,
num_gpus_per_node=num_gpus_per_node,
callbacks=[run.Config(TimingCallback)],
),
data=run.Config(MockDataModule, seq_length=4096, global_batch_size=512, micro_batch_size=1),
log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)),
optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4),
resume=default_resume(),
)


@run.cli.factory(name=NAME + "_hf")
def hf_resume() -> run.Config[nl.AutoResume]:
"""
Configure automatic resumption from a Hugging Face checkpoint for Mistral-Nemo-Base-12B model.
This function sets up the configuration to resume training from a pre-trained
Hugging Face model checkpoint.
More info about the model can be found at: https://huggingface.co/mistralai/Mistral-Nemo-Base-2407
Returns:
run.Config[nl.AutoResume]: Configuration for resuming from HuggingFace checkpoint.
Note:
This is particularly useful for fine-tuning scenarios where you want to
start from the pre-trained Mistral-Nemo-Base-12B model.
"""
return run.Config(
nl.AutoResume, restore_config=run.Config(nl.RestoreConfig, path="hf://mistralai/Mistral-Nemo-Base-2407")
)

0 comments on commit daf2549

Please sign in to comment.