Skip to content

Commit

Permalink
Stop tensorflow from eating all GPU memory
Browse files Browse the repository at this point in the history
By default, tensorflow will consume all available GPU memory:
https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth

Say you are running on KERAS_BACKEND=jax, and jax and tf have both
been configured with GPU suport:
- keras_core will always import and initialize tensorflow
- tensorflow will use all available GPU memory
- jax will attempt any GPU allocation and immediately fail

Note this does not happen in colab because colab automatically exports
the environment variable:
TF_FORCE_GPU_ALLOW_GROWTH=true

From keras-core, we can attempt to work around it by limiting tensorflow
GPU growth. Long term we should work around it by not importing
tensorflow on jax and torch backends.
  • Loading branch information
mattdangerw committed Jul 13, 2023
1 parent e4dec5a commit a2b17b2
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions keras_core/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@
# upon import.
import torch

if backend() != "tensorflow":
import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")
if gpus:
# Stop tensorflow from using all avilable GPU memory. See
# https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
# This might fail, e.g. if tensorflow was already imported and
# initialized before the keras-core import.
pass

from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.backend.common.keras_tensor import any_symbolic_tensors
from keras_core.backend.common.keras_tensor import is_keras_tensor
Expand Down

0 comments on commit a2b17b2

Please sign in to comment.