-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding PyTorch XLA support for sdxl inference #5273
Conversation
… with instructions for xla
Hey @ssusie, Nice addition, generally I'm ok with having this addition. However, could you maybe add
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
Thanks for the feedback Patrick. Added the dependency in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
@ssusie could you run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
Thanks everyone for the review and comments. I ran the make style and added changes to the pr. |
speedup, we need to call the pipe again on the input with the same length | ||
as the original prompt to reuse the optimized graph and get the performance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the prompt length really matters here, as embeddings always have the same shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great question. With the current setup we need the same length for the input for XLA, otherwise it will recompile the whole graph. We can potentially cut the graph with xm.mark_step after embeddings are calculated, but it is not address in this PR.
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Show resolved
Hide resolved
Thanks for the reviews and comments. Is there anything else to change or address or can this be merged? |
Let's merge it - great job @ssusie! We don't have tests here yet, but I think this is fine to begin with :-) |
* Added mark_step for sdxl to run with pytorch xla. Also updated README with instructions for xla * adding soft dependency on torch_xla * fix some styling --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
* Added mark_step for sdxl to run with pytorch xla. Also updated README with instructions for xla * adding soft dependency on torch_xla * fix some styling --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
What does this PR do?
Enables users run sdxl inference in PyTorch XLA. README_sdxl.md guide is also updated.
@patrickvonplaten @sayakpaul
Thanks.