Skip to content

Commit

Permalink
fix: DDIM restoration when batch_size > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
royale authored and beniz committed Aug 22, 2023
1 parent bd16f1e commit ccb445b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions models/modules/diffusion_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,12 +378,12 @@ def ddim_p_mean_variance(
y_0_hat.clamp_(-1.0, 1.0)

gamma_t = self.extract(
getattr(self.denoise_fn.model, "gammas_" + phase), t, x_shape=(1, 1)
getattr(self.denoise_fn.model, "gammas_" + phase), t, x_shape=(1, 1, 1, 1)
).to(y_t.device)
gamma_prevt = self.extract(
getattr(self.denoise_fn.model, "gammas_prev_" + phase),
prevt + 1,
x_shape=(1, 1),
x_shape=(1, 1, 1, 1),
).to(y_t.device)

## denoising formula for model_mean witih DDIM
Expand Down

0 comments on commit ccb445b

Please sign in to comment.