Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop tensorflow from eating all GPU memory #473

Merged
merged 1 commit into from
Jul 16, 2023

Conversation

mattdangerw
Copy link
Member

By default, tensorflow will consume all available GPU memory: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth

Say you are running on KERAS_BACKEND=jax, and jax and tf have both been configured with GPU suport:

  • keras_core will always import and initialize tensorflow
  • tensorflow will use all available GPU memory
  • jax will attempt any GPU allocation and immediately fail

Note this does not happen in colab because colab automatically exports the environment variable:
TF_FORCE_GPU_ALLOW_GROWTH=true

From keras-core, we can attempt to work around it by limiting tensorflow GPU growth. Long term we should work around it by not importing tensorflow on jax and torch backends.

@mattdangerw mattdangerw requested a review from fchollet July 13, 2023 23:26
@mattdangerw mattdangerw force-pushed the memory-growth-tf branch 2 times, most recently from a2b17b2 to fa85eb3 Compare July 14, 2023 00:20
@fchollet
Copy link
Contributor

From keras-core, we can attempt to work around it by limiting tensorflow GPU growth. Long term we should work around it by not importing tensorflow on jax and torch backends.

This is not quite what we are doing. We are making it possible to use Keras without having installed TF -- but if it installed, then it is likely to get imported during the normal course of operations with another backend, because:

  1. We still use gfile when available, e.g. when saving
  2. KPLs still rely on TF

Hypothetically we could fix 1 (not easy though) but not 2.


gpus = tf.config.list_physical_devices("GPU")
if gpus:
# Stop tensorflow from using all avilable GPU memory. See
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: available

@@ -6,6 +6,21 @@
# upon import.
import torch

if backend() != "tensorflow":
import tensorflow as tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a general policy we should only import TF when requested. Right now it gets imported lazily the first time it's needed, in utils/module_utils.py. We should customize the initialize() method for TF to insert this routine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think in this case, we may be better off switching this to just setting os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" in config.py. This is essentially how colab handles the problem.

The issue with tf become a delayed import is that a user script might import tensorflow first, not use the lazy model, and lead to OOMs outside of our control. Setting the environment variable is lightweight, we can do it first thing, and it will not affect tf until it's imported. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think in this case, we may be better off switching this to just setting

SGTM. Lightweight indeed and the Colab precedent shows it's fine.

To note, I have made the changes described above -- now we only import TF if using KPL or if saving to GCS. But many users will likely import it anyway (e.g. tf.data + JAX)

By default, tensorflow will consume all available GPU memory:
https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth

Say you are running on KERAS_BACKEND=jax, and jax and tf have both
been configured with GPU suport:
- keras_core will always import and initialize tensorflow
- tensorflow will use all available GPU memory
- jax will attempt any GPU allocation and immediately fail

Note this does not happen in colab because colab automatically exports
the environment variable:
TF_FORCE_GPU_ALLOW_GROWTH=true

We can do the same from keras-core.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants