From 151e1c1744a14eb8e525bbdb54e2251852548f61 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Oct 2023 15:04:08 +0530 Subject: [PATCH 01/13] fix: controlnet graph? --- src/diffusers/models/unet_2d_condition.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 0ce2e04ad99a..ed186e08e279 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -846,9 +846,11 @@ def forward( forward_upsample_size = False upsample_size = None - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - # Forward upsample size to force interpolation output size. - forward_upsample_size = True + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: From ff13002d1dc75790216ff59ff9a7c2d8880df142 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Oct 2023 15:36:47 +0530 Subject: [PATCH 02/13] fix: sample --- .../pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index dca9e5fc3de2..44e078f03ba1 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -814,7 +814,8 @@ def __call__( encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, down_intrablock_additional_residuals=[state.clone() for state in adapter_state], - ).sample + return_dict=False + )[0] # perform guidance if do_classifier_free_guidance: From ea8c05b9a772b12292dd626418a2d9e1cf4e5031 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Oct 2023 16:12:09 +0530 Subject: [PATCH 03/13] fix: --- src/diffusers/models/adapter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 64d64d07bf77..b88d4f47289d 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -316,6 +316,7 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]: capturing information at a different stage of processing within the FullAdapter model. The number of feature tensors in the list is determined by the number of downsample blocks specified during initialization. """ + print(f"Shape before unshuffle: {x.shape}") x = self.unshuffle(x) x = self.conv_in(x) From 1cc6cfd5c5ca7e34927dc9b1ffc07107388f5831 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Oct 2023 16:13:49 +0530 Subject: [PATCH 04/13] remove print --- src/diffusers/models/adapter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index b88d4f47289d..64d64d07bf77 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -316,7 +316,6 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]: capturing information at a different stage of processing within the FullAdapter model. The number of feature tensors in the list is determined by the number of downsample blocks specified during initialization. """ - print(f"Shape before unshuffle: {x.shape}") x = self.unshuffle(x) x = self.conv_in(x) From 17d87d076207b776fb1f2fb331d2b00bf1303020 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Oct 2023 18:46:02 +0530 Subject: [PATCH 05/13] styling --- src/diffusers/models/unet_2d_condition.py | 2 +- .../pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index ed186e08e279..9a6236e8fc87 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -850,7 +850,7 @@ def forward( if dim % default_overall_up_factor != 0: # Forward upsample size to force interpolation output size. forward_upsample_size = True - break + break # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 44e078f03ba1..285ad6caabfe 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -814,7 +814,7 @@ def __call__( encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, down_intrablock_additional_residuals=[state.clone() for state in adapter_state], - return_dict=False + return_dict=False, )[0] # perform guidance From a7dad1f313349ae8d29662671bf4c3b42a4c451e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Oct 2023 18:48:13 +0530 Subject: [PATCH 06/13] fix-copies --- .../pipelines/versatile_diffusion/modeling_text_unet.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 717db3bbdb34..f97becf25987 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1057,9 +1057,11 @@ def forward( forward_upsample_size = False upsample_size = None - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - # Forward upsample size to force interpolation output size. - forward_upsample_size = True + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: From b2a0529bccba37ffc369dd6978dfaa14a3a31284 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Oct 2023 22:39:39 +0530 Subject: [PATCH 07/13] prevent more graph breaks? --- src/diffusers/models/controlnet.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index c0d2da9b8c5f..59d9cceee705 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -815,20 +815,26 @@ def forward( mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling + down_block_res_samples = [] if guess_mode and not self.config.global_pool_conditions: scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 - scales = scales * conditioning_scale - down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + # down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + for sample, scale in zip(down_block_res_samples, scales): + down_block_res_samples.append(sample * scale) mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: - down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + # down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + for sample in down_block_res_samples: + down_block_res_samples.append(sample * conditioning_scale) mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: - down_block_res_samples = [ - torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples - ] + # down_block_res_samples = [ + # torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + # ] + for sample in down_block_res_samples: + down_block_res_samples.append(torch.mean(sample, dim=(2, 3), keepdim=True)) mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) if not return_dict: From a92003a0947350f6674190fe1058bb30b1421ac3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Oct 2023 22:42:10 +0530 Subject: [PATCH 08/13] prevent more graph breaks? --- src/diffusers/models/controlnet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 59d9cceee705..78aef2d3f8a9 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -815,8 +815,8 @@ def forward( mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling - down_block_res_samples = [] if guess_mode and not self.config.global_pool_conditions: + down_block_res_samples = [] scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = scales * conditioning_scale # down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] @@ -825,6 +825,7 @@ def forward( mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: # down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + down_block_res_samples = [] for sample in down_block_res_samples: down_block_res_samples.append(sample * conditioning_scale) mid_block_res_sample = mid_block_res_sample * conditioning_scale @@ -833,6 +834,7 @@ def forward( # down_block_res_samples = [ # torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples # ] + down_block_res_samples = [] for sample in down_block_res_samples: down_block_res_samples.append(torch.mean(sample, dim=(2, 3), keepdim=True)) mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) From 9b0ed31e3f53eff58f7cbf7f3cacd45b65a85233 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Oct 2023 22:46:01 +0530 Subject: [PATCH 09/13] see? --- src/diffusers/models/controlnet.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 78aef2d3f8a9..bc13832f3ea7 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -816,27 +816,18 @@ def forward( # 6. scaling if guess_mode and not self.config.global_pool_conditions: - down_block_res_samples = [] scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = scales * conditioning_scale - # down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] - for sample, scale in zip(down_block_res_samples, scales): - down_block_res_samples.append(sample * scale) + down_block_res_samples = [(sample * scale) + 0 for sample, scale in zip(down_block_res_samples, scales)] mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: - # down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] - down_block_res_samples = [] - for sample in down_block_res_samples: - down_block_res_samples.append(sample * conditioning_scale) + down_block_res_samples = [(sample * conditioning_scale) + 0 for sample in down_block_res_samples] mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: - # down_block_res_samples = [ - # torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples - # ] - down_block_res_samples = [] - for sample in down_block_res_samples: - down_block_res_samples.append(torch.mean(sample, dim=(2, 3), keepdim=True)) + down_block_res_samples = [ + (torch.mean(sample, dim=(2, 3), keepdim=True)) + 0 for sample in down_block_res_samples + ] mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) if not return_dict: From 5fb44f289d170a2ce7b784a3a4f19963e3b6cb5d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Oct 2023 22:48:42 +0530 Subject: [PATCH 10/13] revert. --- src/diffusers/models/controlnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index bc13832f3ea7..052335f6c5cd 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -818,15 +818,15 @@ def forward( if guess_mode and not self.config.global_pool_conditions: scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = scales * conditioning_scale - down_block_res_samples = [(sample * scale) + 0 for sample, scale in zip(down_block_res_samples, scales)] + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: - down_block_res_samples = [(sample * conditioning_scale) + 0 for sample in down_block_res_samples] + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: down_block_res_samples = [ - (torch.mean(sample, dim=(2, 3), keepdim=True)) + 0 for sample in down_block_res_samples + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples ] mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) From 0c3cd364fd770d50950199284e2b66791724a428 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 20 Oct 2023 07:45:29 +0530 Subject: [PATCH 11/13] compilation. --- .../pipelines/controlnet/pipeline_controlnet.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index f52b222ee129..25fe2d5997de 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -20,6 +20,7 @@ import PIL.Image import torch import torch.nn.functional as F +from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import PipelineImageInput, VaeImageProcessor @@ -976,8 +977,17 @@ def __call__( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_than_2_1 = version.parse(version.parse(torch.__version__).base_version) >= version.parse( + "2.1" + ) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_than_2_1: + torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) From 1c710bc2ab568dc03676ca4856134e03c6568946 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 20 Oct 2023 08:32:35 +0530 Subject: [PATCH 12/13] rpopagate changes to controlnet sdxl pipeline too. --- .../pipelines/controlnet/pipeline_controlnet.py | 2 +- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 25fe2d5997de..562ce76240b0 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -984,7 +984,7 @@ def __call__( ) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - # Relevant thread: + # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_than_2_1: torch._inductor.cudagraph_mark_step_begin() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 59573665867e..6248ebb5b816 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -20,6 +20,7 @@ import PIL.Image import torch import torch.nn.functional as F +from packaging import version from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from diffusers.utils.import_utils import is_invisible_watermark_available @@ -1144,8 +1145,17 @@ def __call__( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_than_2_1 = version.parse(version.parse(torch.__version__).base_version) >= version.parse( + "2.1" + ) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_than_2_1: + torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) From 35c533d450a89fd63db93cdd31c59d878b396fbd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 21 Oct 2023 22:57:15 +0530 Subject: [PATCH 13/13] add: clean version checking. --- .../pipelines/controlnet/pipeline_controlnet.py | 9 +++------ .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 9 +++------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 16749e0740d0..6944d9331253 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -20,7 +20,6 @@ import PIL.Image import torch import torch.nn.functional as F -from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import PipelineImageInput, VaeImageProcessor @@ -36,7 +35,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -979,14 +978,12 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) - is_torch_higher_equal_than_2_1 = version.parse(version.parse(torch.__version__).base_version) >= version.parse( - "2.1" - ) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_than_2_1: + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 3c749811a217..d6278c4f046a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -20,7 +20,6 @@ import PIL.Image import torch import torch.nn.functional as F -from packaging import version from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from diffusers.utils.import_utils import is_invisible_watermark_available @@ -37,7 +36,7 @@ from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -1147,14 +1146,12 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) - is_torch_higher_equal_than_2_1 = version.parse(version.parse(torch.__version__).base_version) >= version.parse( - "2.1" - ) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_than_2_1: + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents