Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix manual backward for DeepSpeed #13882

Merged
merged 13 commits into from
Jul 27, 2022
3 changes: 2 additions & 1 deletion src/pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def backward(
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> None:
Expand All @@ -71,7 +72,7 @@ def backward(
"""
opt = optimizer or model.trainer.optimizers
with amp.scale_loss(closure_loss, opt) as closure_loss:
super().backward(model, closure_loss, optimizer, *args, **kwargs)
super().backward(model, closure_loss, optimizer, optimizer_idx, *args, **kwargs)

def optimizer_step(
self,
Expand Down
10 changes: 9 additions & 1 deletion src/pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,15 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona
self.amp_type = amp_type
self.amp_level = amp_level

def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any, **kwargs: Any) -> None:
def backward(
self,
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> None:
if is_overridden("backward", model):
warning_cache.warn(
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/plugins/precision/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, precision: int) -> None:
super().__init__()
self.precision = precision

def backward(self, model: "pl.LightningModule", *args: Any, **kwargs: Any) -> None:
def backward(self, model: "pl.LightningModule", *_: Any, **__: Any) -> None:
if is_overridden("backward", model):
warning_cache.warn(
"You have overridden the `LightningModule.backward` hook but it will be ignored since IPUs handle"
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def backward(
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> None:
Expand All @@ -76,7 +77,7 @@ def backward(
"""
# do backward pass
if model is not None and isinstance(model, pl.LightningModule):
model.backward(closure_loss, optimizer, *args, **kwargs)
model.backward(closure_loss, optimizer, optimizer_idx, *args, **kwargs)
else:
self._run_backward(closure_loss, *args, **kwargs)

Expand Down
11 changes: 9 additions & 2 deletions src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,14 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
"""
return optimizer.state_dict()

def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
def backward(
self,
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> Tensor:
"""Forwards backward-calls to the precision plugin.

Args:
Expand All @@ -181,7 +188,7 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
assert self.lightning_module is not None
closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss)

self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
self.precision_plugin.backward(self.lightning_module, closure_loss, optimizer, optimizer_idx, *args, **kwargs)

closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss)
self.post_backward(closure_loss)
Expand Down