Skip to content

Commit

Permalink
Run test in a subprocess to avoid contaminating global state
Browse files Browse the repository at this point in the history
Signed-off-by: John St John <jstjohn@nvidia.com>
  • Loading branch information
jstjohn committed Aug 21, 2024
1 parent dfdf426 commit a6ff157
Showing 1 changed file with 126 additions and 116 deletions.
242 changes: 126 additions & 116 deletions tests/collections/llm/test_mnist_model_nemo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
# limitations under the License.


import os
import subprocess
import sys
import tempfile
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TypedDict, TypeVar, Union

import megatron.core.num_microbatches_calculator
Expand Down Expand Up @@ -87,7 +92,7 @@ def forward(self, batch: DataT, forward_out: Tensor) -> Tuple[Tensor, ReductionT
x = batch["data"]
outputs = forward_out
x_hat = outputs["x_hat"]
# TODO, you could also put a latent loss on Z here.
# you could also put a latent loss on z here.
xview = x.view(x.size(0), -1)
loss = nn.functional.mse_loss(x_hat, xview)

Expand Down Expand Up @@ -182,8 +187,6 @@ def batch_collator(batches: Optional[Union[Tuple[ReductionT], List[ReductionT]]]
raise ValueError("Unsupported input structure in batch_collator")


# TODO(@jstjohn): Properly use the Generic for DataT and ReductionT usage. Define our own batch/output types.
# TODO(@skothenhill): Re-think the generics here- the way that `batch_collator` is expressed, `batches` should be a recursive generic type.
class PassthroughLossReduction(MegatronLossReduction):
"""Internally in NeMo2.0 the forward step is always expected to return a loss reduction class, and forward is expected to return a loss.
This class hijacks that mechanism to instead pass through the forward output unperturbed as the loss (to enable inference in the predict step), and then the
Expand Down Expand Up @@ -361,7 +364,7 @@ def __getitem__(self, index: int) -> MnistItem:
}


# TODO: remove this callback after `val` loss is logged by default in training.
# TODO: remove this callback after `val` loss is logged by default in training in NeMo2
class LossLoggingCallback(pl.Callback): # noqa: D101
def __init__(self):
"""Log the loss at the end of each batch. For training do not reduce across the epoch but do so for validation/test."""
Expand All @@ -371,49 +374,34 @@ def __init__(self):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): # noqa: D102
# Assuming the loss is computed internally and stored in pl_module
if torch.distributed.get_rank() == 0 and parallel_state.is_pipeline_last_stage():
# TODO(@jstjohn): verify when the outputs are a dictionary of "loss" and when they are just one tensor value.
if isinstance(outputs, dict):
outputs = outputs["loss"]
# torch.distributed.all_reduce(outputs, op=torch.distributed.ReduceOp.AVG)
loss = outputs
pl_module.log("train_loss", loss, on_step=True, prog_bar=True, logger=True, rank_zero_only=True)

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): # noqa: D102
# TODO(@jstjohn): Add a docstring with type hints for this lightning hook
# Assuming the loss is computed internally and stored in pl_module
if torch.distributed.get_rank() == 0 and parallel_state.is_pipeline_last_stage():
# TODO(@jstjohn): verify when the outputs are a dictionary of "loss" and when they are just one tensor value.
if isinstance(outputs, dict):
outputs = outputs["loss"]
# TODO verify that losses are already reduced across ranks
# torch.distributed.all_reduce(outputs, op=torch.distributed.ReduceOp.AVG)
loss = outputs
self.test_losses.append(loss)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): # noqa: D102
# TODO(@jstjohn): Add a docstring with type hints for this lightning hook
# Assuming the loss is computed internally and stored in pl_module
if torch.distributed.get_rank() == 0 and parallel_state.is_pipeline_last_stage():
# TODO(@jstjohn): verify when the outputs are a dictionary of "loss" and when they are just one tensor value.
if isinstance(outputs, dict):
outputs = outputs["loss"]
# TODO verify that losses are already reduced across ranks
# torch.distributed.all_reduce(outputs, op=torch.distributed.ReduceOp.AVG)
# TODO verify that losses are already reduced across ranks
# torch.distributed.all_reduce(outputs, op=torch.distributed.ReduceOp.AVG)
loss = outputs
self.val_losses.append(loss)

def on_validation_epoch_end(self, trainer, pl_module): # noqa: D102
# TODO(@jstjohn): Add a docstring with type hints for this lightning hook
if torch.distributed.get_rank() == 0 and parallel_state.is_pipeline_last_stage():
if len(self.val_losses) > 0:
avg_val_loss = torch.stack(self.val_losses).mean()
pl_module.log("val_loss", avg_val_loss, prog_bar=True, logger=True, rank_zero_only=True)
self.val_losses.clear()

def on_test_epoch_end(self, trainer, pl_module): # noqa: D102
# TODO(@jstjohn): Add a docstring with type hints for this lightning hook
if torch.distributed.get_rank() == 0 and parallel_state.is_pipeline_last_stage():
if len(self.test_losses) > 0:
avg_test_loss = torch.stack(self.test_losses).mean()
Expand Down Expand Up @@ -471,7 +459,7 @@ def _reset_megatron_parallel_state():
nemo.collections.nlp.modules.common.megatron.megatron_init.initialize_model_parallel_for_nemo
""" # noqa: D205, D415
megatron.core.num_microbatches_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
# Clean up any process gorups created in testing
# Clean up any process groups created in testing
torch.cuda.empty_cache()
if parallel_state.is_initialized():
parallel_state.destroy_model_parallel()
Expand All @@ -489,100 +477,122 @@ def reset_megatron_parallel_state() -> Iterator[None]:
_reset_megatron_parallel_state()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available")
def test_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(tmpdir):
with reset_megatron_parallel_state():
# Configure our custom Checkpointer
name = "test_experiment"
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_best_model=True,
save_last=True,
monitor="val_loss", # TODO find out how to get val_loss logged and use "val_loss",
save_top_k=1,
every_n_train_steps=5,
enable_nemo_ckpt_io=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
# async_save=False, # Tries to save asynchronously, previously led to race conditions.
)
root_dir = tmpdir
save_dir = root_dir / name
tb_logger = TensorBoardLogger(save_dir=save_dir, name=name)
# Setup the logger and train the model
nemo_logger = NeMoLogger(
dir=root_dir,
name=name,
tensorboard=tb_logger,
ckpt=checkpoint_callback,
)
# Needed so that the trainer can find an output directory for the profiler
# nemo_logger.save_dir = tmpdir

model = LitAutoEncoder(config=ExampleConfig())
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
ddp="megatron",
find_unused_parameters=True,
enable_nemo_ckpt_io=True,
)
trainer = nl.Trainer(
accelerator="gpu",
devices=1,
strategy=strategy,
limit_val_batches=5,
val_check_interval=5,
max_steps=20,
num_nodes=1,
log_every_n_steps=5,
callbacks=[io.track_io(LossLoggingCallback)()],
)
data_module = MNISTDataModule(data_dir=tmpdir)
llm.train(
model=model,
data=data_module,
trainer=trainer,
log=nemo_logger,
resume=resume.AutoResume(
path=None, # Overrides the path found by resume_if_exists when set.
resume_if_exists=True, # Looks for the -last checkpoint to continue training.
resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint.
),
)
trainer._teardown()
with reset_megatron_parallel_state():
pred_strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
ddp="megatron",
find_unused_parameters=True,
enable_nemo_ckpt_io=True,
data_sampler=MegatronDataSampler(
seq_len=28 * 28,
micro_batch_size=2,
global_batch_size=2,
output_log=False, # Disable logs to support predict_step
),
)
predict_trainer = nl.Trainer(
accelerator="gpu",
devices=1,
strategy=pred_strategy,
default_root_dir=root_dir,
)
ckpt_path = checkpoint_callback.last_model_path.replace(
".ckpt", ""
) # strip .ckpt off the end of the last path
# FIXME: the below checkpoint loading strategy and manual module unwrapping probably only works in single GPU
# and maybe DDP.
unwrapped_trained_model = trainer.model.module # TODO clean this up. Would be good not to have to unwrap.
forward_output = batch_collator(
predict_trainer.predict(
unwrapped_trained_model, dataloaders=data_module.test_dataloader(), ckpt_path=ckpt_path
@pytest.mark.run_only_on("GPU")
@pytest.mark.integration
def test_train_mnist_litautoencoder_with_megatron_strategy_single_gpu():
path = os.path.abspath(__file__)
call = f"python {path}"
# Raises a CalledProcessError if there is a failure in the subprocess
subprocess.check_call(call, shell=True, stdout=sys.stdout, stderr=sys.stdout)


def run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu():
"""This is the actual test that will get run in a subprocess so it does not contaminate the state of other tests."""
with tempfile.TemporaryDirectory() as tmpdir_str:
tmpdir = Path(tmpdir_str)
assert tmpdir.exists()
assert tmpdir.is_dir()
with reset_megatron_parallel_state():
# Configure our custom Checkpointer
name = "test_experiment"
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_best_model=True,
save_last=True,
monitor="val_loss",
save_top_k=1,
every_n_train_steps=5,
# Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
enable_nemo_ckpt_io=True,
)
)
assert set(forward_output.keys()) == {
"z",
"x_hat",
}, f"We expect forward output from predit_step, not the loss, got: {forward_output}"
assert forward_output["x_hat"].shape == (len(data_module.mnist_test), 28 * 28)
assert forward_output["z"].shape == (len(data_module.mnist_test), 3) # latent bottleneck in model of dim 3
predict_trainer._teardown()
root_dir = tmpdir
save_dir = root_dir / name
tb_logger = TensorBoardLogger(save_dir=str(save_dir), name=name)
# Setup the logger and train the model
nemo_logger = NeMoLogger(
dir=str(root_dir), # WARNING: passing a path in here results in mutating the Path class.
name=name,
tensorboard=tb_logger,
ckpt=checkpoint_callback,
)
# Needed so that the trainer can find an output directory for the profiler
# nemo_logger.save_dir = tmpdir

model = LitAutoEncoder(config=ExampleConfig())
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
ddp="megatron",
find_unused_parameters=True,
enable_nemo_ckpt_io=True,
)
trainer = nl.Trainer(
accelerator="gpu",
devices=1,
strategy=strategy,
limit_val_batches=5,
val_check_interval=5,
max_steps=20,
num_nodes=1,
log_every_n_steps=5,
callbacks=[io.track_io(LossLoggingCallback)()],
)
data_module = MNISTDataModule(data_dir=tmpdir)
llm.train(
model=model,
data=data_module,
trainer=trainer,
log=nemo_logger,
resume=resume.AutoResume(
path=None, # Overrides the path found by resume_if_exists when set.
resume_if_exists=True, # Looks for the -last checkpoint to continue training.
resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint.
),
)
trainer._teardown()
with reset_megatron_parallel_state():
pred_strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
ddp="megatron",
find_unused_parameters=True,
enable_nemo_ckpt_io=True,
data_sampler=MegatronDataSampler(
seq_len=28 * 28,
micro_batch_size=2,
global_batch_size=2,
output_log=False, # Disable logs to support predict_step
),
)
predict_trainer = nl.Trainer(
accelerator="gpu",
devices=1,
strategy=pred_strategy,
default_root_dir=str(root_dir), # WARNING: passing a path in here results in mutating the Path class.
)
ckpt_path = checkpoint_callback.last_model_path.replace(
".ckpt", ""
) # strip .ckpt off the end of the last path

assert Path(
ckpt_path
).exists(), f"checkpoint {ckpt_path} not found in {os.listdir(Path(ckpt_path).parent)}"
# FIXME: the below checkpoint loading strategy and manual module unwrapping probably only works in single GPU
# and maybe DDP.
unwrapped_trained_model = trainer.model.module # TODO clean this up. Would be good not to have to unwrap.
forward_output = batch_collator(
predict_trainer.predict(
unwrapped_trained_model, dataloaders=data_module.test_dataloader(), ckpt_path=ckpt_path
)
)
assert set(forward_output.keys()) == {
"z",
"x_hat",
}, f"We expect forward output from predit_step, not the loss, got: {forward_output}"
assert forward_output["x_hat"].shape == (len(data_module.mnist_test), 28 * 28)
assert forward_output["z"].shape == (len(data_module.mnist_test), 3) # latent bottleneck in model of dim 3
predict_trainer._teardown()


if __name__ == "__main__":
# Have the test run this one item as a subprocess call
run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu()

0 comments on commit a6ff157

Please sign in to comment.