diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 9272f8e2be924..24c8be9dc9b37 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -203,6 +203,10 @@ def __run_eval_epoch_end(self, num_dataloaders): # with a single dataloader don't pass an array outputs = self.outputs + + # free memory + self.outputs = [] + eval_results = outputs if num_dataloaders == 1: eval_results = outputs[0] diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 23b2fcbb52235..0e63fc29d49b1 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from collections import OrderedDict from logging import INFO @@ -22,7 +21,7 @@ from torch.nn import Sequential from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.callbacks import ModelPruning, ModelCheckpoint +from pytorch_lightning.callbacks import ModelCheckpoint, ModelPruning from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -274,6 +273,7 @@ def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog): seed_everything(0) class TestPruning(ModelPruning): + def on_save_checkpoint(self, trainer, pl_module, checkpoint): super().on_save_checkpoint(trainer, pl_module, checkpoint) assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"] diff --git a/tests/trainer/logging_/test_eval_loop_logging_1_0.py b/tests/trainer/logging_/test_eval_loop_logging_1_0.py index e5cf596a78eca..72084454ba10d 100644 --- a/tests/trainer/logging_/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_/test_eval_loop_logging_1_0.py @@ -126,6 +126,7 @@ def validation_step_end(self, acc): def validation_epoch_end(self, outputs): self.log('g', torch.tensor(2, device=self.device), on_epoch=True) self.validation_epoch_end_called = True + assert len(self.trainer.evaluation_loop.outputs) == 0 def backward(self, loss, optimizer, optimizer_idx): return LightningModule.backward(self, loss, optimizer, optimizer_idx)