From 00e73c12d39072e480e96517ad18be504780722a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 6 Sep 2020 05:48:28 +0200 Subject: [PATCH] update docs --- pytorch_lightning/trainer/training_tricks.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 9da690abbb75c..384688262e1ea 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -161,6 +161,13 @@ def scale_batch_size(self, algorithm is terminated batch_arg_name: name of the attribute that stores the batch size. + It is expected that the user has provided a model or datamodule that has a hyperparameter + with that name. We will look for this attribute name in the following places + + - `model` + - `model.hparams` + - `model.datamodule` + - `trainer.datamodule` (the datamodule passed to the tune method) **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader or datamodule. @@ -263,16 +270,12 @@ def _adjust_batch_size(trainer, factor: float = 1.0, value: Optional[int] = None, desc: str = None) -> Tuple[int, bool]: - """ Function for adjusting the batch size. It is expected that the user - has provided a model that has a hparam field called `batch_size` i.e. - `model.hparams.batch_size` should exist. Additionally there can be a - datamodule attached to either Trainer or model, in that case the attribute - also gets updated when present. + """ Helper function for adjusting the batch size. Args: trainer: instance of pytorch_lightning.Trainer - batch_arg_name: field where batch_size is stored in `model.hparams` + batch_arg_name: name of the field where batch_size is stored. factor: value which the old batch size is multiplied by to get the new batch size