Skip to content

Commit

Permalink
fix tmpdir (#1012)
Browse files Browse the repository at this point in the history
* fix tmpdir

* just str path
  • Loading branch information
Borda committed Mar 12, 2020
1 parent 2b3f443 commit 1d5f062
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_dp_output_reduce():
assert reduced['b']['c'] == out['b']['c']


def test_model_checkpoint_options(tmp_path):
def test_model_checkpoint_options(tmpdir):
"""Test ModelCheckpoint options."""
def mock_save_function(filepath):
open(filepath, 'a').close()
Expand All @@ -258,8 +258,8 @@ def mock_save_function(filepath):
_ = LightningTestModel(hparams)

# simulated losses
save_dir = tmp_path / "1"
save_dir.mkdir()
save_dir = os.path.join(tmpdir, '1')
os.mkdir(save_dir)
losses = [10, 9, 2.8, 5, 2.5]

# -----------------
Expand All @@ -286,8 +286,8 @@ def mock_save_function(filepath):
'epoch=0.ckpt'}:
assert fname in file_lists

save_dir = tmp_path / "2"
save_dir.mkdir()
save_dir = os.path.join(tmpdir, '2')
os.mkdir(save_dir)

# -----------------
# CASE K=0 (none)
Expand All @@ -305,8 +305,8 @@ def mock_save_function(filepath):

assert len(file_lists) == 0, "Should save 0 models when save_top_k=0"

save_dir = tmp_path / "3"
save_dir.mkdir()
save_dir = os.path.join(tmpdir, '3')
os.mkdir(save_dir)

# -----------------
# CASE K=1 (2.5, epoch 4)
Expand All @@ -325,8 +325,8 @@ def mock_save_function(filepath):
assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"
assert 'test_prefix_epoch=4.ckpt' in file_lists

save_dir = tmp_path / "4"
save_dir.mkdir()
save_dir = os.path.join(tmpdir, '4')
os.mkdir(save_dir)

# -----------------
# CASE K=2 (2.5 epoch 4, 2.8 epoch 2)
Expand All @@ -351,8 +351,8 @@ def mock_save_function(filepath):
'other_file.ckpt'}:
assert fname in file_lists

save_dir = tmp_path / "5"
save_dir.mkdir()
save_dir = os.path.join(tmpdir, '5')
os.mkdir(save_dir)

# -----------------
# CASE K=4 (save all 4 models)
Expand All @@ -372,8 +372,8 @@ def mock_save_function(filepath):

assert len(file_lists) == 4, 'Should save all 4 models when save_top_k=4 within same epoch'

save_dir = tmp_path / "6"
save_dir.mkdir()
save_dir = os.path.join(tmpdir, '6')
os.mkdir(save_dir)

# -----------------
# CASE K=3 (save the 2nd, 3rd, 4th model)
Expand Down

0 comments on commit 1d5f062

Please sign in to comment.