diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 66fa0d7d20ab28..b8b5a97664d0dc 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -208,15 +208,12 @@ def lightning_get_attr_holder(model, attribute): # Check if attribute in model.hparams, either namespace or dict if hasattr(model, 'hparams'): - if isinstance(model.hparams, dict) and attribute in model.hparams: - return model.hparams - elif hasattr(model.hparams, attribute): + if attribute in model.hparams: return model.hparams # Check if the attribute in datamodule (datamodule gets registered in Trainer) if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): - if getattr(trainer.datamodule, attribute): - return trainer.datamodule + return trainer.datamodule return None