Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image dimension checking for ControlNet FLUX #9550

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

christopher-beckham
Copy link
Contributor

@christopher-beckham christopher-beckham commented Sep 28, 2024

What does this PR do?

Issue

This addresses an issue discussed in a two PRs, see #9406 (comment) and #9507 (comment)

The FLUX controlnet pipeline is actually lacking any checks for the shape or number of control images passed (for np.ndarray or torch.Tensor and PIL objects, respectively).

I will give a simple example. If you were to run the following code:

pipe = FluxControlNetPipeline.from_pretrained(
    base_model, vae=vae, controlnet=multi_controlnet, torch_dtype=torch.bfloat16
).to("cuda")
# image_t is a torch tensor of shape (2,3,h,w)
self.pipe(
    prompt=["test"],
    control_image=image_t, 
    control_mode=0, 
    num_images_per_prompt=1,
    num_inference_steps=2
)

you'd get the following error:

Traceback (most recent call last):
  File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers-tests/controlnet_pipeline_cleaner_api/flux.py", line 67, in test_torch_batched_ctrl_wrong_1ipp
    self.pipe(
  File "/home/beckhamc/envs/diffusers/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 742, in __call__
    control_image = self._pack_latents(
  File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 458, in _pack_latents
    latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
RuntimeError: shape '[1, 16, 32, 2, 32, 2]' is invalid for input of size 131072

This is actually because the number of control images must match the number of prompts passed -- in this case we passed in a control image of batch size 2 but the number of prompts passed is 1. Because we don't catch for this, it results in a downstream error related to the packing of the latents.

It turns out SDXL's controlnet actually checks to make sure the number of control images are consistent with the number of prompts (I do recall one of the two are also allowed to be a singleton list, which is also fine). I essentially ported over the check_image method from StableDiffusionControlNetPipeline as well as modify check_inputs to actually check the control image as well. Now if you run the above code you will get the following error instead, which makes it much clearer what the issue is:

Traceback (most recent call last):
  File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers-tests/controlnet_pipeline_cleaner_api/flux.py", line 67, in test_torch_batched_ctrl_wrong_1ipp
    self.pipe(
  File "/home/beckhamc/envs/diffusers/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 742, in __call__
    self.check_inputs(
  File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 475, in check_inputs
    self.check_image(image, prompt, prompt_embeds)
  File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 427, in check_image
    raise ValueError(
ValueError: If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: 2, prompt batch size: 1

This fix should also work for MultiControlNet, which means you can do something like this:

multi_controlnet = FluxMultiControlNetModel([controlnet] * 2)
pipe = FluxControlNetPipeline.from_pretrained(
    base_model, vae=vae, controlnet=multi_controlnet, torch_dtype=torch.bfloat16
).to("cuda")
images = pipe(
    prompt=["1","2","3"],
    control_image=[images1, images2], 
    controlnet_conditioning_scale=[0.6, 0.6],
    control_mode=0,
    num_images_per_prompt=2
)

i.e. images and images2 are both torch.Tensor with a batch size of 3, and their corresponding ControlNet states (which will be effectively have double batch size due to num_images_per_prompt=2) will be summed together.

I have some tests you can copy and paste from here: https://github.com/christopher-beckham/diffusers-tests/blob/4b548f8/controlnet_pipeline_cleaner_api/flux.py

(you can run with python -m unittest flux.py)

Other concerns

There are some questions I have however. Why is it that we skip the image preprocessing if the image is torch.Tensor? i.e.

if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)

This also seems inconsistent with what is done in the SDXL ControlNet code:

image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)

It may also lead to unexpected behaviour because preprocess explicitly tries to use width and height to preprocess the image (if they are None, then a reasonable default is used instead, depending on what the precise model is). But this logic gets skipped entirely if a torch.Tensor is passed.

Thanks.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you read our philosophy doc (important for complex PRs)?
  • Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings
  • Did you write any new necessary tests? (Yes but in my own standalone repo which I linked to )

Who can review?

@yiyixuxu @wangqixun

@@ -389,10 +390,49 @@ def encode_prompt(

return prompt_embeds, pooled_prompt_embeds, text_ids

# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
Copy link
Collaborator

@yiyixuxu yiyixuxu Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so. I think this method was made before we introduced image processor, which set a standard image input format we accept across all our pipelines and check if it is a valid format there

if not is_valid_image_imagelist(image):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants