Skip to content

Commit

Permalink
add trainer callbacks
Browse files Browse the repository at this point in the history
Signed-off-by: dimapihtar <dpihtar@gmail.com>
  • Loading branch information
dimapihtar committed Sep 27, 2024
1 parent 9ea966a commit f141149
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion nemo/collections/llm/recipes/mistral_nemo_12b.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed_plugin
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.utils.exp_manager import TimingCallback

NAME = "mistral_nemo_base_12b"
Expand Down Expand Up @@ -150,7 +151,7 @@ def pretrain_recipe(
>>> recipe = pretrain_recipe(name="mistral_pretrain", num_nodes=2)
>>> print(recipe)
"""
return run.Partial(
recipe = run.Partial(
fn,
model=model(),
trainer=trainer(
Expand All @@ -170,6 +171,14 @@ def pretrain_recipe(
resume=default_resume(),
)

recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=True,
)
)

return recipe

@run.cli.factory(name=NAME + "_hf")
def hf_resume() -> run.Config[nl.AutoResume]:
Expand Down

0 comments on commit f141149

Please sign in to comment.