Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify attribute finding logic, fix not using dataloader when hparams present #4559

Merged
merged 21 commits into from
Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 41 additions & 44 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,72 +197,69 @@ def __repr__(self):
return out


def lightning_hasattr(model, attribute):
""" Special hasattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
def lightning_get_all_attr_holders(model, attribute):
Borda marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

be careful, this is API change
we are free to change anything which is protected (starting with _) but the remaining is public API
so even when we are adding new features, make a decision about what is meant to the public and protected api
cc: @PyTorchLightning/core-contributors

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lightning_hasattr is still there, the diff is just not rendering properly

""" Special attribute finding for lightning. Gets all of the objects or dicts that holds attribute.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """
trainer = getattr(model, 'trainer', None)

attr = False
holders = []

# Check if attribute in model
if hasattr(model, attribute):
attr = True
holders.append(model)

# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
attr = attribute in model.hparams
else:
attr = hasattr(model.hparams, attribute)
if hasattr(model, 'hparams'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hparams is property so what is the case it would be missing? some old model versions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is just collectin attr candidates. If there is no hparams, one of the previous/next candidates will be used

if attribute in model.hparams:
holders.append(model.hparams)

# Check if the attribute in datamodule (datamodule gets registered in Trainer)
if not attr and trainer is not None:
attr = hasattr(trainer.datamodule, attribute)
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

datamodule is also quite new

Suggested change
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
if trainer is not None and hasattr(trainer, 'datamodule') is not None and hasattr(trainer.datamodule, attribute):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not against being extra-safe but this should not fail. Unless a user is using this code with their own Trainer that does not inherit PL's

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda I think we don't need this, since we only expect this code to be used with the current version of the trainer and not some old.

Plus, if you want to add this, instead of hasattr(trainer, 'datamodule') you'd have to do getattr(trainer, 'datamodule', None) sinde hasattr` only return True or False.

holders.append(trainer.datamodule)

return attr
return holders


def lightning_get_first_attr_holder(model, attribute):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need this? is also not very consistent as the name says "first" but you take last as [-1]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in line 229, the last is taken to preserve backwards compatibility. In case somebody is relying on the attribute candidate order

""" Special attribute finding for lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule, returns the last one that has it. """
holders = lightning_get_all_attr_holders(model, attribute)
if len(holders) == 0:
return None
return holders[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if an attribute key is in both pl_module and datamodule, the datamodule attr will be used? Is this intended?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't say for sure if that was the original intention but it's not a bad idea. Usually, if a datamodule is given it's because you want to use it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The edge case is where the pl_module and datamodule contain the same key but different values. I guess this ties in with #3792



def lightning_hasattr(model, attribute):
""" Special hasattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
return lightning_get_first_attr_holder(model, attribute) is not None
Comment on lines +234 to +237
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure of the usage

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate?



def lightning_getattr(model, attribute):
""" Special getattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
trainer = getattr(model, 'trainer', None)

# Check if attribute in model
if hasattr(model, attribute):
attr = getattr(model, attribute)
# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams') and isinstance(model.hparams, dict) and attribute in model.hparams:
attr = model.hparams[attribute]
elif hasattr(model, 'hparams') and hasattr(model.hparams, attribute):
attr = getattr(model.hparams, attribute)
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
attr = getattr(trainer.datamodule, attribute)
else:
holder = lightning_get_first_attr_holder(model, attribute)
if holder is None:
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')
return attr

if isinstance(holder, dict):
return holder[attribute]
return getattr(holder, attribute)


def lightning_setattr(model, attribute, value):
""" Special setattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict.
Will also set the attribute on datamodule, if it exists.
"""
if not lightning_hasattr(model, attribute):
holders = lightning_get_all_attr_holders(model, attribute)
if len(holders) == 0:
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')

trainer = getattr(model, 'trainer', None)

# Check if attribute in model
if hasattr(model, attribute):
setattr(model, attribute, value)

# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
model.hparams[attribute] = value
for holder in holders:
if isinstance(holder, dict):
holder[attribute] = value
else:
setattr(model.hparams, attribute, value)

# Check if the attribute in datamodule (datamodule gets registered in Trainer)
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
setattr(trainer.datamodule, attribute, value)
setattr(holder, attribute, value)
39 changes: 35 additions & 4 deletions tests/utilities/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def _get_test_cases():
class TestHparamsNamespace:
learning_rate = 1

def __contains__(self, item):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mind add context for this state? (as a docstring)

return item == "learning_rate"

TestHparamsDict = {'learning_rate': 2}

class TestModel1: # test for namespace
Expand Down Expand Up @@ -53,12 +56,26 @@ class TestModel5: # test for datamodule

model5 = TestModel5()

return model1, model2, model3, model4, model5
class TestModel6: # test for datamodule w/ hparams w/o attribute (should use datamodule)
trainer = Trainer
hparams = TestHparamsDict

model6 = TestModel6()

TestHparamsDict2 = {'batch_size': 2}

class TestModel7: # test for datamodule w/ hparams w/ attribute (should use datamodule)
trainer = Trainer
hparams = TestHparamsDict2

model7 = TestModel7()

return model1, model2, model3, model4, model5, model6, model7


def test_lightning_hasattr(tmpdir):
""" Test that the lightning_hasattr works in all cases"""
model1, model2, model3, model4, model5 = _get_test_cases()
model1, model2, model3, model4, model5, model6, model7 = _get_test_cases()
assert lightning_hasattr(model1, 'learning_rate'), \
'lightning_hasattr failed to find namespace variable'
assert lightning_hasattr(model2, 'learning_rate'), \
Expand All @@ -69,6 +86,10 @@ def test_lightning_hasattr(tmpdir):
'lightning_hasattr found variable when it should not'
assert lightning_hasattr(model5, 'batch_size'), \
'lightning_hasattr failed to find batch_size in datamodule'
assert lightning_hasattr(model6, 'batch_size'), \
'lightning_hasattr failed to find batch_size in datamodule w/ hparams present'
assert lightning_hasattr(model7, 'batch_size'), \
'lightning_hasattr failed to find batch_size in hparams w/ datamodule present'


def test_lightning_getattr(tmpdir):
Expand All @@ -78,9 +99,13 @@ def test_lightning_getattr(tmpdir):
value = lightning_getattr(m, 'learning_rate')
assert value == i, 'attribute not correctly extracted'

model5 = models[4]
model5, model6, model7 = models[4:]
assert lightning_getattr(model5, 'batch_size') == 8, \
'batch_size not correctly extracted'
assert lightning_getattr(model6, 'batch_size') == 8, \
'batch_size not correctly extracted'
assert lightning_getattr(model7, 'batch_size') == 8, \
'batch_size not correctly extracted'


def test_lightning_setattr(tmpdir):
Expand All @@ -91,7 +116,13 @@ def test_lightning_setattr(tmpdir):
assert lightning_getattr(m, 'learning_rate') == 10, \
'attribute not correctly set'

model5 = models[4]
model5, model6, model7 = models[4:]
lightning_setattr(model5, 'batch_size', 128)
lightning_setattr(model6, 'batch_size', 128)
lightning_setattr(model7, 'batch_size', 128)
assert lightning_getattr(model5, 'batch_size') == 128, \
'batch_size not correctly set'
assert lightning_getattr(model6, 'batch_size') == 128, \
'batch_size not correctly set'
assert lightning_getattr(model7, 'batch_size') == 128, \
'batch_size not correctly set'