Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
DN6 committed Nov 19, 2024
1 parent 5ee4d3c commit 06b0835
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,6 @@ def prepare_latents(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
Expand Down Expand Up @@ -878,7 +877,6 @@ def __call__(
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)

latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)

latents, latent_image_ids = self.prepare_latents(
init_image,
latent_timestep,
Expand Down
12 changes: 10 additions & 2 deletions tests/pipelines/controlnet_flux/test_controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,17 @@ def test_flux_image_output_shape(self):
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)

inputs.update({"height": height, "width": width})
inputs.update(
{
"control_image": randn_tensor(
(1, 3, height, width),
device=torch_device,
dtype=torch.float16,
)
}
)
image = pipe(**inputs).images[0]
output_height, output_width = image.shape
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)


Expand Down
29 changes: 29 additions & 0 deletions tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from diffusers.utils.testing_utils import (
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor

from ..test_pipelines_common import (
PipelineTesterMixin,
Expand Down Expand Up @@ -218,3 +219,31 @@ def test_fused_qkv_projections(self):
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."

def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)

height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update(
{
"control_image": randn_tensor(
(1, 3, height, width),
device=torch_device,
dtype=torch.float16,
),
"image": randn_tensor(
(1, 3, height, width),
device=torch_device,
dtype=torch.float16,
),
"height": height,
"width": width,
}
)
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
32 changes: 32 additions & 0 deletions tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor

from ..test_pipelines_common import PipelineTesterMixin

Expand Down Expand Up @@ -192,3 +194,33 @@ def test_attention_slicing_forward_pass(self):

def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)

def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)

height_width_pairs = [(32, 32), (72, 56)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)

inputs.update(
{
"control_image": randn_tensor(
(1, 3, height, width),
device=torch_device,
dtype=torch.float16,
),
"image": randn_tensor(
(1, 3, height, width),
device=torch_device,
dtype=torch.float16,
),
"mask_image": torch.ones((1, 1, height, width)).to(torch_device),
"height": height,
"width": width,
}
)
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
2 changes: 1 addition & 1 deletion tests/pipelines/flux/test_pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)

height_width_pairs = [(32, 32), (72, 56)]
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/flux/test_pipeline_flux_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)

height_width_pairs = [(32, 32), (72, 56)]
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/flux/test_pipeline_flux_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)

height_width_pairs = [(32, 32), (72, 56)]
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
Expand Down

0 comments on commit 06b0835

Please sign in to comment.