Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] fix graph break problems partially #5453

Merged
merged 14 commits into from
Oct 23, 2023
1 change: 0 additions & 1 deletion src/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,6 @@ def forward(
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0

scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
Expand Down
8 changes: 5 additions & 3 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,9 +846,11 @@ def forward(
forward_upsample_size = False
upsample_size = None

if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
# Forward upsample size to force interpolation output size.
Comment on lines -877 to -878
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile() fails to compile these kinds of iterators right now.

forward_upsample_size = True
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
break

# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import PIL.Image
import torch
import torch.nn.functional as F
from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

from ...image_processor import PipelineImageInput, VaeImageProcessor
Expand Down Expand Up @@ -976,8 +977,17 @@ def __call__(

# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_than_2_1 = version.parse(version.parse(torch.__version__).base_version) >= version.parse(
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
"2.1"
)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_than_2_1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_than_2_1:
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:

Do we need this really? It's not super pretty and looks like a bug in PT 2.1 . Also are we sure the code works fine with PT 2.0?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed internally.

torch._inductor.cudagraph_mark_step_begin()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this is in some sense a breaking change from PT, do we really have to add version specific code here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no other way to support compiled ControlNets otherwise in PT 2.1, sadly.

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import PIL.Image
import torch
import torch.nn.functional as F
from packaging import version
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer

from diffusers.utils.import_utils import is_invisible_watermark_available
Expand Down Expand Up @@ -1144,8 +1145,17 @@ def __call__(

# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_than_2_1 = version.parse(version.parse(torch.__version__).base_version) >= version.parse(
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
"2.1"
)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_than_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,8 @@ def __call__(
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_intrablock_additional_residuals=[state.clone() for state in adapter_state],
).sample
return_dict=False,
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
)[0]

# perform guidance
if do_classifier_free_guidance:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1057,9 +1057,11 @@ def forward(
forward_upsample_size = False
upsample_size = None

if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
break

# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
Expand Down
Loading