-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Malay Nagda <malayn@nvidia.com>
- Loading branch information
1 parent
e1ba662
commit 8031797
Showing
4 changed files
with
169 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Optional | ||
|
||
import nemo_run as run | ||
from utils import get_comm_overlap_callback_idx, hf_tokenizer, parse_cli_args, slurm_executor | ||
|
||
from nemo.collections.llm.recipes.llama31_405b import pretrain_recipe | ||
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_with_fp8_mixed | ||
from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback | ||
from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin | ||
|
||
NUM_NODES = 72 | ||
NUM_GPUS_PER_NODE = 8 | ||
MICRO_BATCH_SIZE = 1 | ||
GLOBAL_BATCH_SIZE = 252 | ||
TP_SIZE = 8 | ||
PP_SIZE = 9 | ||
CP_SIZE = 2 | ||
VP_SIZE = 7 | ||
MAX_STEPS = 100 | ||
|
||
def llama3_405b_performance_recipe( | ||
log_dir: str, | ||
compute_dtype: str, | ||
num_nodes: int, | ||
num_gpus_per_node: int, | ||
mbs: int, | ||
gbs: int, | ||
tp_size: int, | ||
pp_size: int, | ||
cp_size: int, | ||
vp_size: Optional[int], | ||
max_steps: int, | ||
): | ||
recipe = pretrain_recipe(dir=log_dir, performance_mode=True) | ||
|
||
# data module configs | ||
recipe.data.micro_batch_size = mbs | ||
recipe.data.global_batch_size = gbs | ||
recipe.data.num_train_samples = max_steps * (num_nodes * num_gpus_per_node) # ensure only 1 epoch for whole run | ||
recipe.data.tokenizer = hf_tokenizer("meta-llama/Llama-3.1-405B") | ||
|
||
recipe.trainer.max_steps = max_steps | ||
recipe.trainer.num_nodes = num_nodes | ||
recipe.trainer.devices = num_gpus_per_node | ||
|
||
# parallelism configs | ||
recipe.trainer.strategy.tensor_model_parallel_size = tp_size | ||
recipe.trainer.strategy.pipeline_model_parallel_size = pp_size | ||
recipe.trainer.strategy.context_parallel_size = cp_size | ||
recipe.trainer.strategy.virtual_pipeline_model_parallel_size = vp_size | ||
if tp_size > 1: | ||
recipe.trainer.strategy.sequence_parallel = True | ||
else: | ||
recipe.trainer.strategy.sequence_parallel = False | ||
|
||
comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks) | ||
|
||
# compute dtype configs | ||
if compute_dtype.lower() == "fp8": | ||
recipe.trainer.plugins = bf16_with_fp8_mixed() | ||
recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.proj_fprop.fp8_buf=True | ||
recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.fc2_fprop.fp8_buf=True | ||
|
||
recipe.trainer.plugins.grad_reduce_in_fp32 = False # bf16 grad dtype | ||
|
||
# callback configs | ||
garbage_collection_callback = run.Config( | ||
GarbageCollectionCallback, | ||
gc_interval_train=100, | ||
gc_interval_val=500, | ||
) | ||
recipe.trainer.callbacks.extend( | ||
[ | ||
garbage_collection_callback, | ||
] | ||
) | ||
dp_size = (num_nodes * num_gpus_per_node) / (tp_size * pp_size * cp_size) | ||
if dp_size > 1 and pp_size > 1 and vp_size and vp_size > 1: | ||
if comm_overlap_callback_idx >= 0: | ||
recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather_with_optimizer_step = True | ||
|
||
recipe.log.ckpt = None | ||
recipe.trainer.enable_checkpointing = False | ||
recipe.trainer.val_check_interval=MAX_STEPS | ||
recipe.trainer.log_every_n_steps=1 | ||
|
||
return recipe | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_cli_args().parse_args() | ||
|
||
exp_name = "_".join( | ||
[ | ||
f"llama3_405b", | ||
args.compute_dtype, | ||
f"{NUM_NODES}nodes", | ||
f"tp{TP_SIZE}_pp{PP_SIZE}_cp{CP_SIZE}_vp{VP_SIZE}", | ||
f"{MICRO_BATCH_SIZE}mbs_{GLOBAL_BATCH_SIZE}gbs", | ||
] | ||
) | ||
|
||
executor = slurm_executor( | ||
args.account, | ||
args.partition, | ||
args.log_dir, | ||
NUM_NODES, | ||
NUM_GPUS_PER_NODE, | ||
args.time_limit, | ||
args.container_image, | ||
custom_mounts=[], | ||
custom_env_vars={}, | ||
retries=0, | ||
) | ||
|
||
recipe = llama3_405b_performance_recipe( | ||
args.log_dir, | ||
args.compute_dtype, | ||
NUM_NODES, | ||
NUM_GPUS_PER_NODE, | ||
MICRO_BATCH_SIZE, | ||
GLOBAL_BATCH_SIZE, | ||
TP_SIZE, | ||
PP_SIZE, | ||
CP_SIZE, | ||
VP_SIZE, | ||
MAX_STEPS, | ||
) | ||
|
||
with run.Experiment(exp_name) as exp: | ||
exp.add( | ||
recipe, | ||
executor=executor, | ||
name=exp_name, | ||
plugins=[ | ||
PerfEnvPlugin(enable_vboost=True), | ||
NsysPlugin(start_step=5, end_step=6), | ||
], | ||
) | ||
|
||
if not args.dryrun: | ||
exp.run(sequential=True, detach=True) | ||
else: | ||
exp.dryrun() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters