Skip to content

Commit

Permalink
add torch.Generator to MultiUpscaler.upscale + make MultiUpscaler.dif…
Browse files Browse the repository at this point in the history
…fuse_targets "stateless"
  • Loading branch information
Laurent2916 committed Sep 26, 2024
1 parent 883a212 commit 5a92285
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import Tensor
from typing_extensions import TypeVar

from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion, Size
Expand Down Expand Up @@ -217,13 +217,12 @@ def compute_upscaler_targets(

def diffuse_targets(
self,
noise: torch.Tensor,
targets: Sequence[T],
image: Image.Image,
latent_size: Size,
first_step: int,
autoencoder_tile_length: int,
) -> Image.Image:
noise = torch.randn(size=(1, 4, *latent_size), device=self.device, dtype=self.dtype)
with self.sd.lda.tiled_inference(image, (autoencoder_tile_length, autoencoder_tile_length)):
latents = self.sd.lda.tiled_image_to_latents(image)
x = self.sd.solver.add_noise(x=latents, noise=noise, step=first_step)
Expand All @@ -249,7 +248,7 @@ def upscale(
solver_type: type[Solver] = DPMSolver,
num_inference_steps: int = 18,
autoencoder_tile_length: int = 1024,
seed: int = 37,
generator: torch.Generator | None = None,
) -> Image.Image:
"""
Upscale an image using the multi upscaler.
Expand Down Expand Up @@ -280,10 +279,8 @@ def upscale(
between quality and speed.
autoencoder_tile_length: The length of the autoencoder tiles. It shouldn't affect the end result, but
lowering it can reduce GPU memory usage (but increase computation time).
seed: The seed to use for the random number generator.
generator: The random number generator to use for sampling noise.
"""
manual_seed(seed)

# update controlnet scale
self.controlnet.scale = controlnet_scale
self.controlnet.scale_decay = controlnet_scale_decay
Expand Down Expand Up @@ -323,11 +320,19 @@ def upscale(
clip_text_embedding=clip_text_embedding,
)

# initialize the noise
noise = torch.randn(
size=(1, 4, *latent_size),
device=self.device,
dtype=self.dtype,
generator=generator,
)

# diffuse the tiles
return self.diffuse_targets(
noise=noise,
targets=targets,
image=image,
latent_size=latent_size,
first_step=first_step,
autoencoder_tile_length=autoencoder_tile_length,
)
Expand Down
4 changes: 3 additions & 1 deletion tests/e2e/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2669,7 +2669,9 @@ def test_multi_upscaler(
clarity_example: Image.Image,
expected_multi_upscaler: Image.Image,
) -> None:
predicted_image = multi_upscaler.upscale(clarity_example)
generator = torch.Generator(device=multi_upscaler.device)
generator.manual_seed(37)
predicted_image = multi_upscaler.upscale(clarity_example, generator=generator)
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99)


Expand Down

0 comments on commit 5a92285

Please sign in to comment.