Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Latest FairScale + Sharded Training crashes using default trainer parameters #6876

Closed
SeanNaren opened this issue Apr 7, 2021 · 1 comment · Fixed by #6877
Closed

Latest FairScale + Sharded Training crashes using default trainer parameters #6876

SeanNaren opened this issue Apr 7, 2021 · 1 comment · Fixed by #6877
Assignees
Labels
3rd party Related to a 3rd-party bug Something isn't working help wanted Open to be worked on

Comments

@SeanNaren
Copy link
Contributor

🐛 Bug

When validation/training is used (as default with the boring model) sharded crashes. This is because internally SDP relies on knowing the training state of the model, and when we run the validation sanity check, we do not set the eval mode correctly on the SDP model itself, so it waits for grads to be reduced since the module is in train mode.

import os

import torch
from torch.utils.data import Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    """
    >>> RandomDataset(size=10, length=20)  # doctest: +ELLIPSIS
    <...bug_report_model.RandomDataset object at ...>
    """

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    """
    >>> BoringModel()  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
    BoringModel(
      (layer): Linear(...)
    )
    """

    def __init__(self):
        """
        Testing PL Module

        Use as follows:
        - subclass
        - modify the behavior for what you want

        class TestModel(BaseTestModel):
            def training_step(...):
                # do your own thing

        or:

        model = BaseTestModel()
        model.training_epoch_end = None

        """
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def step(self, x):
        x = self.layer(x)
        out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
        return out

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"x": loss}

    def test_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"y": loss}

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]


def test_run():
    # fake data
    train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
    val_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
    test_data = torch.utils.data.DataLoader(RandomDataset(32, 64))

    # model
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        max_epochs=1,
        plugins='ddp_sharded',
        gpus=1,
        weights_summary=None,
    )
    trainer.fit(model, train_data, val_data)
    trainer.test(test_dataloaders=test_data)


if __name__ == '__main__':
    test_run()
@SeanNaren SeanNaren added bug Something isn't working help wanted Open to be worked on 3rd party Related to a 3rd-party labels Apr 7, 2021
@SeanNaren SeanNaren self-assigned this Apr 7, 2021
@ananthsub
Copy link
Contributor

1 workaround is setting num_sanity_val_steps=0 but that's very much a short-term fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working help wanted Open to be worked on
Projects
None yet
2 participants