Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stateful dataloaders do not load their state_dict if self.trainer.estimated_stepping_batches called beforehand #20550

Open
ItamarKanter opened this issue Jan 16, 2025 · 0 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.5.x

Comments

@ItamarKanter
Copy link

ItamarKanter commented Jan 16, 2025

Bug description

stateful dataloaders do not load their stat_dict and restore their state if trainer.estimated_stepping_batches is called
The situation pops up when one uses lr_scheduler.OneCycleLR which requires the total_steps

What version are you seeing the problem on?

v2.5

How to reproduce the bug

this code is adopted from PL test_resume_mid_epoch_warning

from pathlib import Path
import torch
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.utilities import CombinedLoader
from lightning.pytorch import Trainer


class NotStatefulIterable:
    def __init__(self, start=0):
        self.index = start

    def __iter__(self):
        for i in range(self.index, len(self)):
            self.index = i
            yield self.index

    def __len__(self):
        return 10


class StatefulIterable(NotStatefulIterable):
    def state_dict(self):
        return {"index": self.index}

    def load_state_dict(self, state_dict):
        self.index = state_dict["index"] + 1

    # Single stateful DataLoader


train_dataloader_factory = lambda: CombinedLoader(StatefulIterable())
has_state = True
batches_before = [0, 1]
batches_after = [2, 3]

tmp_path = Path(".")


class DummyModel(BoringModel):
    def __init__(self):
        super().__init__()
        self.seen_data = []

    def configure_optimizers(self):
        total_steps = int(
            self.trainer.estimated_stepping_batches
        )  # for torch.optim.lr_scheduler.OneCycleLR

    def training_step(self, batch, batch_idx):
        self.seen_data.append(batch)
        print(batch)

    def train_dataloader(self):
        return train_dataloader_factory()


trainer_kwargs = {
    "default_root_dir": tmp_path,
    "accelerator": "cpu",
    "enable_checkpointing": False,
    "enable_model_summary": False,
    "enable_progress_bar": False,
    "logger": False,
    "num_sanity_val_steps": 0,
}

# Train for 2 steps
model = DummyModel()
trainer = Trainer(**trainer_kwargs, max_steps=2, max_epochs=10)
trainer.fit(model)
assert model.seen_data == batches_before

# Save a checkpoint
trainer.save_checkpoint(tmp_path / "checkpoint.ckpt")
checkpoint = torch.load(tmp_path / "checkpoint.ckpt")
if has_state:
    assert checkpoint["loops"]["fit_loop"]["state_dict"]["combined_loader"]
else:
    assert "combined_loader" not in checkpoint["loops"]["fit_loop"]["state_dict"]

# Restore training from step 2 and continue 2 more steps
model = DummyModel()
trainer = Trainer(**trainer_kwargs, max_steps=4, max_epochs=10)
trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
assert model.seen_data == batches_after

Error messages and logs

assert model.seen_data == batches_after
AssertionError

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

It has to do with trainer.estimated_stepping_batches that invokes self.fit_loop.setup_data() during strategy.setup and than when self.fit_loop.setup_data() invoked again in self._run_stage() it skips the state_dict loading

@ItamarKanter ItamarKanter added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jan 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.5.x
Projects
None yet
Development

No branches or pull requests

1 participant