-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch.compile] fix graph break problems partially #5453
Changes from 13 commits
151e1c1
ff13002
ea8c05b
1cc6cfd
17d87d0
a7dad1f
b2a0529
a92003a
9b0ed31
5fb44f2
0c3cd36
1c710bc
402e2c9
35c533d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -20,6 +20,7 @@ | |||||
import PIL.Image | ||||||
import torch | ||||||
import torch.nn.functional as F | ||||||
from packaging import version | ||||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer | ||||||
|
||||||
from ...image_processor import PipelineImageInput, VaeImageProcessor | ||||||
|
@@ -976,8 +977,17 @@ def __call__( | |||||
|
||||||
# 8. Denoising loop | ||||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | ||||||
is_unet_compiled = is_compiled_module(self.unet) | ||||||
is_controlnet_compiled = is_compiled_module(self.controlnet) | ||||||
is_torch_higher_equal_than_2_1 = version.parse(version.parse(torch.__version__).base_version) >= version.parse( | ||||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
"2.1" | ||||||
) | ||||||
with self.progress_bar(total=num_inference_steps) as progress_bar: | ||||||
for i, t in enumerate(timesteps): | ||||||
# Relevant thread: | ||||||
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 | ||||||
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_than_2_1: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Do we need this really? It's not super pretty and looks like a bug in PT 2.1 . Also are we sure the code works fine with PT 2.0? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussed internally. |
||||||
torch._inductor.cudagraph_mark_step_begin() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm this is in some sense a breaking change from PT, do we really have to add version specific code here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no other way to support compiled ControlNets otherwise in PT 2.1, sadly. |
||||||
# expand the latents if we are doing classifier free guidance | ||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.compile()
fails to compile these kinds of iterators right now.