Skip to content

Commit

Permalink
405b recipe
Browse files Browse the repository at this point in the history
Signed-off-by: Malay Nagda <malayn@nvidia.com>
  • Loading branch information
malay-nagda committed Dec 23, 2024
1 parent e1ba662 commit 8031797
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 0 deletions.
158 changes: 158 additions & 0 deletions scripts/llm/performance/llama3_405b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import nemo_run as run
from utils import get_comm_overlap_callback_idx, hf_tokenizer, parse_cli_args, slurm_executor

from nemo.collections.llm.recipes.llama31_405b import pretrain_recipe
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_with_fp8_mixed
from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback
from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin

NUM_NODES = 72
NUM_GPUS_PER_NODE = 8
MICRO_BATCH_SIZE = 1
GLOBAL_BATCH_SIZE = 252
TP_SIZE = 8
PP_SIZE = 9
CP_SIZE = 2
VP_SIZE = 7
MAX_STEPS = 100

def llama3_405b_performance_recipe(
log_dir: str,
compute_dtype: str,
num_nodes: int,
num_gpus_per_node: int,
mbs: int,
gbs: int,
tp_size: int,
pp_size: int,
cp_size: int,
vp_size: Optional[int],
max_steps: int,
):
recipe = pretrain_recipe(dir=log_dir, performance_mode=True)

# data module configs
recipe.data.micro_batch_size = mbs
recipe.data.global_batch_size = gbs
recipe.data.num_train_samples = max_steps * (num_nodes * num_gpus_per_node) # ensure only 1 epoch for whole run
recipe.data.tokenizer = hf_tokenizer("meta-llama/Llama-3.1-405B")

recipe.trainer.max_steps = max_steps
recipe.trainer.num_nodes = num_nodes
recipe.trainer.devices = num_gpus_per_node

# parallelism configs
recipe.trainer.strategy.tensor_model_parallel_size = tp_size
recipe.trainer.strategy.pipeline_model_parallel_size = pp_size
recipe.trainer.strategy.context_parallel_size = cp_size
recipe.trainer.strategy.virtual_pipeline_model_parallel_size = vp_size
if tp_size > 1:
recipe.trainer.strategy.sequence_parallel = True
else:
recipe.trainer.strategy.sequence_parallel = False

comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks)

# compute dtype configs
if compute_dtype.lower() == "fp8":
recipe.trainer.plugins = bf16_with_fp8_mixed()
recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.proj_fprop.fp8_buf=True
recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.fc2_fprop.fp8_buf=True

recipe.trainer.plugins.grad_reduce_in_fp32 = False # bf16 grad dtype

# callback configs
garbage_collection_callback = run.Config(
GarbageCollectionCallback,
gc_interval_train=100,
gc_interval_val=500,
)
recipe.trainer.callbacks.extend(
[
garbage_collection_callback,
]
)
dp_size = (num_nodes * num_gpus_per_node) / (tp_size * pp_size * cp_size)
if dp_size > 1 and pp_size > 1 and vp_size and vp_size > 1:
if comm_overlap_callback_idx >= 0:
recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather_with_optimizer_step = True

recipe.log.ckpt = None
recipe.trainer.enable_checkpointing = False
recipe.trainer.val_check_interval=MAX_STEPS
recipe.trainer.log_every_n_steps=1

return recipe


if __name__ == "__main__":
args = parse_cli_args().parse_args()

exp_name = "_".join(
[
f"llama3_405b",
args.compute_dtype,
f"{NUM_NODES}nodes",
f"tp{TP_SIZE}_pp{PP_SIZE}_cp{CP_SIZE}_vp{VP_SIZE}",
f"{MICRO_BATCH_SIZE}mbs_{GLOBAL_BATCH_SIZE}gbs",
]
)

executor = slurm_executor(
args.account,
args.partition,
args.log_dir,
NUM_NODES,
NUM_GPUS_PER_NODE,
args.time_limit,
args.container_image,
custom_mounts=[],
custom_env_vars={},
retries=0,
)

recipe = llama3_405b_performance_recipe(
args.log_dir,
args.compute_dtype,
NUM_NODES,
NUM_GPUS_PER_NODE,
MICRO_BATCH_SIZE,
GLOBAL_BATCH_SIZE,
TP_SIZE,
PP_SIZE,
CP_SIZE,
VP_SIZE,
MAX_STEPS,
)

with run.Experiment(exp_name) as exp:
exp.add(
recipe,
executor=executor,
name=exp_name,
plugins=[
PerfEnvPlugin(enable_vboost=True),
NsysPlugin(start_step=5, end_step=6),
],
)

if not args.dryrun:
exp.run(sequential=True, detach=True)
else:
exp.dryrun()
5 changes: 5 additions & 0 deletions scripts/llm/performance/llama3_70b.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ def llama3_70b_performance_recipe(
if comm_overlap_callback_idx >= 0:
recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather_with_optimizer_step = True

recipe.log.ckpt = None
recipe.trainer.enable_checkpointing = False
recipe.trainer.val_check_interval=MAX_STEPS
recipe.trainer.log_every_n_steps=1

return recipe


Expand Down
5 changes: 5 additions & 0 deletions scripts/llm/performance/llama3_8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def llama3_8b_performance_recipe(
if comm_overlap_callback_idx >= 0:
recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather_with_optimizer_step = True

recipe.log.ckpt = None
recipe.trainer.enable_checkpointing = False
recipe.trainer.val_check_interval=MAX_STEPS
recipe.trainer.log_every_n_steps=1

return recipe


Expand Down
1 change: 1 addition & 0 deletions scripts/llm/performance/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def slurm_executor(

env_vars = {
"TRANSFORMERS_OFFLINE": "1",
"TOKENIZERS_PARALLELISM": "False",
"TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
"NCCL_NVLS_ENABLE": "0",
"NVTE_DP_AMAX_REDUCE_INTERVAL": "0",
Expand Down

0 comments on commit 8031797

Please sign in to comment.