Skip to content

Commit

Permalink
plms cfg++
Browse files Browse the repository at this point in the history
  • Loading branch information
v0xie committed Jun 17, 2024
1 parent 663a4d8 commit 06c2452
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
1 change: 1 addition & 0 deletions modules/sd_samplers_timesteps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
('DDIM CFG++', sd_samplers_timesteps_impl.ddim_cfgpp, ['ddim_cfgpp'], {}),
('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),
('PLMS CFG++', sd_samplers_timesteps_impl.plms_cfgpp, ['plms_cfgpp'], {}),
('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),
]

Expand Down
61 changes: 61 additions & 0 deletions modules/sd_samplers_timesteps_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,67 @@ def get_x_prev_and_pred_x0(e_t, index):

return x

@torch.no_grad()
def plms_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)

extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
s_x = x.new_ones((x.shape[0], 1, 1, 1))
old_eps = []

def get_x_prev_and_pred_x0(e_t, noise_uncond, index):
# select parameters corresponding to the currently considered timestep
a_t = alphas[index].item() * s_x
a_prev = alphas_prev[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x

# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()

# direction pointing to x_t
dir_xt = (1. - a_prev).sqrt() * noise_uncond
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
return x_prev, pred_x0

for i in tqdm.trange(len(timesteps) - 1, disable=disable):
index = len(timesteps) - 1 - i
ts = timesteps[index].item() * s_in
t_next = timesteps[max(index - 1, 0)].item() * s_in

e_t = model(x, ts, **extra_args)

if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = model(x_prev, t_next, **extra_args)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
else:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24

last_noise_uncond = model.last_noise_uncond
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, last_noise_uncond, index)

old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)

x = x_prev

if callback is not None:
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})

return x

class UniPCCFG(uni_pc.UniPC):
def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
Expand Down

0 comments on commit 06c2452

Please sign in to comment.