diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 56146498b5394..dd10a726e67a7 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -231,7 +231,24 @@ def forward( pipeline = self.pipeline - use_global_batch_sampler = self.trainer.datamodule.data_sampler.dataloader_type == 'batch' + # FIXME: cleanup the following code block which is here for backwards compatibility with nemo1. The "batch" + # sampler is a nemo1 sampler. It requires some custom code here to use (if use_global_batch_sampler). + # by default we shouldn't use this "batch" sampler probably. + if getattr(self.trainer, "datamodule", None) is not None: + use_global_batch_sampler = self.trainer.datamodule.data_sampler.dataloader_type == 'batch' + elif getattr(self.trainer, "predict_dataloaders", None) is not None: + from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( # noqa: I001 + MegatronPretrainingBatchSampler, + ) + + # The batch_sampler gets injected into the dataloader by the data_sampler. When doing predict without a + # datamodule we can look inside the dataloader's batch_sampler to see if it is the nemo1 style sampler + # that we need to handle specially below. + use_global_batch_sampler = isinstance( + self.trainer.predict_dataloaders.batch_sampler, MegatronPretrainingBatchSampler + ) + else: + raise ValueError("Unsure how to check for nemo1 global_batch_sampler status. TODO maybe default to False?") if use_global_batch_sampler: from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split diff --git a/nemo/lightning/pytorch/plugins/data_sampler.py b/nemo/lightning/pytorch/plugins/data_sampler.py index 13a0caa98f0c5..bacb7cb0af5ca 100644 --- a/nemo/lightning/pytorch/plugins/data_sampler.py +++ b/nemo/lightning/pytorch/plugins/data_sampler.py @@ -26,8 +26,10 @@ def __init__( dataloader_type: Literal["single", "cyclic", "batch"] = "single", init_consumed_samples: int = 0, init_global_step: int = 0, + output_log: bool = True, ): self.seq_len = seq_len + self.output_log = output_log self.micro_batch_size = micro_batch_size self.global_batch_size = global_batch_size self.rampup_batch_size = rampup_batch_size @@ -95,12 +97,14 @@ def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModul self.prev_global_batch_size = self.current_global_batch_size consumed_samples = self.compute_consumed_samples(trainer.global_step + 1 - self.init_global_step) - pl_module.log( - 'consumed_samples', - consumed_samples, - prog_bar=True, - batch_size=1, - ) + if self.output_log: + # You may need to turn off logging, for example when doing trainer.predict(model, data) + pl_module.log( + 'consumed_samples', + consumed_samples, + prog_bar=True, + batch_size=1, + ) self.prev_consumed_samples = consumed_samples @@ -108,12 +112,14 @@ def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModul consumed_samples=consumed_samples, consistency_check=False, ) - pl_module.log( - "global_batch_size", - self.current_global_batch_size, - prog_bar=True, - batch_size=1, - ) + if self.output_log: + # You may need to turn off logging, for example when doing trainer.predict(model, data) + pl_module.log( + "global_batch_size", + self.current_global_batch_size, + prog_bar=True, + batch_size=1, + ) self.if_first_step = 1 @property diff --git a/tests/collections/llm/test_mnist_model_nemo2.py b/tests/collections/llm/test_mnist_model_nemo2.py new file mode 100644 index 0000000000000..c783062017515 --- /dev/null +++ b/tests/collections/llm/test_mnist_model_nemo2.py @@ -0,0 +1,598 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import subprocess +import sys +import tempfile +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TypedDict, TypeVar, Union + +import megatron.core.num_microbatches_calculator +import pytest +import pytorch_lightning as pl +import torch +import torch.distributed +from megatron.core import ModelParallelConfig, parallel_state +from megatron.core.optimizer import OptimizerConfig +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.module import MegatronModule +from pytorch_lightning.loggers import TensorBoardLogger +from torch import Tensor, nn +from torch.utils.data import DataLoader +from torchvision import transforms +from torchvision.datasets import MNIST + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.lightning import NeMoLogger, io, resume +from nemo.lightning.megatron_parallel import DataT, MegatronLossReduction, ReductionT +from nemo.lightning.pytorch import callbacks as nl_callbacks +from nemo.lightning.pytorch.optim import MegatronOptimizerModule +from nemo.lightning.pytorch.plugins import MegatronDataSampler + +TokenizerType = Any + +"""This is intended to be a minimal self-container NeMo2 example.""" + + +T = TypeVar("T") + + +@dataclass +class ExampleConfig(ModelParallelConfig): + """ExampleConfig is a dataclass that is used to configure the model. + + Timers from ModelParallelConfig are required for megatron forward compatibility. + """ + + calculate_per_token_loss: bool = False + + def configure_model(self) -> nn.Module: + """This function is called by the strategy to construct the model. + + Note: Must pass self into Model since model requires having a config object. + + Returns: + The model object. + """ + return ExampleModel(self) + + +class MSELossReduction(MegatronLossReduction): + """A class used for calculating the loss, and for logging the reduced loss across micro batches.""" + + def forward(self, batch: DataT, forward_out: Tensor) -> Tuple[Tensor, ReductionT]: + """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU. + + Args: + batch: A batch of data that gets passed to the original forward inside LitAutoEncoder. + forward_out: the output of the forward method inside LitAutoEncoder. + + Returns: + A tuple containing [, ReductionT] where the loss tensor will be used for + backpropagation and the ReductionT will be passed to the reduce method + (which currently only works for logging.). + """ + x = batch["data"] + outputs = forward_out + x_hat = outputs["x_hat"] + # you could also put a latent loss on z here. + xview = x.view(x.size(0), -1) + loss = nn.functional.mse_loss(x_hat, xview) + + return loss, {"avg": loss} + + def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> Tensor: + """Works across micro-batches. (data on single gpu). + + Note: This currently only works for logging and this loss will not be used for backpropagation. + + Args: + losses_reduced_per_micro_batch: a list of the outputs of forward + + Returns: + A tensor that is the mean of the losses. (used for logging). + """ + mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch]) + return mse_losses.mean() + + +def some_first(seq: Iterable[Optional[T]]) -> T: + """Returns the first non-None value from the sequence or fails""" # noqa: D415 + for s in seq: + if s is not None: + return s + raise ValueError("non-None value not found") + + +def get_dtype_device(torch_object) -> Tuple[torch.dtype, torch.device]: # noqa: D103 + match torch_object: + case []: + raise ValueError("Looking up dtype on an empty list") + case {**data} if not data: + raise ValueError("Looking up dtype on an empty dict") + case torch.Tensor(dtype=dtype, device=device): + return dtype, device + case torch.nn.Module() as m: + try: + p = next(m.parameters()) + except StopIteration as e: + raise ValueError("Cannot get dtype on a torch module with no parameters.") from e + return p.dtype, p.device + case dict(keys=_, values=values): + val = some_first(values()) + return get_dtype_device(val) + case list() as l: + val = some_first(l) + return get_dtype_device(val) + case _: + raise TypeError("Got something we didnt expect") + + +# NOTE(SKH): These types are all wrong, but are close. The inner type must always be a torch.Tensor, but the outer container should be generic. +def batch_collator(batches: Optional[Union[Tuple[ReductionT], List[ReductionT]]]) -> Optional[ReductionT]: + """Takes a sequence of batches and collates them into a single batch. + This is distinct from the standard pytorch default_collator since it does + not add the batch dimension, it's assumed the batch + dimension is already present in the input, as would be the case when + parallelizing across minibatches. + + IMPORTANT: The underlying data primitive _must_ be a torch Tensor. The input to this function is a recurisve type, + there can be any amount of nesting between dictionaries, tuples, and lists, as long as the inner type is a n-d torch.Tensor. + + Examples: + Outer container = Dict: + [{'a': torch.tensor([1]), 'b': torch.tensor([2])}, {'a': torch.tensor([2]), 'b': torch.tensor([3])}] -> {'a': torch.tensor([1, 2]), 'b': torch.tensor([2, 3])} + Outer container = List: + [[torch.tensor([1]), torch.tensor([2])], [torch.tensor([2]), torch.tensor([3])]] -> [torch.tensor([1, 2]), torch.tensor([2, 3])] + Outer container = Tuple: + ([torch.tensor([1]), torch.tensor([2])], [torch.tensor([2]), torch.tensor([3])]) -> (torch.tensor([1, 2]), torch.tensor([2, 3])) + + Args: + batches (Optional[Sequence[ReductionT]]): sequence of batches to collate into a single batch. + + Returns: + A single batch of the same type as the elements of your input sequence. + """ # noqa: D205 + match batches: + case [torch.Tensor(), *_]: + return torch.cat(batches, dim=0) + case [dict(), *_]: + return {key: batch_collator([batch[key] for batch in batches]) for key in batches[0]} + case [tuple(), *_]: + return tuple(batch_collator([batch[i] for batch in batches]) for i in range(len(batches[0]))) + case [list(), *_]: + return [batch_collator([batch[i] for batch in batches]) for i in range(len(batches[0]))] + case None: + return None + case []: + raise ValueError("Cannot process an empty sequence") + case _: + raise ValueError("Unsupported input structure in batch_collator") + + +class PassthroughLossReduction(MegatronLossReduction): + """Internally in NeMo2.0 the forward step is always expected to return a loss reduction class, and forward is expected to return a loss. + This class hijacks that mechanism to instead pass through the forward output unperturbed as the loss (to enable inference in the predict step), and then the + reduce method is used to collate the batch of forward outputs into a single batch. This supports the model forward output being a tensor, dict, tuple, + or list of tensors. The inner type _must always be a torch.Tensor_. + """ # noqa: D205 + + def forward(self, batch: DataT, forward_out: DataT) -> Tuple[torch.Tensor, DataT]: + """_summary_ + + Args: + batch (DataT): The batch of data that was passed through the model to generate output. + forward_out (torch.Tensor): The output from your model's forward pass. + + Returns: + Tuple[torch.Tensor, ReductionT]: A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified). + """ # noqa: D415 + dtype, device = get_dtype_device(forward_out) + return torch.zeros(1, device=device, dtype=dtype), forward_out + + def reduce(self, forward_out: List[DataT]) -> DataT: + """This overrides the standard reduce with a simplified version that just takes a list of your model's forward outputs + and collates them togehter into a single output. + + Args: + forward_out (List[ReductionT]): _description_ + + Returns: + ReductionT: _description_ + """ # noqa: D205 + return batch_collator(forward_out) + + +class LitAutoEncoder(pl.LightningModule, io.IOMixin, io.ConnectorMixin): + """A very basic lightning module for testing the megatron strategy and the megatron-nemo2-bionemo contract.""" + + def __init__(self, config): + """Initializes the model. + + Args: + config: a Config object necessary to construct the actual nn.Module (the thing that has the parameters). + """ + super().__init__() + self.config = config + self.optim = MegatronOptimizerModule( + config=OptimizerConfig(lr=1e-4, optimizer="adam", use_distributed_optimizer=True), + ) + # Bind the configure_optimizers method to the model + self.optim.connect(self) + + def forward(self, batch: Dict, batch_idx: Optional[int] = None) -> Any: + """This forward will be called by the megatron scheduler and it will be wrapped. + + !!! note + + The `training_step` defines the training loop and is independent of the `forward` method here. + + Args: + batch: A dictionary of data. + batch_idx: The index of the batch. + + Returns: + The output of the model. + """ + x = batch["data"] + return self.module(x) + + def training_step(self, batch, batch_idx: Optional[int] = None): + """The training step is where the loss is calculated and the backpropagation is done. + + Background: + - NeMo's Strategy overrides this method. + - The strategies' training step will call the forward method of the model. + - That forward method then calls the wrapped forward step of MegatronParallel which wraps the forward method of the model. + - That wrapped forward step is then executed inside the Mcore scheduler, which calls the `_forward_step` method from the + MegatronParallel class. + - Which then calls the training_step function here. + + In this particular use case, we simply call the forward method of this class, the lightning module. + + Args: + batch: A dictionary of data. requires `batch_idx` as default None. + batch_idx: The index of the batch. + """ + return self(batch, batch_idx) + + def training_loss_reduction(self) -> MegatronLossReduction: # noqa: D102 + # This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss + return MSELossReduction() + + def validation_loss_reduction(self) -> MegatronLossReduction: # noqa: D102 + return MSELossReduction() + + def test_loss_reduction(self) -> MegatronLossReduction: # noqa: D102 + return MSELossReduction() + + def predict_loss_reduction(self) -> MegatronLossReduction: # noqa: D102 + # This allows us to do inference (not output the loss) + return PassthroughLossReduction() + + def configure_model(self) -> None: # noqa: D102 + self.module = self.config.configure_model() + + +class ExampleModel(MegatronModule): # noqa: D101 + def __init__(self, config: ModelParallelConfig) -> None: + """Constructor of the model. + + Args: + config: The config object is responsible for telling the strategy what model to create. + """ + super().__init__(config) + self.model_type = ModelType.encoder_or_decoder + self.linear1 = nn.Linear(28 * 28, 64) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(64, 3) + self.linear3 = nn.Linear(3, 64) + self.relu2 = nn.ReLU() + self.linear4 = nn.Linear(64, 28 * 28) + + def forward(self, x: Tensor) -> Dict[str, Tensor]: + """Forward pass of the model. + + Args: + x: The input data. + + Returns: + x_hat: The result of the last linear layer of the network. + """ + x = x.view(x.size(0), -1) + z = self.linear1(x) + z = self.relu(z) + z = self.linear2(z) + x_hat = self.linear3(z) + x_hat = self.relu2(x_hat) + x_hat = self.linear4(x_hat) + return {"x_hat": x_hat, "z": z} + + def set_input_tensor(self, input_tensor: Optional[Tensor]) -> None: + """This is needed because it is a megatron convention. Even if it is a no-op for single GPU testing. + + See megatron.model.transformer.set_input_tensor() + + Note: Currently this is a no-op just to get by an mcore function. + + Args: + input_tensor: Input tensor. + """ + pass + + +class MnistItem(TypedDict): + data: Tensor + label: Tensor + idx: int + + +class MNISTCustom(MNIST): # noqa: D101 + def __getitem__(self, index: int) -> MnistItem: + """Wraps the getitem method of the MNIST dataset such that we return a Dict + instead of a Tuple or tensor. + + Args: + index: The index we want to grab, an int. + + Returns: + A dict containing the data ("x"), label ("y"), and index ("idx"). + """ # noqa: D205 + x, y = super().__getitem__(index) + + return { + "data": x, + "label": y, + "idx": index, + } + + +# TODO: remove this callback after `val` loss is logged by default in training in NeMo2 +class LossLoggingCallback(pl.Callback): # noqa: D101 + def __init__(self): + """Log the loss at the end of each batch. For training do not reduce across the epoch but do so for validation/test.""" + self.val_losses = [] + self.test_losses = [] + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): # noqa: D102 + # Assuming the loss is computed internally and stored in pl_module + if torch.distributed.get_rank() == 0 and parallel_state.is_pipeline_last_stage(): + if isinstance(outputs, dict): + outputs = outputs["loss"] + loss = outputs + pl_module.log("train_loss", loss, on_step=True, prog_bar=True, logger=True, rank_zero_only=True) + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): # noqa: D102 + if torch.distributed.get_rank() == 0 and parallel_state.is_pipeline_last_stage(): + if isinstance(outputs, dict): + outputs = outputs["loss"] + loss = outputs + self.test_losses.append(loss) + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): # noqa: D102 + # Assuming the loss is computed internally and stored in pl_module + if torch.distributed.get_rank() == 0 and parallel_state.is_pipeline_last_stage(): + if isinstance(outputs, dict): + outputs = outputs["loss"] + loss = outputs + self.val_losses.append(loss) + + def on_validation_epoch_end(self, trainer, pl_module): # noqa: D102 + if torch.distributed.get_rank() == 0 and parallel_state.is_pipeline_last_stage(): + if len(self.val_losses) > 0: + avg_val_loss = torch.stack(self.val_losses).mean() + pl_module.log("val_loss", avg_val_loss, prog_bar=True, logger=True, rank_zero_only=True) + self.val_losses.clear() + + def on_test_epoch_end(self, trainer, pl_module): # noqa: D102 + if torch.distributed.get_rank() == 0 and parallel_state.is_pipeline_last_stage(): + if len(self.test_losses) > 0: + avg_test_loss = torch.stack(self.test_losses).mean() + pl_module.log("test_loss", avg_test_loss, prog_bar=True, logger=True, rank_zero_only=True) + self.test_losses.clear() + + +class MNISTDataModule(pl.LightningDataModule): # noqa: D101 + def __init__(self, data_dir: str = "./", batch_size: int = 32) -> None: # noqa: D107 + super().__init__() + self.data_dir = data_dir + self.batch_size = batch_size + self.micro_batch_size = 8 + self.global_batch_size = 8 + self.max_len = 100 + self.rampup_batch_size = None + + # Note that this sampler is sequential, meaning it does not do any shuffling. Let's wrap our data in a shuffler. + # Wraps the datasampler with the MegatronDataSampler. The MegatronDataSampler is a wrapper that allows the sampler + # to be used with megatron. It sets up the capability to utilize micro-batching and gradient accumulation. It is also + # the place where the global batch size is constructed. + self.data_sampler = MegatronDataSampler( + seq_len=self.max_len, + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + rampup_batch_size=self.rampup_batch_size, + ) + + def setup(self, stage: str) -> None: + """Sets up the datasets + + Args: + stage: can be one of train / test / predict. + """ # noqa: D415 + self.mnist_test = MNISTCustom(self.data_dir, download=True, transform=transforms.ToTensor(), train=False) + self.mnist_predict = MNISTCustom(self.data_dir, download=True, transform=transforms.ToTensor(), train=False) + mnist_full = MNISTCustom(self.data_dir, download=True, transform=transforms.ToTensor(), train=True) + self.mnist_train, self.mnist_val = torch.utils.data.random_split( + mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42) + ) + + def train_dataloader(self) -> DataLoader: # noqa: D102 + return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=0) + + def val_dataloader(self) -> DataLoader: # noqa: D102 + return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=0) + + def test_dataloader(self) -> DataLoader: # noqa: D102 + return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=0) + + +### Begin model environment related utilities +def _reset_megatron_parallel_state(): + """Resets _GLOBAL_NUM_MICROBATCHES_CALCULATOR in megatron which is used in NeMo to initialized model parallel in + nemo.collections.nlp.modules.common.megatron.megatron_init.initialize_model_parallel_for_nemo + """ # noqa: D205, D415 + megatron.core.num_microbatches_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None + # Clean up any process groups created in testing + torch.cuda.empty_cache() + if parallel_state.is_initialized(): + parallel_state.destroy_model_parallel() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +@contextmanager +def reset_megatron_parallel_state() -> Iterator[None]: + """Puts you into a clean parallel state, and again tears it down at the end.""" + try: + _reset_megatron_parallel_state() + yield + finally: + _reset_megatron_parallel_state() + + +@pytest.mark.run_only_on("GPU") +@pytest.mark.integration +def test_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(): + path = os.path.abspath(__file__) + call = f"python {path}" + # Raises a CalledProcessError if there is a failure in the subprocess + subprocess.check_call(call, shell=True, stdout=sys.stdout, stderr=sys.stdout) + + +def run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(): + """This is the actual test that will get run in a subprocess so it does not contaminate the state of other tests.""" + with tempfile.TemporaryDirectory() as tmpdir_str: + tmpdir = Path(tmpdir_str) + assert tmpdir.exists() + assert tmpdir.is_dir() + with reset_megatron_parallel_state(): + # Configure our custom Checkpointer + name = "test_experiment" + checkpoint_callback = nl_callbacks.ModelCheckpoint( + save_best_model=True, + save_last=True, + monitor="val_loss", + save_top_k=1, + every_n_train_steps=5, + # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe + enable_nemo_ckpt_io=True, + ) + root_dir = tmpdir + save_dir = root_dir / name + tb_logger = TensorBoardLogger(save_dir=str(save_dir), name=name) + # Setup the logger and train the model + nemo_logger = NeMoLogger( + dir=str(root_dir), # WARNING: passing a path in here results in mutating the Path class. + name=name, + tensorboard=tb_logger, + ckpt=checkpoint_callback, + ) + # Needed so that the trainer can find an output directory for the profiler + # nemo_logger.save_dir = tmpdir + + model = LitAutoEncoder(config=ExampleConfig()) + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + ddp="megatron", + find_unused_parameters=True, + enable_nemo_ckpt_io=True, + ) + trainer = nl.Trainer( + accelerator="gpu", + devices=1, + strategy=strategy, + limit_val_batches=5, + val_check_interval=5, + max_steps=20, + num_nodes=1, + log_every_n_steps=5, + callbacks=[io.track_io(LossLoggingCallback)()], + ) + data_module = MNISTDataModule(data_dir=tmpdir) + llm.train( + model=model, + data=data_module, + trainer=trainer, + log=nemo_logger, + resume=resume.AutoResume( + path=None, # Overrides the path found by resume_if_exists when set. + resume_if_exists=True, # Looks for the -last checkpoint to continue training. + resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint. + ), + ) + trainer._teardown() + with reset_megatron_parallel_state(): + pred_strategy = nl.MegatronStrategy( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + ddp="megatron", + find_unused_parameters=True, + enable_nemo_ckpt_io=True, + data_sampler=MegatronDataSampler( + seq_len=28 * 28, + micro_batch_size=2, + global_batch_size=2, + output_log=False, # Disable logs to support predict_step + ), + ) + predict_trainer = nl.Trainer( + accelerator="gpu", + devices=1, + strategy=pred_strategy, + default_root_dir=str(root_dir), # WARNING: passing a path in here results in mutating the Path class. + ) + ckpt_path = checkpoint_callback.last_model_path.replace( + ".ckpt", "" + ) # strip .ckpt off the end of the last path + + assert Path( + ckpt_path + ).exists(), f"checkpoint {ckpt_path} not found in {os.listdir(Path(ckpt_path).parent)}" + # FIXME: the below checkpoint loading strategy and manual module unwrapping probably only works in single GPU + # and maybe DDP. + unwrapped_trained_model = trainer.model.module # TODO clean this up. Would be good not to have to unwrap. + forward_output = batch_collator( + predict_trainer.predict( + unwrapped_trained_model, dataloaders=data_module.test_dataloader(), ckpt_path=ckpt_path + ) + ) + assert set(forward_output.keys()) == { + "z", + "x_hat", + }, f"We expect forward output from predit_step, not the loss, got: {forward_output}" + assert forward_output["x_hat"].shape == (len(data_module.mnist_test), 28 * 28) + assert forward_output["z"].shape == (len(data_module.mnist_test), 3) # latent bottleneck in model of dim 3 + predict_trainer._teardown() + + +if __name__ == "__main__": + # Have the test run this one item as a subprocess call + run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu()