Skip to content

Commit

Permalink
Fix batch_arg_name bug
Browse files Browse the repository at this point in the history
Add `batch_arg_name` to all calls to `_adjust_batch_size`
  • Loading branch information
M-Salti committed Nov 22, 2020
1 parent 8601268 commit 2fb4f6c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def scale_batch_size(trainer,
trainer.progress_bar_callback.disable()

# Initially we just double in size until an OOM is encountered
new_size = _adjust_batch_size(trainer, value=init_val) # initially set to init_val
new_size = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val
if mode == 'power':
new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs)
elif mode == 'binsearch':
Expand Down Expand Up @@ -231,7 +231,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials,
garbage_collection_cuda()
high = new_size
midval = (high + low) // 2
new_size, _ = _adjust_batch_size(trainer, value=midval, desc='failed')
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='failed')
if high - low <= 1:
break
else:
Expand Down

0 comments on commit 2fb4f6c

Please sign in to comment.