Skip to content

Commit

Permalink
[torch.compile] fix graph break problems partially (#5453)
Browse files Browse the repository at this point in the history
* fix: controlnet graph?

* fix: sample

* fix:

* remove print

* styling

* fix-copies

* prevent more graph breaks?

* prevent more graph breaks?

* see?

* revert.

* compilation.

* rpopagate changes to controlnet sdxl pipeline too.

* add: clean version checking.
  • Loading branch information
sayakpaul authored Oct 23, 2023
1 parent 1ade42f commit 48ce118
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 10 deletions.
1 change: 0 additions & 1 deletion src/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,6 @@ def forward(
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0

scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
Expand Down
8 changes: 5 additions & 3 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,9 +874,11 @@ def forward(
forward_upsample_size = False
upsample_size = None

if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
break

# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
Expand Down
9 changes: 8 additions & 1 deletion src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -976,8 +976,15 @@ def __call__(

# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput

Expand Down Expand Up @@ -1144,8 +1144,15 @@ def __call__(

# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,8 @@ def __call__(
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_intrablock_additional_residuals=[state.clone() for state in adapter_state],
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1084,9 +1084,11 @@ def forward(
forward_upsample_size = False
upsample_size = None

if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
break

# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
Expand Down

0 comments on commit 48ce118

Please sign in to comment.