Skip to content

Commit

Permalink
Only reset the global microbatch, not entire parallel state
Browse files Browse the repository at this point in the history
Signed-off-by: John St John <jstjohn@nvidia.com>
  • Loading branch information
jstjohn committed Aug 20, 2024
1 parent ce413ab commit d0258da
Showing 1 changed file with 5 additions and 15 deletions.
20 changes: 5 additions & 15 deletions tests/collections/llm/test_mnist_model_nemo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,29 +473,19 @@ def _reset_microbatch_calculator():
megatron.core.num_microbatches_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None


def _teardown_apex_megatron_cuda():
"""Cleans GPU allocation and model and data parallel settings after usage of a model:
- sets the global variables related to model and data parallelism to None in Apex and Megatron:.
- releases all unoccupied cached GPU memory currently held by the caching CUDA allocator, see torch.cuda.empty_cache
""" # noqa: D205, D415
torch.cuda.empty_cache()
_reset_microbatch_calculator()
parallel_state.destroy_model_parallel()


@contextmanager
def clean_parallel_state_context() -> Iterator[None]:
def reset_global_microbatch_calculator() -> Iterator[None]:
"""Puts you into a clean parallel state, and again tears it down at the end."""
try:
_teardown_apex_megatron_cuda()
_reset_microbatch_calculator()
yield
finally:
_teardown_apex_megatron_cuda()
_reset_microbatch_calculator()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available")
def test_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(tmpdir):
with clean_parallel_state_context():
with reset_global_microbatch_calculator():
# Configure our custom Checkpointer
name = "test_experiment"
checkpoint_callback = nl_callbacks.ModelCheckpoint(
Expand Down Expand Up @@ -551,7 +541,7 @@ def test_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(tmpdir):
resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint.
),
)
with clean_parallel_state_context():
with reset_global_microbatch_calculator():
pred_strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
Expand Down

0 comments on commit d0258da

Please sign in to comment.