Skip to content

Commit

Permalink
Unify attribute finding logic, fix not using dataloader when hparams …
Browse files Browse the repository at this point in the history
…present (#4559)

* Rebase onto master

* indent fix

* Remove duplicated logic

* Use single return

* Remove extra else

* add `__contains__` to TestHparamsNamespace to fix tests

* Fix lightning_setattr to set all valid attributes

* update doc

* better names

* fix holder order preference

* tests for new behavior

* Comment about using the last holder

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>

(cherry picked from commit eee3b1a)
  • Loading branch information
rnett authored and Borda committed Feb 4, 2021
1 parent 02620f3 commit cfae13c
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 56 deletions.
98 changes: 46 additions & 52 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,76 +196,70 @@ def __repr__(self):
return out


def lightning_hasattr(model, attribute):
""" Special hasattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
def lightning_get_all_attr_holders(model, attribute):
""" Special attribute finding for lightning. Gets all of the objects or dicts that holds attribute.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """
trainer = getattr(model, 'trainer', None)

attr = False
holders = []

# Check if attribute in model
if hasattr(model, attribute):
attr = True
holders.append(model)

# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
attr = attribute in model.hparams
else:
attr = hasattr(model.hparams, attribute)
if hasattr(model, 'hparams'):
if attribute in model.hparams:
holders.append(model.hparams)

# Check if the attribute in datamodule (datamodule gets registered in Trainer)
if not attr and trainer is not None:
attr = hasattr(trainer.datamodule, attribute)
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
holders.append(trainer.datamodule)

return attr
return holders


def lightning_get_first_attr_holder(model, attribute):
""" Special attribute finding for lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule, returns the last one that has it. """
holders = lightning_get_all_attr_holders(model, attribute)
if len(holders) == 0:
return None
# using the last holder to preserve backwards compatibility
return holders[-1]


def lightning_hasattr(model, attribute):
""" Special hasattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
return lightning_get_first_attr_holder(model, attribute) is not None


def lightning_getattr(model, attribute):
""" Special getattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
trainer = getattr(model, 'trainer', None)
holder = lightning_get_first_attr_holder(model, attribute)
if holder is None:
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')

# Check if attribute in model
if hasattr(model, attribute):
attr = getattr(model, attribute)
# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams') and isinstance(model.hparams, dict) and attribute in model.hparams:
attr = model.hparams[attribute]
elif hasattr(model, 'hparams') and hasattr(model.hparams, attribute):
attr = getattr(model.hparams, attribute)
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
attr = getattr(trainer.datamodule, attribute)
else:
raise ValueError(
f'The {attribute} is neither stored in the model namespace nor the `hparams` namespace/dict,'
' nor the datamodule.'
)
return attr
if isinstance(holder, dict):
return holder[attribute]
return getattr(holder, attribute)


def lightning_setattr(model, attribute, value):
""" Special setattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict.
Will also set the attribute on datamodule, if it exists.
"""
if not lightning_hasattr(model, attribute):
raise ValueError(
f'The {attribute} is neither stored in the model namespace nor the `hparams` namespace/dict,'
' nor the datamodule.'
)

trainer = getattr(model, 'trainer', None)

# Check if attribute in model
if hasattr(model, attribute):
setattr(model, attribute, value)

# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
model.hparams[attribute] = value
holders = lightning_get_all_attr_holders(model, attribute)
if len(holders) == 0:
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')

for holder in holders:
if isinstance(holder, dict):
holder[attribute] = value
else:
setattr(model.hparams, attribute, value)

# Check if the attribute in datamodule (datamodule gets registered in Trainer)
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
setattr(trainer.datamodule, attribute, value)
setattr(holder, attribute, value)
39 changes: 35 additions & 4 deletions tests/utilities/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def _get_test_cases():
class TestHparamsNamespace:
learning_rate = 1

def __contains__(self, item):
return item == "learning_rate"

TestHparamsDict = {'learning_rate': 2}

class TestModel1: # test for namespace
Expand Down Expand Up @@ -52,12 +55,26 @@ class TestModel5: # test for datamodule

model5 = TestModel5()

return model1, model2, model3, model4, model5
class TestModel6: # test for datamodule w/ hparams w/o attribute (should use datamodule)
trainer = Trainer
hparams = TestHparamsDict

model6 = TestModel6()

TestHparamsDict2 = {'batch_size': 2}

class TestModel7: # test for datamodule w/ hparams w/ attribute (should use datamodule)
trainer = Trainer
hparams = TestHparamsDict2

model7 = TestModel7()

return model1, model2, model3, model4, model5, model6, model7


def test_lightning_hasattr(tmpdir):
""" Test that the lightning_hasattr works in all cases"""
model1, model2, model3, model4, model5 = _get_test_cases()
model1, model2, model3, model4, model5, model6, model7 = _get_test_cases()
assert lightning_hasattr(model1, 'learning_rate'), \
'lightning_hasattr failed to find namespace variable'
assert lightning_hasattr(model2, 'learning_rate'), \
Expand All @@ -68,6 +85,10 @@ def test_lightning_hasattr(tmpdir):
'lightning_hasattr found variable when it should not'
assert lightning_hasattr(model5, 'batch_size'), \
'lightning_hasattr failed to find batch_size in datamodule'
assert lightning_hasattr(model6, 'batch_size'), \
'lightning_hasattr failed to find batch_size in datamodule w/ hparams present'
assert lightning_hasattr(model7, 'batch_size'), \
'lightning_hasattr failed to find batch_size in hparams w/ datamodule present'


def test_lightning_getattr(tmpdir):
Expand All @@ -77,9 +98,13 @@ def test_lightning_getattr(tmpdir):
value = lightning_getattr(m, 'learning_rate')
assert value == i, 'attribute not correctly extracted'

model5 = models[4]
model5, model6, model7 = models[4:]
assert lightning_getattr(model5, 'batch_size') == 8, \
'batch_size not correctly extracted'
assert lightning_getattr(model6, 'batch_size') == 8, \
'batch_size not correctly extracted'
assert lightning_getattr(model7, 'batch_size') == 8, \
'batch_size not correctly extracted'


def test_lightning_setattr(tmpdir):
Expand All @@ -90,7 +115,13 @@ def test_lightning_setattr(tmpdir):
assert lightning_getattr(m, 'learning_rate') == 10, \
'attribute not correctly set'

model5 = models[4]
model5, model6, model7 = models[4:]
lightning_setattr(model5, 'batch_size', 128)
lightning_setattr(model6, 'batch_size', 128)
lightning_setattr(model7, 'batch_size', 128)
assert lightning_getattr(model5, 'batch_size') == 128, \
'batch_size not correctly set'
assert lightning_getattr(model6, 'batch_size') == 128, \
'batch_size not correctly set'
assert lightning_getattr(model7, 'batch_size') == 128, \
'batch_size not correctly set'

0 comments on commit cfae13c

Please sign in to comment.