Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightningLite not updating DeviceDtypeModuleMixin correctly #10556

Closed
justusschock opened this issue Nov 16, 2021 · 3 comments · Fixed by #10559
Closed

LightningLite not updating DeviceDtypeModuleMixin correctly #10556

justusschock opened this issue Nov 16, 2021 · 3 comments · Fixed by #10559
Labels
bug Something isn't working fabric lightning.fabric.Fabric
Milestone

Comments

@justusschock
Copy link
Member

justusschock commented Nov 16, 2021

🐛 Bug

When using LightningLite and transferring the _LiteModule to cpu, attributes of DeviceDtypeModule are not updated.

To Reproduce

class SomeDummy(DeviceDtypeModuleMixin):
    def __init__(self):
        super().__init__()
        self.a = torch.nn.Linear(1,1)

class MyClass(LightningLite):
    def run(self):
        model = SomeDummy()
        model, optimiser = self.setup(model, torch.optim.Adam(model.parameters()))

        #  do some stuff
        
        # now clean up gpu memory for later stages
        model.cpu()
        assert str(model.module.device) == 'cpu'

MyClass(accelerator='gpu', devices=1).run()

Expected behavior

model.module.device should be Cpu

Additional context

Could probably be solved by using DeviceDtypeModuleMixin as base class for the _LiteModule since this is an issue with the to function only calling _apply on all child tensors instead of calling .to on every child module.

cc @carmocca @justusschock @awaelchli

@justusschock justusschock added bug Something isn't working fabric lightning.fabric.Fabric labels Nov 16, 2021
@awaelchli
Copy link
Contributor

awaelchli commented Nov 16, 2021

That's an issue we have even without Lite, e.g., it could occur with torchmetrics too. This example shows it:

import torch

from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin


class SomeDummy(DeviceDtypeModuleMixin):
    def __init__(self):
        super().__init__()
        self.a = torch.nn.Linear(1, 1)


class WrapperModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.module = SomeDummy()  # this could be a torchmetric


w = WrapperModule().cuda()
print(w.module.device)  # prints cpu !!! should be cuda:0
w.cpu()
print(w.module.device)  # prints cpu

As you said, the only solution is to add the DeviceDtypeModuleMixin to the base of the wrapper class because of how child modules get called.

@awaelchli awaelchli added this to the 1.5.x milestone Nov 16, 2021
@awaelchli
Copy link
Contributor

awaelchli commented Nov 16, 2021

I noticed, in torchmetrics they abandoned the DeviceDtypeMixin and override the _apply method:

https://github.com/PyTorchLightning/metrics/blob/93cb842f24d15804dd2e7677ca7fc6631b234773/torchmetrics/metric.py#L466-L490

@justusschock
Copy link
Member Author

justusschock commented Nov 16, 2021

@awaelchli yes, but without Lite this is not our concern :D

From our offline chat (for completeness):
For torch metrics removing it is fine since they don't nest modules there. And overriding the _apply is exactly what the mixin does :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working fabric lightning.fabric.Fabric
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants