Skip to content

Commit

Permalink
Destroy the right sets of state in test of lightning trainer
Browse files Browse the repository at this point in the history
Signed-off-by: John St John <jstjohn@nvidia.com>
  • Loading branch information
jstjohn committed Aug 20, 2024
1 parent b1ba4d9 commit 0da36a9
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tests/collections/llm/test_mnist_model_nemo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,12 @@ def _reset_microbatch_calculator():
nemo.collections.nlp.modules.common.megatron.megatron_init.initialize_model_parallel_for_nemo
""" # noqa: D205, D415
megatron.core.num_microbatches_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
# Clean up any process gorups created in testing
torch.cuda.empty_cache()
if parallel_state.is_initialized():
parallel_state.destroy_model_parallel()
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()


@contextmanager
Expand Down Expand Up @@ -541,6 +547,7 @@ def test_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(tmpdir):
resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint.
),
)
trainer._teardown()
with reset_global_microbatch_calculator():
pred_strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
Expand Down Expand Up @@ -578,3 +585,4 @@ def test_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(tmpdir):
}, f"We expect forward output from predit_step, not the loss, got: {forward_output}"
assert forward_output["x_hat"].shape == (len(data_module.mnist_test), 28 * 28)
assert forward_output["z"].shape == (len(data_module.mnist_test), 3) # latent bottleneck in model of dim 3
predict_trainer._teardown()

0 comments on commit 0da36a9

Please sign in to comment.