diff --git a/tests/collections/llm/test_mnist_model_nemo2.py b/tests/collections/llm/test_mnist_model_nemo2.py index 58ac1039085f..52fde4c26af9 100644 --- a/tests/collections/llm/test_mnist_model_nemo2.py +++ b/tests/collections/llm/test_mnist_model_nemo2.py @@ -471,6 +471,12 @@ def _reset_microbatch_calculator(): nemo.collections.nlp.modules.common.megatron.megatron_init.initialize_model_parallel_for_nemo """ # noqa: D205, D415 megatron.core.num_microbatches_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None + # Clean up any process gorups created in testing + torch.cuda.empty_cache() + if parallel_state.is_initialized(): + parallel_state.destroy_model_parallel() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() @contextmanager @@ -541,6 +547,7 @@ def test_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(tmpdir): resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint. ), ) + trainer._teardown() with reset_global_microbatch_calculator(): pred_strategy = nl.MegatronStrategy( tensor_model_parallel_size=1, @@ -578,3 +585,4 @@ def test_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(tmpdir): }, f"We expect forward output from predit_step, not the loss, got: {forward_output}" assert forward_output["x_hat"].shape == (len(data_module.mnist_test), 28 * 28) assert forward_output["z"].shape == (len(data_module.mnist_test), 3) # latent bottleneck in model of dim 3 + predict_trainer._teardown()