Image dimension checking for ControlNet FLUX #9550
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Issue
This addresses an issue discussed in a two PRs, see #9406 (comment) and #9507 (comment)
The FLUX controlnet pipeline is actually lacking any checks for the shape or number of control images passed (for
np.ndarray or torch.Tensor
andPIL
objects, respectively).I will give a simple example. If you were to run the following code:
you'd get the following error:
This is actually because the number of control images must match the number of prompts passed -- in this case we passed in a control image of batch size 2 but the number of prompts passed is 1. Because we don't catch for this, it results in a downstream error related to the packing of the latents.
It turns out SDXL's controlnet actually checks to make sure the number of control images are consistent with the number of prompts (I do recall one of the two are also allowed to be a singleton list, which is also fine). I essentially ported over the
check_image
method fromStableDiffusionControlNetPipeline
as well as modifycheck_inputs
to actually check the control image as well. Now if you run the above code you will get the following error instead, which makes it much clearer what the issue is:This fix should also work for
MultiControlNet
, which means you can do something like this:i.e.
images
andimages2
are bothtorch.Tensor
with a batch size of3
, and their corresponding ControlNet states (which will be effectively have double batch size due tonum_images_per_prompt=2
) will be summed together.I have some tests you can copy and paste from here: https://github.com/christopher-beckham/diffusers-tests/blob/4b548f8/controlnet_pipeline_cleaner_api/flux.py
(you can run with
python -m unittest flux.py
)Other concerns
There are some questions I have however. Why is it that we skip the image preprocessing if the image is
torch.Tensor
? i.e.diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Lines 526 to 529 in 9cd3755
This also seems inconsistent with what is done in the SDXL ControlNet code:
diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
Line 857 in 7071b74
It may also lead to unexpected behaviour because
preprocess
explicitly tries to usewidth
andheight
to preprocess the image (if they areNone
, then a reasonable default is used instead, depending on what the precise model is). But this logic gets skipped entirely if atorch.Tensor
is passed.Thanks.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings
Who can review?
@yiyixuxu @wangqixun