diff --git a/models/modules/diffusion_generator.py b/models/modules/diffusion_generator.py index 58ae92f5f..209093393 100644 --- a/models/modules/diffusion_generator.py +++ b/models/modules/diffusion_generator.py @@ -378,12 +378,12 @@ def ddim_p_mean_variance( y_0_hat.clamp_(-1.0, 1.0) gamma_t = self.extract( - getattr(self.denoise_fn.model, "gammas_" + phase), t, x_shape=(1, 1) + getattr(self.denoise_fn.model, "gammas_" + phase), t, x_shape=(1, 1, 1, 1) ).to(y_t.device) gamma_prevt = self.extract( getattr(self.denoise_fn.model, "gammas_prev_" + phase), prevt + 1, - x_shape=(1, 1), + x_shape=(1, 1, 1, 1), ).to(y_t.device) ## denoising formula for model_mean witih DDIM