From 3078001439d25b66ef5627c9e3d431aa23bbed73 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sun, 14 May 2023 01:49:41 +0000 Subject: [PATCH 1/2] Add/modify CFG callbacks Required by self-attn guidance extension https://github.com/ashen-sensored/sd_webui_SAG --- modules/script_callbacks.py | 35 +++++++++++++++++++++++++++++++ modules/sd_samplers_kdiffusion.py | 8 ++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 7d9dd736121..e83c6ecf539 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -53,6 +53,21 @@ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, te class CFGDenoisedParams: + def __init__(self, x, sampling_step, total_sampling_steps, inner_model): + self.x = x + """Latent image representation in the process of being denoised""" + + self.sampling_step = sampling_step + """Current Sampling step number""" + + self.total_sampling_steps = total_sampling_steps + """Total number of sampling steps planned""" + + self.inner_model = inner_model + """Inner model reference that is being used for denoising""" + + +class AfterCFGCallbackParams: def __init__(self, x, sampling_step, total_sampling_steps): self.x = x """Latent image representation in the process of being denoised""" @@ -63,6 +78,9 @@ def __init__(self, x, sampling_step, total_sampling_steps): self.total_sampling_steps = total_sampling_steps """Total number of sampling steps planned""" + self.output_altered = False + """A flag for CFGDenoiser that indicates whether the output has been altered by the callback""" + class UiTrainTabParams: def __init__(self, txt2img_preview_params): @@ -87,6 +105,7 @@ def __init__(self, imgs, cols, rows): callbacks_image_saved=[], callbacks_cfg_denoiser=[], callbacks_cfg_denoised=[], + callbacks_cfg_after_cfg=[], callbacks_before_component=[], callbacks_after_component=[], callbacks_image_grid=[], @@ -186,6 +205,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams): report_exception(c, 'cfg_denoised_callback') +def cfg_after_cfg_callback(params: AfterCFGCallbackParams): + for c in callback_map['callbacks_cfg_after_cfg']: + try: + c.callback(params) + except Exception: + report_exception(c, 'cfg_after_cfg_callback') + + def before_component_callback(component, **kwargs): for c in callback_map['callbacks_before_component']: try: @@ -332,6 +359,14 @@ def on_cfg_denoised(callback): add_callback(callback_map['callbacks_cfg_denoised'], callback) +def on_cfg_after_cfg(callback): + """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations has completed. + The callback is called with one argument: + - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. + """ + add_callback(callback_map['callbacks_cfg_after_cfg'], callback) + + def on_before_component(callback): """register a function to be called before a component is created. The callback is called with arguments: diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index e9e41818c81..55f0d3a3e25 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -8,6 +8,7 @@ import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback +from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback samplers_k_diffusion = [ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), @@ -160,7 +161,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be - denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) + denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) cfg_denoised_callback(denoised_params) devices.test_for_nans(x_out, "unet") @@ -180,6 +181,11 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if self.mask is not None: denoised = self.init_latent * self.mask + self.nmask * denoised + after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) + cfg_after_cfg_callback(after_cfg_callback_params) + if after_cfg_callback_params.output_altered: + denoised = after_cfg_callback_params.x + self.step += 1 return denoised From 8abfc95013d247c8a863d048574bc1f9d1eb0443 Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Sun, 14 May 2023 12:56:34 +0800 Subject: [PATCH 2/2] Update script_callbacks.py --- modules/script_callbacks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index e83c6ecf539..57dfd457f22 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -64,7 +64,7 @@ def __init__(self, x, sampling_step, total_sampling_steps, inner_model): """Total number of sampling steps planned""" self.inner_model = inner_model - """Inner model reference that is being used for denoising""" + """Inner model reference used for denoising""" class AfterCFGCallbackParams: @@ -79,7 +79,7 @@ def __init__(self, x, sampling_step, total_sampling_steps): """Total number of sampling steps planned""" self.output_altered = False - """A flag for CFGDenoiser that indicates whether the output has been altered by the callback""" + """A flag for CFGDenoiser indicating whether the output has been altered by the callback""" class UiTrainTabParams: @@ -360,9 +360,9 @@ def on_cfg_denoised(callback): def on_cfg_after_cfg(callback): - """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations has completed. + """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed. The callback is called with one argument: - - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. + - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation. """ add_callback(callback_map['callbacks_cfg_after_cfg'], callback)