Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP integration #6152

Closed
wants to merge 77 commits into from
Closed

FSDP integration #6152

wants to merge 77 commits into from

Conversation

SeanNaren
Copy link
Contributor

@SeanNaren SeanNaren commented Feb 23, 2021

What does this PR do?

Integrates fully sharded (ZeRO Stage 3) parallelism as seen in https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html

This also deprecates pipe + stops CI running the tests by updating the fairscale installation, as we move towards a full replacement primarily due to elegance and a long term future using FSDP.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@SeanNaren SeanNaren added feature Is an improvement or enhancement distributed Generic distributed-related topic 3rd party Related to a 3rd-party labels Feb 23, 2021
@SeanNaren SeanNaren self-assigned this Feb 23, 2021
@SeanNaren
Copy link
Contributor Author

To test cpu offload and to help me fix:

from argparse import ArgumentParser
import torch
from torch.nn import functional as F

import pytorch_lightning as pl
from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule
from pytorch_lightning.plugins import FullShardedPlugin


class LitClassifier(pl.LightningModule):

    def __init__(self, hidden_dim=128, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.trainer.model.parameters(), lr=self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--hidden_dim', type=int, default=128)
        parser.add_argument('--learning_rate', type=float, default=0.0001)
        return parser


def cli_main():
    pl.seed_everything(1234)
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = LitClassifier.add_model_specific_args(parser)
    parser = MNISTDataModule.add_argparse_args(parser)
    args = parser.parse_args()

    dm = MNISTDataModule.from_argparse_args(args)
    model = LitClassifier(args.hidden_dim, args.learning_rate)

    trainer = pl.Trainer.from_argparse_args(args, plugins=FullShardedPlugin(cpu_offload=True), precision=16, gpus=1,
                                            max_epochs=1)
    trainer.fit(model, datamodule=dm)


if __name__ == '__main__':
    cli_main()

@Lightning-AI Lightning-AI deleted a comment from codecov bot Feb 23, 2021
@codecov
Copy link

codecov bot commented Feb 23, 2021

Codecov Report

Merging #6152 (78d52b5) into master (29357ba) will decrease coverage by 1%.
The diff coverage is 44%.

@@           Coverage Diff           @@
##           master   #6152    +/-   ##
=======================================
- Coverage      87%     86%    -1%     
=======================================
  Files         200     202     +2     
  Lines       12857   13049   +192     
=======================================
+ Hits        11224   11273    +49     
- Misses       1633    1776   +143     

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some shallow comments

pytorch_lightning/plugins/precision/precision_plugin.py Outdated Show resolved Hide resolved
pytorch_lightning/plugins/training_type/full_sharded.py Outdated Show resolved Hide resolved
pytorch_lightning/plugins/training_type/full_sharded.py Outdated Show resolved Hide resolved
pytorch_lightning/utilities/enums.py Outdated Show resolved Hide resolved
requirements/extra.txt Outdated Show resolved Hide resolved
@pep8speaks
Copy link

pep8speaks commented Feb 24, 2021

Hello @SeanNaren! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2021-05-04 16:27:51 UTC

pytorch_lightning/overrides/fairscale.py Outdated Show resolved Hide resolved
pytorch_lightning/overrides/fairscale.py Outdated Show resolved Hide resolved
pytorch_lightning/plugins/precision/native_amp.py Outdated Show resolved Hide resolved
pytorch_lightning/plugins/precision/native_amp.py Outdated Show resolved Hide resolved
pytorch_lightning/plugins/training_type/full_sharded.py Outdated Show resolved Hide resolved
pytorch_lightning/plugins/training_type/rpc_sequential.py Outdated Show resolved Hide resolved
pytorch_lightning/utilities/enums.py Outdated Show resolved Hide resolved
@@ -47,7 +47,7 @@ def test_invalid_apex_sharded(tmpdir):
"""

model = BoringModel()
with pytest.raises(MisconfigurationException, match='Sharded Plugin is not supported with Apex AMP'):
with pytest.raises(MisconfigurationException, match='Sharded Plugins are not supported with Apex AMP'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also extend the testcase here? So that we test with all sharded plugins?

@SeanNaren
Copy link
Contributor Author

Thanks for the extensive reviews guys it's much appreciated :) to give a status update

Issues

CPU Offload doesn't work currently with PyTorch Native AMP (no idea about APEX) when using the AMP Grad Scaler. This can be tracked by this issue: facebookresearch/fairscale#421. Whilst we discuss long term upstream solution, we can use @tchaton's fix of delaying the grad moves (performance to be confirmed) 59dbb83 or we can disable this functionality for now. I think @tchaton solution may be worth while, but we might be able to move the logic into the FSDP class. Need to POC this out.

Flatten parameters removes all parameters within the lightning module, and moves them to a contiguous tensor. This means that when you're building your optimizer, you should refer to the wrapped model, not individual layers or the lightning module via self. This is being tracked in facebookresearch/fairscale#430 but will be tricky to fix as this requires a re-write of the base bucketing (which is planned it seems).
For now I think an iterative solution here is to define a property in the lightning module that accesses the wrapped model, and give doc info on how to setup your optimizer using this property. We don't have a choice, given Myles' benchmarks: facebookresearch/fairscale#430 (comment)

Help Needed!

To get the best benefits (of scaling to ridiculous parameter sizes), we need to recommend wrapping child modules in the FSDP wrapper, like suggested in the main issue facebookresearch/fairscale#413.

Because the FSDP wrapper requires torch distributed to be created, we need to delay the wrapping. I've gotten past this in the future by doing something like below:

from pytorch_lightning.plugins.fully_sharded import ShardedModule
class MyModel(pl.LightningModule):

    def __init__(self):
          ...
          self.linear = ShardedModule(torch.nn.Linear(5, 5))

ShardedModule will then do something like this:

class ShardedModule(nn.Module):
     def __init__(self, module):
         ...
         self.module = module

    def init_module(self):
         self.module = FullyShardedDataParallel(self.module)

   def forward(self, *args, **kwargs):
         return self.module(*args, **kwargs)

Any thoughts here?

@myleott
Copy link

myleott commented Feb 24, 2021

To get the best benefits (of scaling to ridiculous parameter sizes), we need to recommend wrapping child modules in the FSDP wrapper, like suggested in the main issue facebookresearch/fairscale#413.

Note that this will be needed for large models (20B+) even after we add automatic wrapping, since initializing models of this size on CPU can cause users to run out of system RAM. Thus you want to be wrapping layers with FSDP as you initialize the layers, so that they get sharded in place and free CPU memory.

@SeanNaren
Copy link
Contributor Author

To get the best benefits (of scaling to ridiculous parameter sizes), we need to recommend wrapping child modules in the FSDP wrapper, like suggested in the main issue facebookresearch/fairscale#413.

Note that this will be needed for large models (20B+) even after we add automatic wrapping, since initializing models of this size on CPU can cause users to run out of system RAM. Thus you want to be wrapping layers with FSDP as you initialize the layers, so that they get sharded in place and free CPU memory.

Ah thanks for this, I did not see this! Brings up a great point that lazy init as I suggested won't solve in this case so we'll need to re-think further

@SeanNaren
Copy link
Contributor Author

SeanNaren commented Apr 26, 2021

For some reason when the typing initiative was going on the model was removed from the precision plugin clip_gradients signature. My guess is that there was no use in the library at the current stage but it may have been added in an incremental step towards getting FSDP integrated.

f29ecbf#diff-3facc0e73962d7c559c4257f0845ee7de30191a51017643c0f8f83bb0edb8a12L79 cc @carmocca in case there was another reason it was removed.

I've added this back in this PR and specified that it can be a torch.nn.Module as well, since the model could be wrapped.

@shuyingsunshine21 glad it fixed the issue! We'll definitely highlight in the docs. Let me know how your XLM experiments go!

@shuyingsunshine21
Copy link
Contributor

Let me know how your XLM experiments go!

you might tag the wrong person. will let you know (testing now)

for the clip_gradients signature part, was also about to raise that as well, found that when I rebase

@shuyingsunshine21
Copy link
Contributor

shuyingsunshine21 commented Apr 27, 2021

has another question, as FSDP wrap the whole lightning module, when we setup metrics (where the metric class has tensor fields) in the module. When using FP16, that would be converted to float16 also, this would cause some problem where the metric computation make use of it and some other metrics computing on the fly which are of type float32.

is there a workaround for not casting those to float16?

import torch
import torch.nn.functional as F
import torchmetrics as metrics

# from fairscale.nn import wrap
from pytorch_lightning import Trainer, LightningModule, LightningDataModule
from pytorch_lightning.plugins import FullyShardedPlugin
from torch.nn import Linear, ReLU, Sequential, BCEWithLogitsLoss
from torch.utils.data import DataLoader, Dataset


class RandomDataTensor:
    def __init__(self, size, length, num_classes):
        self.data = torch.randn(length, size)
        # multi-class label
        self.label = torch.zeros([length, num_classes])
        for i in range(length):
            self.label[i][torch.randint(num_classes, (3,))] = 1


class RandomDataset(Dataset):
    def __init__(self, random_data_tensor):
        self.len = random_data_tensor.data.shape[0]
        self.data = random_data_tensor.data
        self.label = random_data_tensor.label

    def __getitem__(self, index):
        return {"sample": self.data[index], "label": self.label[index]}

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.num_classes = 10
        self.model = Sequential(
            Linear(32, 32),
            ReLU(),
            Linear(32, self.num_classes),
        )
        self.loss = BCEWithLogitsLoss()
        self.model_unrelated_parameter = torch.ones(3, 5, dtype=torch.float32)          # <--- here is similar as a Metric class owning this

    def training_step(self, batch, batch_idx):
        logits = self(batch)
        loss = self.loss(logits, batch["label"])
        return {"loss": loss}

    def configure_optimizers(self):
        self.optimizer = torch.optim.SGD(
            self.trainer.model.parameters(),
            lr=0.1,
        )
        lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1)
        return [self.optimizer], [lr_scheduler]

    def on_train_start(self):
        assert self.model_unrelated_parameter.dtype == torch.float32. #   <--- here the model_unrelated_parameter type is float16 for ddp_fully_sharded

model = BoringModel()
data_module = LightningDataModule.from_datasets(
    train_dataset=RandomDataset(dataset_tensors["train"]),
    val_dataset=RandomDataset(dataset_tensors["val"]),
    test_dataset=RandomDataset(dataset_tensors["test"]),
    batch_size=16,
)
trainer = Trainer(
    gpus=1, max_epochs=1, precision=16, accelerator="ddp_fully_sharded"
)
trainer.fit(model, datamodule=data_module)

as pointed above, the assertion in on_train_start fails.

@mergify mergify bot removed the has conflicts label Apr 27, 2021
@SeanNaren
Copy link
Contributor Author

SeanNaren commented Apr 27, 2021

@shuyingsunshine21 I'm not able to reproduce this, there were a few missing definitions in your example so I made a new one:

import torch
from pytorch_lightning import Trainer, LightningModule
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Sequential(
            torch.nn.Linear(32, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 2)
        )
        self.model_unrelated_parameter = torch.ones(3, 5, dtype=torch.float32)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def training_step(self, batch, batch_idx):
        output = self(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        output = self(batch)
        loss = self.loss(batch, output)
        return {"x": loss}

    def test_step(self, batch, batch_idx):
        output = self(batch)
        loss = self.loss(batch, output)
        return {"y": loss}

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.trainer.model.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64))

    def val_dataloader(self):
        return DataLoader(RandomDataset(32, 64))

    def test_dataloader(self):
        return DataLoader(RandomDataset(32, 64))

    def on_train_start(self):
        assert self.model_unrelated_parameter.dtype == torch.float32


if __name__ == '__main__':
    model = BoringModel()
    trainer = Trainer(
        max_epochs=1,
        gpus=1,
        precision=16,
        plugins='ddp_fully_sharded'
    )
    trainer.fit(model)
    trainer.test(model)

This runs fine, and the assertion passes, is there something missing from the above that I omitted?

EDIT: I also see that sync batch norm is supported in FSDP, so we should remove the guard, is that right @min-xu-ai

@shuyingsunshine21
Copy link
Contributor

shuyingsunshine21 commented Apr 27, 2021

@SeanNaren , my bad, you are right, the above example is not correct. Let me re-paste, just add TestMetric on top of your example

from torchmetrics import Metric
import torch
from pytorch_lightning import Trainer, LightningModule
from torch.utils.data import DataLoader, Dataset

class TestMetric(Metric):

    thresholds: torch.Tensor

    def __init__(
        self,
        num_thresholds: int = 100,
        compute_on_step: bool = False,
        **kwargs
    ) -> None:
        super().__init__(compute_on_step=compute_on_step, **kwargs)
        self.num_thresholds = num_thresholds
        thresholds = torch.arange(num_thresholds) / num_thresholds
        self.register_buffer("thresholds", thresholds)
        assert self.thresholds.dtype == torch.float32 # <- this part is fine

    def update(self, output: torch.Tensor) -> None:
        assert self.thresholds.dtype == torch.float32 # <- this breaks for fully sharded
        self.predictions = torch.rand((1 , self.num_thresholds), device=output.device)


    def compute(self) -> torch.Tensor:
        assert self.thresholds.dtype == torch.float32 # <- this breaks for fully sharded
        condition = self.predictions >= 0.5  # <- this is float32
        thresholds_at_p = (
            torch.where(
                condition, self.thresholds, torch.scalar_tensor(1e6, device=condition.device)
            )
            .min(dim=1)
            .values
        ) # <- as a result, this computation would fail for fully sharded
        return thresholds_at_p


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Sequential(
            torch.nn.Linear(32, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 2)
        )
        self.val_test_metric = TestMetric(num_thresholds=100)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def training_step(self, batch, batch_idx):
        output = self(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        output = self(batch)
        loss = self.loss(batch, output)
        self.val_test_metric(output)
        return {"x": loss}

    def on_validation_epoch_end(self):
        self.val_test_metric.compute()

    def test_step(self, batch, batch_idx):
        output = self(batch)
        loss = self.loss(batch, output)

        return {"y": loss}

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.trainer.model.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64))

    def val_dataloader(self):
        return DataLoader(RandomDataset(32, 64))

    def test_dataloader(self):
        return DataLoader(RandomDataset(32, 64))


if __name__ == '__main__':
    model = BoringModel()
    trainer = Trainer(
        max_epochs=1,
        gpus=1,
        precision=16,
        plugins='ddp_fully_sharded'
    )
    trainer.fit(model)
    trainer.test(model)

Note: tested the above for ddp and ddp_sharded, works.

@SeanNaren
Copy link
Contributor Author

SeanNaren commented Apr 28, 2021

Thanks @shuyingsunshine21 we can iterate on this!

Another bug I've run into which will block this PR heavily is that the parameters are not kept when going from trainer.fit(model) to trainer.test(model) in the same session.

This is because with flatten_parameters we create a FlattenParamsWrapper that contains all the parameters, and when we move to the next stage, we do not bring the FlattenParamsWrapper parameters with us since we're still referring to the old model.

I think the cleanest fix would be at teardown or after training has finished, we unflatten the parameters back to the original model. Is this doable @min-xu-ai via _unflatten_params? I haven't experimented too much with it but it seems it does what I'm suggesting!

EDIT: a high level reprod for this issue:

import os
import torch
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        gpus=1,
        plugins='ddp_fully_sharded',
        weights_summary=None,
    )
    trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data)

    # Passes
    assert len(trainer.accelerator.training_type_plugin.model.state_dict()) > 0

    trainer.test(model, test_dataloaders=test_data)
    # Fails
    assert len(trainer.accelerator.training_type_plugin.model.state_dict()) > 0


if __name__ == '__main__':
    run()

There are no parameters within the model, since we've assumed the parameters are stored in the model after fit when it really was in the wrapper class that is now lost.

Comment on lines +161 to +166
if self.automatic_module_wrap and not self._model_has_nested_fsdp():
self.model = auto_wrap(LightningFullyShardedModule(self.model))
if not isinstance(self.model, FullyShardedDataParallel):
self.model = wrap(self.model)
else:
self.model = wrap(LightningFullyShardedModule(self.model))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if manually wrapping contents inside the lightning module, is this final outer layer wrap needed? or could we defer this to the user in the lightning module too?

then we could not wrap model in the dummy LightningFullyShardedModule to map forward to one of the step functions. would it also mean users don't have to refer to self.trainer.model inside of the lightning module?

would this avoid the parameter flattening issue across stages?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I understand! Theoretically I think so, since then we're just using the LM as a wrapper. So the cases I see:

  1. User wraps nothing, expects module to be wrapped by Lightning, and potentially auto_wrap to handle recursive wrapping
  2. User wraps some of the layers in configure_sharded_model but then expects all other layers to be included in a higher wrapper class (wrap the entire LM)
  3. User wraps all of the layers in configure_sharded_model, doesn't require any high level wrapping

Solutions

  1. This should be default behaviour, i.e plugins=fsdp or plugins=fsdp_auto_wrap
  2. This should be the same as 1., i.e plugins=fsdp or plugins=fsdp_auto_wrap
  3. This could be plugins=fsdp_manual where we do not wrap the highest level module, allowing the user to do whatever they'd like in configure_optimizers.

In either case, it's important to fix the flattening issue for 1. and 2. which for most users trying out will be the first step I think. Thoughts @ananthsub?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly those 3 cases!

To unblock the initial integration, I was wondering if we should start with option #3 to unblock power users in release candidates with the caveat that they are responsible for the full wrapping. Maybe this could can be option on the plugin as to whether the outer wrap on lightning module needs to be applied in order to distinguish between cases 2 and 3.

Completely agreed with you that most users will opt for cases 1 and 2, so we'll need to figure out the parameter flattening, whether in lightning or fairscale, but wanted to offer this as one way we could sequence these cases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started making changes to see if any issues arise with case 3 and a few observation:

The user still has to define a single model, may it be a Module containing modules in a sequential wrapper, or just defining their own model structure defining a forward function. This means self.model will still probably be required in every case for FSDP to work in configure_optimizers.

I also ran into an issue where clipping grad norms which in manual mode cannot be handled automatically, as we do not wrap the model:

class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
    """Mixed Precision for Full Sharded Training"""

    def clip_gradients(`
        self,
        optimizer: 'Optimizer',
        clip_val: Union[int, float],
        gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
        model: Optional[Module] = None
    ) -> None:
        # Model manages clipping of gradients
        model = cast(FullyShardedDataParallel, model)
        # todo: expose norm type once precision plugin supports this.
        model.clip_grad_norm_(clip_val, norm_type=2.0) # This breaks

A potential solution albeit not as elegant as I'd like, would be to go through the immediate children of the LightningModule, find the root FSDP module and call clip_grad_norm_ on it. I assume this will be a negligible cost added on top of the training loop but what are your thoughts @ananthsub?

@ananthsub ananthsub mentioned this pull request May 5, 2021
14 tasks
Comment on lines +173 to +175
self.model.to(self.root_device)
# ensure we update the device type in the lightning module
self.lightning_module.to(self.root_device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might need to be cautious about this, as fsdp_module.to(device) will summon full parameters first: https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L348-L367

and when we perform teardown for GPU memory cleanup, we have self.lightning_module.cpu()

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/accelerators/gpu.py#L50-L55

return unwrap_lightning_module_fully_sharded(self.model)

def on_save(self, checkpoint: dict) -> dict:
state_dict = self.collate_state_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SeanNaren , after getting detailed memory usage, I finally figured out why originally the full model fits in one GPU, but when checkpointing, it OOM

because in checkpoint_connector (https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L270-L277), we have

  model = self.trainer.lightning_module

  checkpoint = {
      'epoch': current_epoch,
      'global_step': global_step,
      'pytorch-lightning_version': pytorch_lightning.__version__,
      'state_dict': model.state_dict(),
  }

here, we try to collect again, this would double the size.

One easy workaround now, is to add

del  checkpoint['state_dict']

but this is not ideal, we summon the full parameters twice which is unnecessary.

I feel, we should modify that file to let training type plugin to control, something like trainer.accelerator.training_type_plugin.state_dict()

especially when we would like to collect only sharded state dict in the future.

cc @ananthsub

@min-xu-ai , I think this is the root cause for OOM, facebookresearch/fairscale#658 should not be problem (for setting state_dict_device=torch.device("cpu"), CPU OOM should be similar problem as we also double the model storage in CPU)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @shuyingsunshine21 for your help here! This makes sense since we're allocating memory new memory.

I agree with allowing the training type plugin to return the state dict, we already rely on the accelerator to dump the optimizer dicts. I'm happy to make the change!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SeanNaren , thanks, no worry, if you have not already made the change, I could help send a small PR for that.

@SeanNaren
Copy link
Contributor Author

I am closing this in favour of #7487

Remaining is the ability to auto wrap the model so users do not have to manually annotate layers. This will come in followup PRs once we figure out the case :)

@SeanNaren SeanNaren closed this May 14, 2021
@Borda Borda deleted the feat/fsdp branch June 17, 2021 16:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party distributed Generic distributed-related topic feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging this pull request may close these issues.