diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 69c764e06e48f..0e234836e2533 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -136,7 +136,10 @@ def scale_batch_size(self, """ if not hasattr(model, batch_arg_name): - raise MisconfigurationException(f'Field {batch_arg_name} not found in `model.hparams`') + if not hasattr(model.hparams, batch_arg_name): + raise MisconfigurationException( + 'Neither of `model.batch_size` and `model.hparams.batch_size` found.' + ) if hasattr(model.train_dataloader, 'patch_loader_code'): raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders' @@ -242,9 +245,15 @@ def _adjust_batch_size(trainer, """ model = trainer.get_model() - batch_size = getattr(model, batch_arg_name) + if hasattr(model, batch_arg_name): + batch_size = getattr(model, batch_arg_name) + else: + batch_size = getattr(model.hparams, batch_arg_name) if value: - setattr(model, batch_arg_name, value) + if hasattr(model, batch_arg_name): + setattr(model, batch_arg_name, value) + else: + setattr(model.hparams, batch_arg_name, value) new_size = value if desc: log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') @@ -252,7 +261,7 @@ def _adjust_batch_size(trainer, new_size = int(batch_size * factor) if desc: log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') - setattr(model, batch_arg_name, new_size) + setattr(model.hparams, batch_arg_name, new_size) return new_size