Skip to content

Commit

Permalink
Add fixme comment around proper checkpoint nemo2 handling
Browse files Browse the repository at this point in the history
Signed-off-by: John St John <jstjohn@nvidia.com>
  • Loading branch information
jstjohn committed Aug 14, 2024
1 parent b0c47bb commit fb07b24
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions tests/collections/llm/test_mnist_model_nemo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,9 +568,16 @@ def test_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(tmpdir):
devices=1,
strategy=pred_strategy,
)
unwrapped_trained_model = trainer.model.module # TODO clean this up. Maybe supply the checkpoint path instead?
ckpt_path = checkpoint_callback.last_model_path.replace(
".ckpt", ""
) # strip .ckpt off the end of the last path
# FIXME: the below checkpoint loading strategy and manual module unwrapping probably only works in single GPU
# and maybe DDP.
unwrapped_trained_model = trainer.model.module # TODO clean this up. Would be good not to have to unwrap.
forward_output = batch_collator(
predict_trainer.predict(unwrapped_trained_model, dataloaders=data_module.test_dataloader())
predict_trainer.predict(
unwrapped_trained_model, dataloaders=data_module.test_dataloader(), ckpt_path=ckpt_path
)
)
assert set(forward_output.keys()) == {
"z",
Expand Down

0 comments on commit fb07b24

Please sign in to comment.