From f061c95e0f5d84d96ba95321f8a0b6604b792cfa Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 16 Jun 2024 19:13:15 -0700 Subject: [PATCH] fix plms cfg++ --- modules/sd_samplers_timesteps_impl.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py index 03aaae188a9..84047a4d36a 100644 --- a/modules/sd_samplers_timesteps_impl.py +++ b/modules/sd_samplers_timesteps_impl.py @@ -149,6 +149,7 @@ def plms_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None s_in = x.new_ones([x.shape[0]]) s_x = x.new_ones((x.shape[0], 1, 1, 1)) old_eps = [] + old_lnu = [] def get_x_prev_and_pred_x0(e_t, noise_uncond, index): # select parameters corresponding to the currently considered timestep @@ -170,30 +171,39 @@ def get_x_prev_and_pred_x0(e_t, noise_uncond, index): t_next = timesteps[max(index - 1, 0)].item() * s_in e_t = model(x, ts, **extra_args) - last_noise_uncond = model.last_noise_uncond + last_noise_uncond = model.last_noise_uncond.detach().clone() if len(old_eps) == 0: # Pseudo Improved Euler (2nd order) x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, last_noise_uncond, index) e_t_next = model(x_prev, t_next, **extra_args) - last_noise_uncond = model.last_noise_uncond + last_noise_uncond_next = model.last_noise_uncond e_t_prime = (e_t + e_t_next) / 2 + last_noise_uncond_prime = (last_noise_uncond + last_noise_uncond_next) / 2 elif len(old_eps) == 1: # 2nd order Pseudo Linear Multistep (Adams-Bashforth) e_t_prime = (3 * e_t - old_eps[-1]) / 2 + last_noise_uncond_prime = (3 * last_noise_uncond - old_lnu[-1]) / 2 elif len(old_eps) == 2: # 3nd order Pseudo Linear Multistep (Adams-Bashforth) e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + last_noise_uncond_prime = (23 * last_noise_uncond - 16 * old_lnu[-1] + 5 * old_lnu[-2]) / 12 else: # 4nd order Pseudo Linear Multistep (Adams-Bashforth) e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + last_noise_uncond_prime = (55 * last_noise_uncond - 59 * old_lnu[-1] + 37 * old_lnu[-2] - 9 * old_lnu[-3]) / 24 + - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, last_noise_uncond, index) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, last_noise_uncond_prime, index) old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) + old_lnu.append(last_noise_uncond) + if len(old_lnu) >= 4: + old_lnu.pop(0) + x = x_prev if callback is not None: