Skip to content

Commit

Permalink
fix a small bug and add a resume test
Browse files Browse the repository at this point in the history
Signed-off-by: ashors1 <ashors@nvidia.com>
  • Loading branch information
ashors1 committed Sep 27, 2024
1 parent fd727f4 commit 1888f54
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 31 deletions.
9 changes: 7 additions & 2 deletions nemo/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,13 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)

## manually update last_model_path so symlink is up-to-date
## should only be done when using a symlink
if self.save_last == "link" and not str(ckpt_to_dir(filepath)).endswith("last"):
self.future_last_model_path = str(ckpt_to_dir(filepath)) + "-last.ckpt"

## if we're not creating a symlink, this could go wrong?
## if we never enter into this for loop, we'll end up with problems
if self.save_last == "link":
self.future_last_model_path = str(ckpt_to_dir(filepath))
if not str(ckpt_to_dir(filepath)).endswith("last"):
self.future_last_model_path += "-last.ckpt"

if ema_callback is not None:
if self.async_save:
Expand Down
106 changes: 77 additions & 29 deletions tests/lightning/pytorch/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,28 @@ def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> Tensor
class ExampleModel(pl.LightningModule, IOMixin):
def __init__(self, *args, **kwargs):
super().__init__()
self.l1 = torch.nn.modules.Linear(in_features=32, out_features=32)
self.bn = torch.nn.BatchNorm1d(32)
self.model_type = "test"
self.validation_step_outputs = []

class DummyConfig(ModelParallelConfig):
calculate_per_token_loss: bool = False
fp8: bool = False
## keeps track of number of validation steps
self.count = torch.zeros((1,))

self.config = DummyConfig()
def configure_model(self):

class NestedModel(torch.nn.Module):

def __init__(self):
super().__init__()
self.l1 = torch.nn.modules.Linear(in_features=32, out_features=32)
self.bn = torch.nn.BatchNorm1d(32)
self.model_type = "test"
self.validation_step_outputs = []

class DummyConfig(ModelParallelConfig):
calculate_per_token_loss: bool = False
fp8: bool = False

self.config = DummyConfig()

self.module = NestedModel()

def forward(self, batch):
return self.l1(self.bn(batch)).sum()
Expand All @@ -123,9 +135,11 @@ def training_step(self, batch):
return self(batch)

def validation_step(self, batch):
loss = self(batch)
self.validation_step_outputs.append(loss)
return loss
## use a dummy validation loss to ensure that loss is decreasing at each step
## which guarantees that the -last checkpoints will be symlinks if specified
self.count += 1
self.validation_step_outputs.append(-self.count)
return -self.count

def test_step(self, batch):
loss = self(batch)
Expand All @@ -139,9 +153,6 @@ def on_validation_epoch_end(self):
self.log("val_loss", torch.stack(self.validation_step_outputs).mean())
self.validation_step_outputs.clear() # free memory

def configure_model(self):
self.module = ExampleModel()

def set_input_tensor(self, input_tensor: Optional[Tensor]) -> None:
pass

Expand All @@ -153,22 +164,31 @@ def validation_loss_reduction(self) -> MegatronLossReduction: # noqa: D102
return PassThroughLossReduction()


def setup_test(path, async_save=False):
def setup_test(path, async_save=False, max_epochs=3):
model = ExampleModel()

data = RandomDataset(32, 64)

resume = nl.AutoResume(
resume_if_exists=True,
resume_ignore_no_checkpoint=True,
)

nemo_logger = nl.NeMoLogger(
log_dir=path,
use_datetime_version=False,
)

strategy = nl.MegatronStrategy(ckpt_async_save=async_save, replace_progress_bar=False)
strategy = nl.MegatronStrategy(
ckpt_async_save=async_save,
replace_progress_bar=False,
)

trainer = nl.Trainer(
max_epochs=5,
max_epochs=max_epochs,
devices=1,
val_check_interval=5,
val_check_interval=6,
log_every_n_steps=4,
callbacks=nl.ModelCheckpoint(
monitor="val_loss",
save_top_k=3,
Expand All @@ -180,9 +200,21 @@ def setup_test(path, async_save=False):
strategy=strategy,
)
nemo_logger.setup(trainer)
resume.setup(trainer)


return data, model, trainer

def get_final_checkpoint(checkpoint_dir):
dist_checkpoints = [d for d in list(checkpoint_dir.glob("*")) if d.is_dir()]
last_checkpoints = [d for d in dist_checkpoints if d.match("*last")]

assert len(last_checkpoints) == 1 ## should only have one -last checkpoint
final_ckpt = last_checkpoints[0]

top_k_checkpoints = [d for d in dist_checkpoints if d not in last_checkpoints]

return final_ckpt, top_k_checkpoints

class TestLinkCheckpoint:

Expand All @@ -198,13 +230,9 @@ def test_link_ckpt(self, tmpdir):
trainer.fit(model, data)

checkpoint_dir = Path(tmp_path / "default" / "checkpoints")
dist_checkpoints = [d for d in list(checkpoint_dir.glob("*")) if d.is_dir()]
last_checkpoints = [d for d in dist_checkpoints if d.match("*last")]
assert len(last_checkpoints) == 1 ## should only have one -last checkpoint
final_ckpt = last_checkpoints[0]
final_ckpt, top_k_checkpoints = get_final_checkpoint(checkpoint_dir)
assert os.path.islink(final_ckpt)

top_k_checkpoints = [d for d in dist_checkpoints if d not in last_checkpoints]
## make sure we're saving the expected number of checkpoints
assert len(top_k_checkpoints) == 3

Expand All @@ -223,14 +251,34 @@ def test_link_ckpt_async(self, tmpdir):
trainer.fit(model, data)

checkpoint_dir = Path(tmp_path / "default" / "checkpoints")
dist_checkpoints = [d for d in list(checkpoint_dir.glob("*")) if d.is_dir()]
last_checkpoints = [d for d in dist_checkpoints if d.match("*last")]
assert len(last_checkpoints) == 1 ## should only have one -last checkpoint
final_ckpt = last_checkpoints[0]
final_ckpt, top_k_checkpoints = get_final_checkpoint(checkpoint_dir)
assert os.path.islink(final_ckpt)

top_k_checkpoints = [d for d in dist_checkpoints if d not in last_checkpoints]
assert len(top_k_checkpoints) == 3

link = final_ckpt.resolve()
assert str(final_ckpt).replace("-last", "") == str(link)

@pytest.mark.unit
@pytest.mark.run_only_on("GPU")
def test_restore_async(self, tmpdir):
"""Test to ensure that we always keep top_k checkpoints, even after resuming."""

with reset_megatron_parallel_state():
tmp_path = tmpdir / "async_link_ckpt_test"
data, model, trainer = setup_test(tmp_path, async_save=True, max_epochs=3)

trainer.fit(model, data)

## reinitialize
data, model, trainer = setup_test(tmp_path, async_save=True, max_epochs=6)

trainer.fit(model, data)

checkpoint_dir = Path(tmp_path / "default" / "checkpoints")
final_ckpt, top_k_checkpoints = get_final_checkpoint(checkpoint_dir)
assert os.path.islink(final_ckpt)
assert len(top_k_checkpoints) == 3

epoch = str(final_ckpt).split('epoch=')[1][0]
assert int(epoch) == 5 ## make sure we're running the correct number of epochs

0 comments on commit 1888f54

Please sign in to comment.