diff --git a/src/pytorch_lightning/plugins/precision/apex_amp.py b/src/pytorch_lightning/plugins/precision/apex_amp.py index 15825dedd2ef6..e18f82dc27f6e 100644 --- a/src/pytorch_lightning/plugins/precision/apex_amp.py +++ b/src/pytorch_lightning/plugins/precision/apex_amp.py @@ -59,6 +59,7 @@ def backward( model: "pl.LightningModule", closure_loss: Tensor, optimizer: Optional[Optimizer], + optimizer_idx: Optional[int], *args: Any, **kwargs: Any, ) -> None: @@ -71,7 +72,7 @@ def backward( """ opt = optimizer or model.trainer.optimizers with amp.scale_loss(closure_loss, opt) as closure_loss: - super().backward(model, closure_loss, optimizer, *args, **kwargs) + super().backward(model, closure_loss, optimizer, optimizer_idx, *args, **kwargs) def optimizer_step( self, diff --git a/src/pytorch_lightning/plugins/precision/deepspeed.py b/src/pytorch_lightning/plugins/precision/deepspeed.py index ccc0fff8411dd..594b1f6cc0f46 100644 --- a/src/pytorch_lightning/plugins/precision/deepspeed.py +++ b/src/pytorch_lightning/plugins/precision/deepspeed.py @@ -62,7 +62,15 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona self.amp_type = amp_type self.amp_level = amp_level - def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any, **kwargs: Any) -> None: + def backward( + self, + model: "pl.LightningModule", + closure_loss: Tensor, + optimizer: Optional[Optimizer], + optimizer_idx: Optional[int], + *args: Any, + **kwargs: Any, + ) -> None: if is_overridden("backward", model): warning_cache.warn( "You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles" diff --git a/src/pytorch_lightning/plugins/precision/ipu.py b/src/pytorch_lightning/plugins/precision/ipu.py index 329a8b8978e50..89f544575f63f 100644 --- a/src/pytorch_lightning/plugins/precision/ipu.py +++ b/src/pytorch_lightning/plugins/precision/ipu.py @@ -44,7 +44,7 @@ def __init__(self, precision: int) -> None: super().__init__() self.precision = precision - def backward(self, model: "pl.LightningModule", *args: Any, **kwargs: Any) -> None: + def backward(self, model: "pl.LightningModule", *_: Any, **__: Any) -> None: if is_overridden("backward", model): warning_cache.warn( "You have overridden the `LightningModule.backward` hook but it will be ignored since IPUs handle" diff --git a/src/pytorch_lightning/plugins/precision/precision_plugin.py b/src/pytorch_lightning/plugins/precision/precision_plugin.py index cbf18b8c4fa41..02d343a0876b4 100644 --- a/src/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/src/pytorch_lightning/plugins/precision/precision_plugin.py @@ -64,6 +64,7 @@ def backward( model: "pl.LightningModule", closure_loss: Tensor, optimizer: Optional[Optimizer], + optimizer_idx: Optional[int], *args: Any, **kwargs: Any, ) -> None: @@ -76,7 +77,7 @@ def backward( """ # do backward pass if model is not None and isinstance(model, pl.LightningModule): - model.backward(closure_loss, optimizer, *args, **kwargs) + model.backward(closure_loss, optimizer, optimizer_idx, *args, **kwargs) else: self._run_backward(closure_loss, *args, **kwargs) diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index f47afc890bcbb..0de904ccbd283 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -171,7 +171,14 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: """ return optimizer.state_dict() - def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: + def backward( + self, + closure_loss: Tensor, + optimizer: Optional[Optimizer], + optimizer_idx: Optional[int], + *args: Any, + **kwargs: Any, + ) -> Tensor: """Forwards backward-calls to the precision plugin. Args: @@ -181,7 +188,7 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: assert self.lightning_module is not None closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss) - self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs) + self.precision_plugin.backward(self.lightning_module, closure_loss, optimizer, optimizer_idx, *args, **kwargs) closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss) self.post_backward(closure_loss)