Skip to content

Commit

Permalink
fix DPM-Solver with bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
catwell committed Sep 25, 2024
1 parent 283bf45 commit 83b9312
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 48 deletions.
12 changes: 3 additions & 9 deletions src/refiners/foundationals/latent_diffusion/solvers/dpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,6 @@ def _generate_timesteps(self) -> torch.Tensor:
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
return torch.tensor(np_space).flip(0)

def _generate_sigmas(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Generate the sigmas used by the solver."""
assert self.params.sigma_schedule is not None, "sigma_schedule must be set for the DPM solver"
sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = sigmas.flip(0)
rescaled_sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)
rescaled_sigmas = torch.cat([rescaled_sigmas, torch.tensor([0.0])])
return sigmas, rescaled_sigmas

def _rescale_sigmas(self, sigmas: torch.Tensor, sigma_schedule: NoiseSchedule | None) -> torch.Tensor:
"""Rescale the sigmas according to the sigma schedule."""
match sigma_schedule:
Expand All @@ -140,9 +131,12 @@ def _rescale_sigmas(self, sigmas: torch.Tensor, sigma_schedule: NoiseSchedule |
case NoiseSchedule.KARRAS:
rho = 7
case None:
if sigmas.dtype == torch.bfloat16:
sigmas = sigmas.to(torch.float32)
return torch.tensor(
np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()),
device=self.device,
dtype=self.dtype,
)

linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device)
Expand Down
86 changes: 47 additions & 39 deletions tests/foundationals/latent_diffusion/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from warnings import warn

import pytest
from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn, tensor
import torch
from torch import Tensor, device as Device

from refiners.fluxion import manual_seed
from refiners.foundationals.latent_diffusion.solvers import (
Expand All @@ -27,7 +28,7 @@ def test_ddpm_diffusers():
diffusers_scheduler = DDPMScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012)
diffusers_scheduler.set_timesteps(1000)
solver = DDPM(num_inference_steps=1000)
assert equal(diffusers_scheduler.timesteps, solver.timesteps)
assert torch.equal(diffusers_scheduler.timesteps, solver.timesteps)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -58,10 +59,10 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool, sde_var
sigma_schedule=NoiseSchedule.KARRAS if use_karras_sigmas else None,
),
)
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
assert torch.equal(solver.timesteps, diffusers_scheduler.timesteps)

sample = randn(1, 3, 32, 32)
predicted_noise = randn(1, 3, 32, 32)
sample = torch.randn(1, 3, 32, 32)
predicted_noise = torch.randn(1, 3, 32, 32)

manual_seed(37)
diffusers_outputs: list[Tensor] = [
Expand All @@ -74,7 +75,7 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool, sde_var

atol = 1e-4 if use_karras_sigmas else 1e-6
for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)):
assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=atol), f"outputs differ at step {step}"
assert torch.allclose(diffusers_output, refiners_output, rtol=0.01, atol=atol), f"outputs differ at step {step}"


def test_ddim_diffusers():
Expand All @@ -92,16 +93,16 @@ def test_ddim_diffusers():
)
diffusers_scheduler.set_timesteps(30)
solver = DDIM(num_inference_steps=30)
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
assert torch.equal(solver.timesteps, diffusers_scheduler.timesteps)

sample = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32)
sample = torch.randn(1, 4, 32, 32)
predicted_noise = torch.randn(1, 4, 32, 32)

for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
refiners_output = solver(x=sample, predicted_noise=predicted_noise, step=step)

assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
assert torch.allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"


@pytest.mark.parametrize("model_prediction_type", [ModelPredictionType.NOISE, ModelPredictionType.SAMPLE])
Expand All @@ -122,20 +123,20 @@ def test_euler_diffusers(model_prediction_type: ModelPredictionType):
)
diffusers_scheduler.set_timesteps(30)
solver = Euler(num_inference_steps=30, params=SolverParams(model_prediction_type=model_prediction_type))
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
assert torch.equal(solver.timesteps, diffusers_scheduler.timesteps)

sample = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32)
sample = torch.randn(1, 4, 32, 32)
predicted_noise = torch.randn(1, 4, 32, 32)

ref_init_noise_sigma = diffusers_scheduler.init_noise_sigma # type: ignore
assert isinstance(ref_init_noise_sigma, Tensor)
assert isclose(ref_init_noise_sigma, solver.init_noise_sigma), "init_noise_sigma differ"
assert torch.isclose(ref_init_noise_sigma, solver.init_noise_sigma), "init_noise_sigma differ"

for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
refiners_output = solver(x=sample, predicted_noise=predicted_noise, step=step)

assert allclose(diffusers_output, refiners_output, rtol=0.02), f"outputs differ at step {step}"
assert torch.allclose(diffusers_output, refiners_output, rtol=0.02), f"outputs differ at step {step}"


def test_franken_diffusers():
Expand All @@ -157,21 +158,21 @@ def test_franken_diffusers():

diffusers_scheduler_2 = EulerDiscreteScheduler(**params) # type: ignore
solver = FrankenSolver(lambda: diffusers_scheduler_2, num_inference_steps=30)
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
assert torch.equal(solver.timesteps, diffusers_scheduler.timesteps)

sample = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32)
sample = torch.randn(1, 4, 32, 32)
predicted_noise = torch.randn(1, 4, 32, 32)

ref_init_noise_sigma = diffusers_scheduler.init_noise_sigma # type: ignore
assert isinstance(ref_init_noise_sigma, Tensor)
init_noise_sigma = solver.scale_model_input(tensor(1), step=-1)
assert equal(ref_init_noise_sigma, init_noise_sigma), "init_noise_sigma differ"
init_noise_sigma = solver.scale_model_input(torch.tensor(1), step=-1)
assert torch.equal(ref_init_noise_sigma, init_noise_sigma), "init_noise_sigma differ"

for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
refiners_output = solver(x=sample, predicted_noise=predicted_noise, step=step)

assert equal(diffusers_output, refiners_output), f"outputs differ at step {step}"
assert torch.equal(diffusers_output, refiners_output), f"outputs differ at step {step}"


def test_lcm_diffusers():
Expand All @@ -180,16 +181,16 @@ def test_lcm_diffusers():
manual_seed(0)

# LCMScheduler is stochastic, make sure we use identical generators
diffusers_generator = Generator().manual_seed(42)
refiners_generator = Generator().manual_seed(42)
diffusers_generator = torch.Generator().manual_seed(42)
refiners_generator = torch.Generator().manual_seed(42)

diffusers_scheduler = LCMScheduler()
diffusers_scheduler.set_timesteps(4)
solver = LCMSolver(num_inference_steps=4)
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
assert torch.equal(solver.timesteps, diffusers_scheduler.timesteps)

sample = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32)
sample = torch.randn(1, 4, 32, 32)
predicted_noise = torch.randn(1, 4, 32, 32)

for step, timestep in enumerate(diffusers_scheduler.timesteps):
alpha_prod_t = diffusers_scheduler.alphas_cumprod[timestep]
Expand All @@ -212,7 +213,7 @@ def test_lcm_diffusers():
generator=refiners_generator,
)

assert allclose(refiners_output, diffusers_output, rtol=0.01), f"outputs differ at step {step}"
assert torch.allclose(refiners_output, diffusers_output, rtol=0.01), f"outputs differ at step {step}"


def test_solver_remove_noise():
Expand All @@ -231,14 +232,14 @@ def test_solver_remove_noise():
diffusers_scheduler.set_timesteps(30)
solver = DDIM(num_inference_steps=30)

sample = randn(1, 4, 32, 32)
noise = randn(1, 4, 32, 32)
sample = torch.randn(1, 4, 32, 32)
noise = torch.randn(1, 4, 32, 32)

for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample) # type: ignore
refiners_output = solver.remove_noise(x=sample, noise=noise, step=step)

assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
assert torch.allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"


def test_solver_device(test_device: Device):
Expand All @@ -247,16 +248,16 @@ def test_solver_device(test_device: Device):
pytest.skip()

scheduler = DDIM(num_inference_steps=30, device=test_device)
x = randn(1, 4, 32, 32, device=test_device)
noise = randn(1, 4, 32, 32, device=test_device)
x = torch.randn(1, 4, 32, 32, device=test_device)
noise = torch.randn(1, 4, 32, 32, device=test_device)
noised = scheduler.add_noise(x, noise, scheduler.first_inference_step)
assert noised.device == test_device


def test_solver_add_noise(test_device: Device):
scheduler = DDIM(num_inference_steps=30, device=test_device)
latent = randn(1, 4, 32, 32, device=test_device)
noise = randn(1, 4, 32, 32, device=test_device)
latent = torch.randn(1, 4, 32, 32, device=test_device)
noise = torch.randn(1, 4, 32, 32, device=test_device)
noised = scheduler.add_noise(
x=latent,
noise=noise,
Expand All @@ -267,8 +268,8 @@ def test_solver_add_noise(test_device: Device):
noise=noise.repeat(2, 1, 1, 1),
step=[0, 0],
)
assert allclose(noised, noised_double[0])
assert allclose(noised, noised_double[1])
assert torch.allclose(noised, noised_double[0])
assert torch.allclose(noised, noised_double[1])


@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
Expand All @@ -291,20 +292,27 @@ def test_solver_timestep_spacing():
num_train_timesteps=1000,
offset=1,
)
assert equal(linspace_int, tensor([1000, 889, 778, 667, 556, 445, 334, 223, 112, 1]))
assert torch.equal(linspace_int, torch.tensor([1000, 889, 778, 667, 556, 445, 334, 223, 112, 1]))

leading = Solver.generate_timesteps(
spacing=TimestepSpacing.LEADING,
num_inference_steps=10,
num_train_timesteps=1000,
offset=1,
)
assert equal(leading, tensor([901, 801, 701, 601, 501, 401, 301, 201, 101, 1]))
assert torch.equal(leading, torch.tensor([901, 801, 701, 601, 501, 401, 301, 201, 101, 1]))

trailing = Solver.generate_timesteps(
spacing=TimestepSpacing.TRAILING,
num_inference_steps=10,
num_train_timesteps=1000,
offset=1,
)
assert equal(trailing, tensor([1000, 900, 800, 700, 600, 500, 400, 300, 200, 100]))
assert torch.equal(trailing, torch.tensor([1000, 900, 800, 700, 600, 500, 400, 300, 200, 100]))


def test_dpm_bfloat16(test_device: Device):
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
DPMSolver(num_inference_steps=5, dtype=torch.bfloat16) # should not raise

0 comments on commit 83b9312

Please sign in to comment.