Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] fix graph break problems partially #5453

Merged
merged 14 commits into from
Oct 23, 2023
1 change: 0 additions & 1 deletion src/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,6 @@ def forward(
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0

scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
Expand Down
8 changes: 5 additions & 3 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,9 +874,11 @@ def forward(
forward_upsample_size = False
upsample_size = None

if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
# Forward upsample size to force interpolation output size.
Comment on lines -877 to -878
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile() fails to compile these kinds of iterators right now.

forward_upsample_size = True
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
break

# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
Expand Down
9 changes: 8 additions & 1 deletion src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -976,8 +976,15 @@ def __call__(

# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput

Expand Down Expand Up @@ -1144,8 +1144,15 @@ def __call__(

# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,8 @@ def __call__(
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_intrablock_additional_residuals=[state.clone() for state in adapter_state],
).sample
return_dict=False,
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
)[0]

# perform guidance
if do_classifier_free_guidance:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1084,9 +1084,11 @@ def forward(
forward_upsample_size = False
upsample_size = None

if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
break

# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
Expand Down
Loading