Skip to content

Commit

Permalink
Fix manual backward for DeepSpeed (#13882)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jul 27, 2022
1 parent dbafd6e commit fe9803c
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def backward(
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> None:
Expand All @@ -71,7 +72,7 @@ def backward(
"""
opt = optimizer or model.trainer.optimizers
with amp.scale_loss(closure_loss, opt) as closure_loss:
super().backward(model, closure_loss, optimizer, *args, **kwargs)
super().backward(model, closure_loss, optimizer, optimizer_idx, *args, **kwargs)

def optimizer_step(
self,
Expand Down
10 changes: 9 additions & 1 deletion src/pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,15 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona
self.amp_type = amp_type
self.amp_level = amp_level

def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any, **kwargs: Any) -> None:
def backward(
self,
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> None:
if is_overridden("backward", model):
warning_cache.warn(
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/plugins/precision/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, precision: int) -> None:
super().__init__()
self.precision = precision

def backward(self, model: "pl.LightningModule", *args: Any, **kwargs: Any) -> None:
def backward(self, model: "pl.LightningModule", *_: Any, **__: Any) -> None:
if is_overridden("backward", model):
warning_cache.warn(
"You have overridden the `LightningModule.backward` hook but it will be ignored since IPUs handle"
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def backward(
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> None:
Expand All @@ -76,7 +77,7 @@ def backward(
"""
# do backward pass
if model is not None and isinstance(model, pl.LightningModule):
model.backward(closure_loss, optimizer, *args, **kwargs)
model.backward(closure_loss, optimizer, optimizer_idx, *args, **kwargs)
else:
self._run_backward(closure_loss, *args, **kwargs)

Expand Down
11 changes: 9 additions & 2 deletions src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,14 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
"""
return optimizer.state_dict()

def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
def backward(
self,
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> Tensor:
"""Forwards backward-calls to the precision plugin.
Args:
Expand All @@ -181,7 +188,7 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
assert self.lightning_module is not None
closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss)

self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
self.precision_plugin.backward(self.lightning_module, closure_loss, optimizer, optimizer_idx, *args, **kwargs)

closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss)
self.post_backward(closure_loss)
Expand Down

0 comments on commit fe9803c

Please sign in to comment.