Skip to content

Commit

Permalink
tests for new behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
rnett committed Dec 7, 2020
1 parent cdf2e95 commit ab55ff0
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions tests/utilities/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,26 @@ class TestModel5: # test for datamodule

model5 = TestModel5()

return model1, model2, model3, model4, model5
class TestModel6: # test for datamodule w/ hparams w/o attribute (should use datamodule)
trainer = Trainer
hparams = TestHparamsDict

model6 = TestModel6()

TestHparamsDict2 = {'batch_size': 2}

class TestModel7: # test for datamodule w/ hparams w/ attribute (should use datamodule)
trainer = Trainer
hparams = TestHparamsDict2

model7 = TestModel7()

return model1, model2, model3, model4, model5, model6, model7


def test_lightning_hasattr(tmpdir):
""" Test that the lightning_hasattr works in all cases"""
model1, model2, model3, model4, model5 = _get_test_cases()
model1, model2, model3, model4, model5, model6, model7 = _get_test_cases()
assert lightning_hasattr(model1, 'learning_rate'), \
'lightning_hasattr failed to find namespace variable'
assert lightning_hasattr(model2, 'learning_rate'), \
Expand All @@ -72,6 +86,10 @@ def test_lightning_hasattr(tmpdir):
'lightning_hasattr found variable when it should not'
assert lightning_hasattr(model5, 'batch_size'), \
'lightning_hasattr failed to find batch_size in datamodule'
assert lightning_hasattr(model6, 'batch_size'), \
'lightning_hasattr failed to find batch_size in datamodule w/ hparams present'
assert lightning_hasattr(model7, 'batch_size'), \
'lightning_hasattr failed to find batch_size in hparams w/ datamodule present'


def test_lightning_getattr(tmpdir):
Expand All @@ -81,9 +99,13 @@ def test_lightning_getattr(tmpdir):
value = lightning_getattr(m, 'learning_rate')
assert value == i, 'attribute not correctly extracted'

model5 = models[4]
model5, model6, model7 = models[4:]
assert lightning_getattr(model5, 'batch_size') == 8, \
'batch_size not correctly extracted'
assert lightning_getattr(model6, 'batch_size') == 8, \
'batch_size not correctly extracted'
assert lightning_getattr(model7, 'batch_size') == 8, \
'batch_size not correctly extracted'


def test_lightning_setattr(tmpdir):
Expand All @@ -94,7 +116,13 @@ def test_lightning_setattr(tmpdir):
assert lightning_getattr(m, 'learning_rate') == 10, \
'attribute not correctly set'

model5 = models[4]
model5, model6, model7 = models[4:]
lightning_setattr(model5, 'batch_size', 128)
lightning_setattr(model6, 'batch_size', 128)
lightning_setattr(model7, 'batch_size', 128)
assert lightning_getattr(model5, 'batch_size') == 128, \
'batch_size not correctly set'
assert lightning_getattr(model6, 'batch_size') == 128, \
'batch_size not correctly set'
assert lightning_getattr(model7, 'batch_size') == 128, \
'batch_size not correctly set'

0 comments on commit ab55ff0

Please sign in to comment.