Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Anselmoo committed May 13, 2022
1 parent c51da2d commit 176afc7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
8 changes: 6 additions & 2 deletions autokeras/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,13 @@ def run_with_adaptive_batch_size(batch_size, func, **fit_kwargs):
try:
history = func(x=x, validation_data=validation_data, **fit_kwargs)
break
except tf.errors.ResourceExhaustedError as e:
except tf.errors.ResourceExhaustedError:
if batch_size == 1:
raise e
print(
"Not enough memory, reduced batch size is already set to 1. "
"Current model will be skipped."
)
break
batch_size //= 2
print(
"Not enough memory, reduce batch size to {batch_size}.".format(
Expand Down
24 changes: 14 additions & 10 deletions autokeras/utils/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,23 @@ def test_check_kt_version_error():
)


def test_run_with_adaptive_batch_size_raise_error():
def test_run_with_adaptive_batch_size_raise_error(capfd):
def func(**kwargs):
raise tf.errors.ResourceExhaustedError(0, "", None)

with pytest.raises(tf.errors.ResourceExhaustedError):
utils.run_with_adaptive_batch_size(
batch_size=64,
func=func,
x=tf.data.Dataset.from_tensor_slices(np.random.rand(100, 1)).batch(64),
validation_data=tf.data.Dataset.from_tensor_slices(
np.random.rand(100, 1)
).batch(64),
)
utils.run_with_adaptive_batch_size(
batch_size=64,
func=func,
x=tf.data.Dataset.from_tensor_slices(np.random.rand(100, 1)).batch(64),
validation_data=tf.data.Dataset.from_tensor_slices(
np.random.rand(100, 1)
).batch(64),
)
_, err = capfd.readouterr()
assert (
err == "Not enough memory, reduced batch size is already set to 1. "
"Current model will be skipped."
)


def test_get_hyperparameter_with_none_return_hp():
Expand Down

0 comments on commit 176afc7

Please sign in to comment.