You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using LightningLite and transferring the _LiteModule to cpu, attributes of DeviceDtypeModule are not updated.
To Reproduce
classSomeDummy(DeviceDtypeModuleMixin):
def__init__(self):
super().__init__()
self.a=torch.nn.Linear(1,1)
classMyClass(LightningLite):
defrun(self):
model=SomeDummy()
model, optimiser=self.setup(model, torch.optim.Adam(model.parameters()))
# do some stuff# now clean up gpu memory for later stagesmodel.cpu()
assertstr(model.module.device) =='cpu'MyClass(accelerator='gpu', devices=1).run()
Expected behavior
model.module.device should be Cpu
Additional context
Could probably be solved by using DeviceDtypeModuleMixin as base class for the _LiteModule since this is an issue with the to function only calling _apply on all child tensors instead of calling .to on every child module.
That's an issue we have even without Lite, e.g., it could occur with torchmetrics too. This example shows it:
importtorchfrompytorch_lightning.core.mixinsimportDeviceDtypeModuleMixinclassSomeDummy(DeviceDtypeModuleMixin):
def__init__(self):
super().__init__()
self.a=torch.nn.Linear(1, 1)
classWrapperModule(torch.nn.Module):
def__init__(self):
super().__init__()
self.module=SomeDummy() # this could be a torchmetricw=WrapperModule().cuda()
print(w.module.device) # prints cpu !!! should be cuda:0w.cpu()
print(w.module.device) # prints cpu
As you said, the only solution is to add the DeviceDtypeModuleMixin to the base of the wrapper class because of how child modules get called.
@awaelchli yes, but without Lite this is not our concern :D
From our offline chat (for completeness):
For torch metrics removing it is fine since they don't nest modules there. And overriding the _apply is exactly what the mixin does :)
🐛 Bug
When using LightningLite and transferring the
_LiteModule
to cpu, attributes ofDeviceDtypeModule
are not updated.To Reproduce
Expected behavior
model.module.device should be Cpu
Additional context
Could probably be solved by using
DeviceDtypeModuleMixin
as base class for the_LiteModule
since this is an issue with theto
function only calling_apply
on all child tensors instead of calling.to
on every child module.cc @carmocca @justusschock @awaelchli
The text was updated successfully, but these errors were encountered: