From 176afc7eb00e533567399f19294b8ddf1ac6fc78 Mon Sep 17 00:00:00 2001 From: Anselm Hahn Date: Fri, 13 May 2022 23:20:55 +0200 Subject: [PATCH] Fixed: #1722 --- autokeras/utils/utils.py | 8 ++++++-- autokeras/utils/utils_test.py | 24 ++++++++++++++---------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/autokeras/utils/utils.py b/autokeras/utils/utils.py index 9552b9e54..2e3222d62 100644 --- a/autokeras/utils/utils.py +++ b/autokeras/utils/utils.py @@ -100,9 +100,13 @@ def run_with_adaptive_batch_size(batch_size, func, **fit_kwargs): try: history = func(x=x, validation_data=validation_data, **fit_kwargs) break - except tf.errors.ResourceExhaustedError as e: + except tf.errors.ResourceExhaustedError: if batch_size == 1: - raise e + print( + "Not enough memory, reduced batch size is already set to 1. " + "Current model will be skipped." + ) + break batch_size //= 2 print( "Not enough memory, reduce batch size to {batch_size}.".format( diff --git a/autokeras/utils/utils_test.py b/autokeras/utils/utils_test.py index 0b3acfcc4..8ce07c66d 100644 --- a/autokeras/utils/utils_test.py +++ b/autokeras/utils/utils_test.py @@ -53,19 +53,23 @@ def test_check_kt_version_error(): ) -def test_run_with_adaptive_batch_size_raise_error(): +def test_run_with_adaptive_batch_size_raise_error(capfd): def func(**kwargs): raise tf.errors.ResourceExhaustedError(0, "", None) - with pytest.raises(tf.errors.ResourceExhaustedError): - utils.run_with_adaptive_batch_size( - batch_size=64, - func=func, - x=tf.data.Dataset.from_tensor_slices(np.random.rand(100, 1)).batch(64), - validation_data=tf.data.Dataset.from_tensor_slices( - np.random.rand(100, 1) - ).batch(64), - ) + utils.run_with_adaptive_batch_size( + batch_size=64, + func=func, + x=tf.data.Dataset.from_tensor_slices(np.random.rand(100, 1)).batch(64), + validation_data=tf.data.Dataset.from_tensor_slices( + np.random.rand(100, 1) + ).batch(64), + ) + _, err = capfd.readouterr() + assert ( + err == "Not enough memory, reduced batch size is already set to 1. " + "Current model will be skipped." + ) def test_get_hyperparameter_with_none_return_hp():