Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/manual_optimizaiton_loss_nan
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Dec 16, 2020
2 parents 2d46843 + b4d926b commit 5bb7406
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 30 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `LightningOptimizer` exposes optimizer attributes ([#5095](https://github.com/PyTorchLightning/pytorch-lightning/pull/5095))


- Fixed the saved filename in `ModelCheckpoint` when it already exists ([#4861](https://github.com/PyTorchLightning/pytorch-lightning/pull/4861))


- Do not warn when the `name` key is used in the `lr_scheduler` dict ([#5057](https://github.com/PyTorchLightning/pytorch-lightning/pull/5057))



## [1.1.0] - 2020-12-09

### Added
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pytest
import torch

from pytorch_lightning import seed_everything, Trainer
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer, seed_everything
from tests.base.models import ParityModuleMNIST, ParityModuleRNN


Expand Down
2 changes: 1 addition & 1 deletion benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
import torch

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVAILABLE
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|buil

[tool.isort]
known_first_party = [
"bencharmks",
"benchmarks",
"docs",
"pl_examples",
"pytorch_lightning",
Expand Down Expand Up @@ -52,3 +52,5 @@ skip_glob = [
]
profile = "black"
line_length = 120
force_sort_within_sections = "True"
order_by_type = "False"
49 changes: 28 additions & 21 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,17 +240,14 @@ def save_checkpoint(self, trainer, pl_module):
# what can be monitored
monitor_candidates = self._monitor_candidates(trainer)

# ie: path/val_loss=0.5.ckpt
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, global_step)

# callback supports multiple simultaneous modes
# here we call each mode sequentially
# Mode 1: save all checkpoints OR only the top k
if self.save_top_k:
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, filepath)
self._save_top_k_checkpoints(trainer, pl_module, monitor_candidates)

# Mode 2: save the last checkpoint
self._save_last_checkpoint(trainer, pl_module, monitor_candidates, filepath)
self._save_last_checkpoint(trainer, pl_module, monitor_candidates)

def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
Expand Down Expand Up @@ -444,6 +441,7 @@ def format_checkpoint_name(
)
if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))

ckpt_name = f"{filename}{self.FILE_EXTENSION}"
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name

Expand Down Expand Up @@ -515,13 +513,20 @@ def _validate_monitor_key(self, trainer):
)
raise MisconfigurationException(m)

def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int):
def _get_metric_interpolated_filepath_name(
self,
ckpt_name_metrics: Dict[str, Any],
epoch: int,
step: int,
del_filepath: Optional[str] = None
) -> str:
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)

version_cnt = 0
while self._fs.exists(filepath):
while self._fs.exists(filepath) and filepath != del_filepath:
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1

return filepath

def _monitor_candidates(self, trainer):
Expand All @@ -531,13 +536,11 @@ def _monitor_candidates(self, trainer):
ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch})
return ckpt_name_metrics

def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath):
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
should_save_last = self.monitor is None or self.save_last
if not should_save_last:
return

last_filepath = filepath

# when user ALSO asked for the 'last.ckpt' change the name
if self.save_last:
last_filepath = self._format_checkpoint_name(
Expand All @@ -548,6 +551,10 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath)
prefix=self.prefix
)
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
else:
last_filepath = self._get_metric_interpolated_filepath_name(
ckpt_name_metrics, trainer.current_epoch, trainer.global_step
)

accelerator_backend = trainer.accelerator_backend

Expand All @@ -568,7 +575,7 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath)
if self.monitor is None:
self.best_model_path = self.last_model_path

def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
current = metrics.get(self.monitor)
epoch = metrics.get("epoch")
step = metrics.get("step")
Expand All @@ -577,7 +584,7 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
current = torch.tensor(current, device=pl_module.device)

if self.check_monitor_top_k(current):
self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module)
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
elif self.verbose:
rank_zero_info(
f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}"
Expand All @@ -588,25 +595,26 @@ def _is_valid_monitor_key(self, metrics):

def _update_best_and_save(
self,
filepath: str,
current: torch.Tensor,
epoch: int,
step: int,
trainer,
pl_module,
ckpt_name_metrics
):
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k

del_list = []
del_filepath = None
if len(self.best_k_models) == k and k > 0:
delpath = self.kth_best_model_path
self.best_k_models.pop(self.kth_best_model_path)
del_list.append(delpath)
del_filepath = self.kth_best_model_path
self.best_k_models.pop(del_filepath)

# do not save nan, replace with +/- inf
if torch.isnan(current):
current = torch.tensor(float('inf' if self.mode == "min" else '-inf'))

filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, del_filepath)

# save the current score
self.current_score = current
self.best_k_models[filepath] = current
Expand All @@ -630,9 +638,8 @@ def _update_best_and_save(
)
self._save_model(filepath, trainer, pl_module)

for cur_path in del_list:
if cur_path != filepath:
self._del_model(cur_path)
if del_filepath is not None and filepath != del_filepath:
self._del_model(del_filepath)

def to_yaml(self, filepath: Optional[Union[str, Path]] = None):
"""
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/setup_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
# limitations under the License.
import os
import re
import warnings
from typing import Iterable, List
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
import warnings

from pytorch_lightning import PROJECT_ROOT, __homepage__, __version__
from pytorch_lightning import __homepage__, __version__, PROJECT_ROOT

_PATH_BADGES = os.path.join('.', 'docs', 'source', '_images', 'badges')
# badge to download
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, window_length: int):

def reset(self) -> None:
"""Empty the accumulator."""
self = TensorRunningAccum(self.window_length)
self.__init__(self.window_length)

def last(self):
"""Get the last added element."""
Expand Down
39 changes: 39 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,3 +938,42 @@ def __init__(self, hparams):
else:
# make sure it's not AttributeDict
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type


@pytest.mark.parametrize('max_epochs', [3, 4])
@pytest.mark.parametrize(
'save_top_k, expected',
[
(1, ['curr_epoch.ckpt']),
(2, ['curr_epoch.ckpt', 'curr_epoch-v0.ckpt']),
]
)
def test_model_checkpoint_file_already_exists(tmpdir, max_epochs, save_top_k, expected):
"""
Test that version is added to filename if required and it already exists in dirpath.
"""
model_checkpoint = ModelCheckpoint(
dirpath=tmpdir,
filename='curr_epoch',
save_top_k=save_top_k,
monitor='epoch',
mode='max',
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[model_checkpoint],
max_epochs=max_epochs,
limit_train_batches=2,
limit_val_batches=2,
logger=None,
weights_summary=None,
progress_bar_refresh_rate=0,
)

model = BoringModel()
trainer.fit(model)
ckpt_files = os.listdir(tmpdir)
assert set(ckpt_files) == set(expected)

epochs_in_ckpt_files = [pl_load(os.path.join(tmpdir, f))['epoch'] - 1 for f in ckpt_files]
assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs))
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import threading
from functools import partial, wraps
from http.server import SimpleHTTPRequestHandler
import sys
import threading

import pytest
import torch.multiprocessing as mp
Expand Down
2 changes: 1 addition & 1 deletion tests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

import os
import time
from pathlib import Path
import time

import numpy as np
import pytest
Expand Down
38 changes: 38 additions & 0 deletions tests/trainer/test_supporters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch

from pytorch_lightning.trainer.supporters import TensorRunningAccum


def test_tensor_running_accum_reset():
""" Test that reset would set all attributes to the initialization state """

window_length = 10

accum = TensorRunningAccum(window_length=window_length)
assert accum.last() is None
assert accum.mean() is None

accum.append(torch.tensor(1.5))
assert accum.last() == torch.tensor(1.5)
assert accum.mean() == torch.tensor(1.5)

accum.reset()
assert accum.window_length == window_length
assert accum.memory is None
assert accum.current_idx == 0
assert accum.last_idx is None
assert not accum.rotated

0 comments on commit 5bb7406

Please sign in to comment.