Skip to content

Commit

Permalink
re-derive sqrt alpha bar and sqrt one minus alphabar
Browse files Browse the repository at this point in the history
This is the only place these values are ever referenced outside of training code so this change is very justifiable and more consistent.
  • Loading branch information
drhead authored Dec 9, 2023
1 parent 78acdcf commit 5381405
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion modules/sd_samplers_timesteps.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, model, *args, **kwargs):
self.inner_model = model

def predict_eps_from_z_and_v(self, x_t, t, v):
return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t

def forward(self, input, timesteps, **kwargs):
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
Expand Down

0 comments on commit 5381405

Please sign in to comment.