From c8afcabe28bfb7e52f361205001d4a580c658748 Mon Sep 17 00:00:00 2001 From: Mohamed Al Salti Date: Mon, 23 Nov 2020 08:04:11 +0200 Subject: [PATCH] Fix batch_arg_name bug (#4812) Add `batch_arg_name` to all calls to `_adjust_batch_size` (cherry picked from commit cd90dd429b46b1ad9035be39a6bea4f0cdf70882) --- pytorch_lightning/tuner/batch_size_scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 10de8a2d289e5..b6e33cb958cba 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -97,7 +97,7 @@ def scale_batch_size(trainer, trainer.progress_bar_callback.disable() # Initially we just double in size until an OOM is encountered - new_size = _adjust_batch_size(trainer, value=init_val) # initially set to init_val + new_size = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val if mode == 'power': new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs) elif mode == 'binsearch': @@ -223,7 +223,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, garbage_collection_cuda() high = new_size midval = (high + low) // 2 - new_size, _ = _adjust_batch_size(trainer, value=midval, desc='failed') + new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='failed') if high - low <= 1: break else: