Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix manual backward for DeepSpeed #13882

Merged
merged 13 commits into from
Jul 27, 2022
3 changes: 2 additions & 1 deletion src/pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def backward(
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> None:
Expand All @@ -71,7 +72,7 @@ def backward(
"""
opt = optimizer or model.trainer.optimizers
with amp.scale_loss(closure_loss, opt) as closure_loss:
super().backward(model, closure_loss, optimizer, *args, **kwargs)
super().backward(model, closure_loss, optimizer, optimizer_idx, *args, **kwargs)

def optimizer_step(
self,
Expand Down
10 changes: 9 additions & 1 deletion src/pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,15 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona
self.amp_type = amp_type
self.amp_level = amp_level

def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any, **kwargs: Any) -> None:
def backward(
self,
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> None:
if is_overridden("backward", model):
warning_cache.warn(
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/plugins/precision/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, precision: int) -> None:
super().__init__()
self.precision = precision

def backward(self, model: "pl.LightningModule", *args: Any, **kwargs: Any) -> None:
def backward(self, model: "pl.LightningModule", *_: Any, **__: Any) -> None:
if is_overridden("backward", model):
warning_cache.warn(
"You have overridden the `LightningModule.backward` hook but it will be ignored since IPUs handle"
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def backward(
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> None:
Expand All @@ -76,7 +77,7 @@ def backward(
"""
# do backward pass
if model is not None and isinstance(model, pl.LightningModule):
model.backward(closure_loss, optimizer, *args, **kwargs)
model.backward(closure_loss, optimizer, optimizer_idx, *args, **kwargs)
else:
self._run_backward(closure_loss, *args, **kwargs)

Expand Down
11 changes: 9 additions & 2 deletions src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,14 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
"""
return optimizer.state_dict()

def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
def backward(
self,
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> Tensor:
"""Forwards backward-calls to the precision plugin.

Args:
Expand All @@ -181,7 +188,7 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
assert self.lightning_module is not None
closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss)

self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
self.precision_plugin.backward(self.lightning_module, closure_loss, optimizer, optimizer_idx, *args, **kwargs)

closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss)
self.post_backward(closure_loss)
Expand Down
3 changes: 3 additions & 0 deletions tests/tests_pytorch/lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytorch_lightning.strategies import DeepSpeedStrategy, Strategy
from pytorch_lightning.utilities import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _RequirementAvailable
carmocca marked this conversation as resolved.
Show resolved Hide resolved
from pytorch_lightning.utilities.seed import pl_worker_init_function
from tests_pytorch.helpers.runif import RunIf

Expand Down Expand Up @@ -405,6 +406,8 @@ def test_autocast():
lite._precision_plugin.forward_context().__exit__.assert_called()


# https://github.com/microsoft/DeepSpeed/issues/2139
@pytest.mark.skipif(_RequirementAvailable("deepspeed>=0.6.5"), reason="Lite does not support 0.6.5")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multiple_models():
class Lite(LightningLite):
Expand Down