Skip to content

Commit

Permalink
fix plms cfg++
Browse files Browse the repository at this point in the history
  • Loading branch information
v0xie committed Jun 17, 2024
1 parent 7393993 commit f061c95
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions modules/sd_samplers_timesteps_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def plms_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None
s_in = x.new_ones([x.shape[0]])
s_x = x.new_ones((x.shape[0], 1, 1, 1))
old_eps = []
old_lnu = []

def get_x_prev_and_pred_x0(e_t, noise_uncond, index):
# select parameters corresponding to the currently considered timestep
Expand All @@ -170,30 +171,39 @@ def get_x_prev_and_pred_x0(e_t, noise_uncond, index):
t_next = timesteps[max(index - 1, 0)].item() * s_in

e_t = model(x, ts, **extra_args)
last_noise_uncond = model.last_noise_uncond
last_noise_uncond = model.last_noise_uncond.detach().clone()

if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, last_noise_uncond, index)
e_t_next = model(x_prev, t_next, **extra_args)
last_noise_uncond = model.last_noise_uncond
last_noise_uncond_next = model.last_noise_uncond
e_t_prime = (e_t + e_t_next) / 2
last_noise_uncond_prime = (last_noise_uncond + last_noise_uncond_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
last_noise_uncond_prime = (3 * last_noise_uncond - old_lnu[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
last_noise_uncond_prime = (23 * last_noise_uncond - 16 * old_lnu[-1] + 5 * old_lnu[-2]) / 12
else:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
last_noise_uncond_prime = (55 * last_noise_uncond - 59 * old_lnu[-1] + 37 * old_lnu[-2] - 9 * old_lnu[-3]) / 24


x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, last_noise_uncond, index)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, last_noise_uncond_prime, index)

old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)

old_lnu.append(last_noise_uncond)
if len(old_lnu) >= 4:
old_lnu.pop(0)

x = x_prev

if callback is not None:
Expand Down

0 comments on commit f061c95

Please sign in to comment.