From 437825785bfa9763c616770059d1c3a54f685704 Mon Sep 17 00:00:00 2001 From: Tejasvi S Tomar <45873379+tejasvi@users.noreply.github.com> Date: Wed, 27 May 2020 22:38:11 +0530 Subject: [PATCH 1/3] Misleading exception raised during batch scaling Use batch_size from `model.hparams.batch_size` instead of `model.batch_size` --- pytorch_lightning/trainer/training_tricks.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 69c764e06e48f..814b54e0f65ef 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -135,7 +135,9 @@ def scale_batch_size(self, algorithm is terminated """ - if not hasattr(model, batch_arg_name): + if not hasattr(model, 'hparams'): + raise MisconfigurationException(f'`model.hparams` not found.') + elif not hasattr(model.hparams, batch_arg_name): raise MisconfigurationException(f'Field {batch_arg_name} not found in `model.hparams`') if hasattr(model.train_dataloader, 'patch_loader_code'): @@ -242,9 +244,9 @@ def _adjust_batch_size(trainer, """ model = trainer.get_model() - batch_size = getattr(model, batch_arg_name) + batch_size = getattr(model.hparams, batch_arg_name) if value: - setattr(model, batch_arg_name, value) + setattr(model.hparams, batch_arg_name, value) new_size = value if desc: log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') @@ -252,7 +254,7 @@ def _adjust_batch_size(trainer, new_size = int(batch_size * factor) if desc: log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') - setattr(model, batch_arg_name, new_size) + setattr(model.hparams, batch_arg_name, new_size) return new_size From 179f6a59583d76076b78f14da4cdf3115fc4d433 Mon Sep 17 00:00:00 2001 From: Tejasvi S Tomar <45873379+tejasvi@users.noreply.github.com> Date: Thu, 28 May 2020 13:19:22 +0530 Subject: [PATCH 2/3] Improvements considering #1896 --- pytorch_lightning/trainer/training_tricks.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 814b54e0f65ef..fd4d33396ec7d 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -135,10 +135,10 @@ def scale_batch_size(self, algorithm is terminated """ - if not hasattr(model, 'hparams'): - raise MisconfigurationException(f'`model.hparams` not found.') - elif not hasattr(model.hparams, batch_arg_name): - raise MisconfigurationException(f'Field {batch_arg_name} not found in `model.hparams`') + if not hasattr(model, batch_arg_name): + if not hasattr(model, 'hparams') or not hasattr(model.hparams, batch_arg_name): + raise MisconfigurationException(f'Neither of `model.batch_size` and' + f' `model.hparams.batch_size` found.') if hasattr(model.train_dataloader, 'patch_loader_code'): raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders' @@ -244,9 +244,15 @@ def _adjust_batch_size(trainer, """ model = trainer.get_model() - batch_size = getattr(model.hparams, batch_arg_name) + if hasattr(model, batch_arg_name): + batch_size = getattr(model, batch_arg_name) + else: + batch_size = getattr(model.hparams, batch_arg_name) if value: - setattr(model.hparams, batch_arg_name, value) + if hasattr(model, batch_arg_name): + setattr(model, batch_arg_name, value) + else: + setattr(model.hparams, batch_arg_name, value) new_size = value if desc: log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') From 3d79243abca1267e8b52c61ef882d3c5c7763eb9 Mon Sep 17 00:00:00 2001 From: Tejasvi S Tomar <45873379+tejasvi@users.noreply.github.com> Date: Mon, 8 Jun 2020 21:47:23 +0530 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: Jirka Borovec --- pytorch_lightning/trainer/training_tricks.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index fd4d33396ec7d..0e234836e2533 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -136,9 +136,10 @@ def scale_batch_size(self, """ if not hasattr(model, batch_arg_name): - if not hasattr(model, 'hparams') or not hasattr(model.hparams, batch_arg_name): - raise MisconfigurationException(f'Neither of `model.batch_size` and' - f' `model.hparams.batch_size` found.') + if not hasattr(model.hparams, batch_arg_name): + raise MisconfigurationException( + 'Neither of `model.batch_size` and `model.hparams.batch_size` found.' + ) if hasattr(model.train_dataloader, 'patch_loader_code'): raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders'