Skip to content

Commit

Permalink
Fix typo and rename state resetting functions
Browse files Browse the repository at this point in the history
Signed-off-by: John St John <jstjohn@nvidia.com>
  • Loading branch information
jstjohn committed Aug 20, 2024
1 parent 0da36a9 commit 236f257
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tests/collections/llm/test_mnist_model_nemo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,8 @@ def test_dataloader(self) -> DataLoader: # noqa: D102


### Begin model environment related utilities
def _reset_microbatch_calculator():
"""Resets _GLOBAL_NUM_MICROBATCHES_CALCULATOR in megatron which is used in NeMo to initilised model parallel in
def _reset_megatron_parallel_state():
"""Resets _GLOBAL_NUM_MICROBATCHES_CALCULATOR in megatron which is used in NeMo to initialized model parallel in
nemo.collections.nlp.modules.common.megatron.megatron_init.initialize_model_parallel_for_nemo
""" # noqa: D205, D415
megatron.core.num_microbatches_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
Expand All @@ -480,18 +480,18 @@ def _reset_microbatch_calculator():


@contextmanager
def reset_global_microbatch_calculator() -> Iterator[None]:
def reset_megatron_parallel_state() -> Iterator[None]:
"""Puts you into a clean parallel state, and again tears it down at the end."""
try:
_reset_microbatch_calculator()
_reset_megatron_parallel_state()
yield
finally:
_reset_microbatch_calculator()
_reset_megatron_parallel_state()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available")
def test_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(tmpdir):
with reset_global_microbatch_calculator():
with reset_megatron_parallel_state():
# Configure our custom Checkpointer
name = "test_experiment"
checkpoint_callback = nl_callbacks.ModelCheckpoint(
Expand Down Expand Up @@ -548,7 +548,7 @@ def test_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(tmpdir):
),
)
trainer._teardown()
with reset_global_microbatch_calculator():
with reset_megatron_parallel_state():
pred_strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
Expand Down

0 comments on commit 236f257

Please sign in to comment.