Skip to content

Commit

Permalink
Fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mpariente committed Sep 26, 2020
1 parent c213764 commit ecab426
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):


@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
def test_model_checkpoint_to_json(tmpdir, save_top_k):
def test_model_checkpoint_to_yaml(tmpdir, save_top_k):
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
tutils.reset_seed()
model = EvalModelTemplate()
Expand All @@ -49,8 +49,8 @@ def test_model_checkpoint_to_json(tmpdir, save_top_k):

checkpoint.to_yaml('./best_k_models.yaml')
d = yaml.full_load(open('./best_k_models.yaml', 'r'))
best_k = {k: torch.Tensor(v) for k, v in d.items()}
torch.testing.assert_allclose(best_k, checkpoint.best_k_models)
best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
assert d == best_k


@pytest.mark.parametrize(
Expand Down

0 comments on commit ecab426

Please sign in to comment.