Skip to content

Commit

Permalink
trained_betas ignored in some schedulers (apple#635)
Browse files Browse the repository at this point in the history
* correcting the beta value assignment

* updating DDIM and LMSDiscreteFlax schedulers

* bringing back the changes that were lost as part of main branch merge
  • Loading branch information
vishnu-anirudh authored Sep 29, 2022
1 parent f10576a commit 3dacbb9
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_lms_discrete_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
if beta_schedule == "linear":
elif beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_pndm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
if beta_schedule == "linear":
elif beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
Expand Down

0 comments on commit 3dacbb9

Please sign in to comment.