Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

running on tpu #25

Merged
merged 1 commit into from
Oct 19, 2024
Merged

running on tpu #25

merged 1 commit into from
Oct 19, 2024

Conversation

entrpn
Copy link
Contributor

@entrpn entrpn commented Oct 17, 2024

This PR adds the ability to run on TPUs, although its very slow right now. I'm unsure how to this hooks into uv as I used pip. Please review.

  • Loads encoders on CPU when device_type == 'tpu'.
  • Can run on 32GB of HBM, i.e., TPUv4.

@SauravMaheshkar SauravMaheshkar linked an issue Oct 18, 2024 that may be closed by this pull request
Copy link
Collaborator

@SauravMaheshkar SauravMaheshkar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR @entrpn. I've had a look at the diff and it makes sense to me. I'll just confirm by running this on TPU later today.

@SauravMaheshkar SauravMaheshkar added the feature 🚀 New feature or request label Oct 18, 2024
Copy link
Collaborator

@ariG23498 ariG23498 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

I have asked for some TPU support. Once I have them, I can check this PR.

On another note I have the following queries:

  1. We have FLAX implementation of T5 and CLIP in Hugging Face, would it be easier for us to load them instead of PyTorch versions?
  2. Do you see any place we can apply parallelization to make this implementation faster? I understand that this is a more dense question which requires one to look into the code, if that is too much to ask at this time I completely understand if you don't want to answer it now.

@ariG23498 ariG23498 merged commit 5892764 into ml-gde:main Oct 19, 2024
@ariG23498
Copy link
Collaborator

I had checked the implementation on my side on tpu v4.

Thanks for the great contribution. ❤️

Copy link

@cataluna84 cataluna84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🤩

This conversation to use TPUs looks fantastic

@entrpn
Copy link
Contributor Author

entrpn commented Oct 19, 2024

LGTM!

I have asked for some TPU support. Once I have them, I can check this PR.

On another note I have the following queries:

  1. We have FLAX implementation of T5 and CLIP in Hugging Face, would it be easier for us to load them instead of PyTorch versions?
  2. Do you see any place we can apply parallelization to make this implementation faster? I understand that this is a more dense question which requires one to look into the code, if that is too much to ask at this time I completely understand if you don't want to answer it now.

I will take some time to understand the code better. This should be able to run faster and be parallelizable across all TPU devices.

For reference, I created a pytorch xla version that runs flux on TPUs and generates 4 images in parallel in under 10 seconds. https://github.com/entrpn/diffusers/tree/flux_ptxla

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature 🚀 New feature or request
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

add TPU support
5 participants