Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Resolve memory leak for evaluation #6326

Merged
merged 5 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ def __run_eval_epoch_end(self, num_dataloaders):

# with a single dataloader don't pass an array
outputs = self.outputs

# free memory
self.outputs = []

eval_results = outputs
if num_dataloaders == 1:
eval_results = outputs[0]
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import OrderedDict
from logging import INFO

Expand All @@ -22,7 +21,7 @@
from torch.nn import Sequential

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelPruning, ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint, ModelPruning
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -274,6 +273,7 @@ def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog):
seed_everything(0)

class TestPruning(ModelPruning):

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
super().on_save_checkpoint(trainer, pl_module, checkpoint)
assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def training_step(self, batch, batch_idx):
return acc

def validation_step(self, batch, batch_idx):
self.batch_idx = batch_idx
acc = self.step(batch, batch_idx)
acc = acc + batch_idx
self.log('c', acc)
Expand All @@ -126,6 +127,7 @@ def validation_step_end(self, acc):
def validation_epoch_end(self, outputs):
self.log('g', torch.tensor(2, device=self.device), on_epoch=True)
self.validation_epoch_end_called = True
assert len(self.trainer.evaluation_loop.outputs) == 0

def backward(self, loss, optimizer, optimizer_idx):
return LightningModule.backward(self, loss, optimizer, optimizer_idx)
Expand Down