Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable XLA Compatibility for Pretraining BERT with Keras NLP on TensorFlow GPU #1661

Closed
jacob-talroo opened this issue Jun 5, 2024 · 4 comments

Comments

@jacob-talroo
Copy link

Describe the bug
Training BERT using Keras NLP is significantly slower due to the keras.layers.Embedding not being XLA compatible by default on TensorFlow GPU. This is similar to an issue reported for Keras at keras-team/keras#19809.

To Reproduce
You can reproduce this issue by following the steps in this Colab Notebook: Link to Notebook

Expected behavior
I expect BERT training using Keras NLP on TensorFlow GPU with XLA to be optimized for performance, similar to native TensorFlow implementations.

Additional context
The lack of XLA compatibility affects the training speed and efficiency on GPU, crucial for model training scalability and practical application in production environments.

See also:

Would you like to help us fix it?
Yes, I am willing to contribute to resolving this issue by testing and suggesting implementations that ensure XLA compatibility.

@mattdangerw
Copy link
Member

I think we probably want to solve this at the Keras level not the KerasNLP ideally.

I played around with always using the one hot approach under a distribution strategy. keras-team/keras@master...mattdangerw:embedding-fix

I think this could work, but I am not sure we would want to do it when XLA is off by default on the TF backend. So the first thing might be to look at enabling XLA with tf.distribute.

@mattdangerw
Copy link
Member

Is there a reason training on the Jax backend doesn't work for your use case? It is likely faster, and everything is XLA compatible as there is no other option on Jax.

@jacob-talroo
Copy link
Author

We have switched to the JAX backend. If there is no desire to reduce Keras 2 TF vs Keras 3 TF performance degradation, we can close this one out.

Copy link

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants