diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 00611d87d7f35..6a83b7b1f8637 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -105,6 +105,7 @@ def __init__(self, *args, **kwargs): self._current_dataloader_idx = None self.running_stage = None self._automatic_optimization: bool = True + self._param_requires_grad_state = dict() def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: @@ -1295,7 +1296,7 @@ def untoggle_optimizer(self, optimizer_idx: int): if param in self._param_requires_grad_state: param.requires_grad = self._param_requires_grad_state[param] # save memory - del self._param_requires_grad_state + self._param_requires_grad_state = dict() def optimizer_step( self,