diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b6e4761 --- /dev/null +++ b/.gitignore @@ -0,0 +1,129 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e1fd273 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..8945c67 --- /dev/null +++ b/README.md @@ -0,0 +1,26 @@ +# Implementation of Self Attention Guidance in webui +https://arxiv.org/abs/2210.00939 + +## Additional setup requirements after installation: + +### For AUTOMATIC1111 webui: +at commit 22bcc7be + +run the following command in root directory stable-diffusion-webui/: +``` +git apply --ignore-whitespace extensions/sd_webui_SAG/automatic1111-CFGDenoiser-and-script_callbacks-mod-for-SAG.patch +``` + + +### For vladmandic webui: +at commit 7c684a8b + +run the following command in root directory automatic/ +``` +git apply --ignore-whitespace extensions/sd_webui_SAG/vladmandic-CFGDenoiser-and-script_callbacks-mod-for-SAG.patch +``` + + +Demos with stealth pnginfo: +![xyz_grid-0014-232592377.png](resources%2Fimg%2Fxyz_grid-0014-232592377.png) +![xyz_grid-0001-232592377.png](resources%2Fimg%2Fxyz_grid-0001-232592377.png) \ No newline at end of file diff --git a/automatic1111-CFGDenoiser-and-script_callbacks-mod-for-SAG.patch b/automatic1111-CFGDenoiser-and-script_callbacks-mod-for-SAG.patch new file mode 100644 index 0000000..167bf58 --- /dev/null +++ b/automatic1111-CFGDenoiser-and-script_callbacks-mod-for-SAG.patch @@ -0,0 +1,118 @@ +From cb129e420612813d74043aa4a6a49575b53e9c14 Mon Sep 17 00:00:00 2001 +From: Ashen +Date: Fri, 21 Apr 2023 09:40:59 -0700 +Subject: [PATCH] CFGDenoiser and script_callbacks mod for SAG + +--- + modules/script_callbacks.py | 34 +++++++++++++++++++++++++++++++ + modules/sd_samplers_kdiffusion.py | 7 ++++++- + 2 files changed, 40 insertions(+), 1 deletion(-) + +diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py +index 07911876..d3d3df14 100644 +--- a/modules/script_callbacks.py ++++ b/modules/script_callbacks.py +@@ -53,6 +53,21 @@ class CFGDenoiserParams: + + + class CFGDenoisedParams: ++ def __init__(self, x, sampling_step, total_sampling_steps, inner_model): ++ self.x = x ++ """Latent image representation in the process of being denoised""" ++ ++ self.sampling_step = sampling_step ++ """Current Sampling step number""" ++ ++ self.total_sampling_steps = total_sampling_steps ++ """Total number of sampling steps planned""" ++ ++ self.inner_model = inner_model ++ """Inner model reference that is being used for denoising""" ++ ++ ++class AfterCFGCallbackParams: + def __init__(self, x, sampling_step, total_sampling_steps): + self.x = x + """Latent image representation in the process of being denoised""" +@@ -63,6 +78,8 @@ class CFGDenoisedParams: + self.total_sampling_steps = total_sampling_steps + """Total number of sampling steps planned""" + ++ self.output_altered = False ++ """A flag for CFGDenoiser that indicates whether the output has been altered by the callback""" + + class UiTrainTabParams: + def __init__(self, txt2img_preview_params): +@@ -87,6 +104,7 @@ callback_map = dict( + callbacks_image_saved=[], + callbacks_cfg_denoiser=[], + callbacks_cfg_denoised=[], ++ callbacks_cfg_after_cfg=[], + callbacks_before_component=[], + callbacks_after_component=[], + callbacks_image_grid=[], +@@ -177,6 +195,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams): + report_exception(c, 'cfg_denoised_callback') + + ++def cfg_after_cfg_callback(params: AfterCFGCallbackParams): ++ for c in callback_map['callbacks_cfg_after_cfg']: ++ try: ++ c.callback(params) ++ except Exception: ++ report_exception(c, 'cfg_after_cfg_callback') ++ ++ + def before_component_callback(component, **kwargs): + for c in callback_map['callbacks_before_component']: + try: +@@ -318,6 +344,14 @@ def on_cfg_denoised(callback): + add_callback(callback_map['callbacks_cfg_denoised'], callback) + + ++def on_cfg_after_cfg(callback): ++ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations has completed. ++ The callback is called with one argument: ++ - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. ++ """ ++ add_callback(callback_map['callbacks_cfg_after_cfg'], callback) ++ ++ + def on_before_component(callback): + """register a function to be called before a component is created. + The callback is called with arguments: +diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py +index e9f08518..6ff55ba6 100644 +--- a/modules/sd_samplers_kdiffusion.py ++++ b/modules/sd_samplers_kdiffusion.py +@@ -9,6 +9,7 @@ from modules.shared import opts, state + import modules.shared as shared + from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback + from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback ++from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback + + samplers_k_diffusion = [ + ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), +@@ -146,7 +147,7 @@ class CFGDenoiser(torch.nn.Module): + + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:])) + +- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) ++ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) + cfg_denoised_callback(denoised_params) + + devices.test_for_nans(x_out, "unet") +@@ -164,6 +165,10 @@ class CFGDenoiser(torch.nn.Module): + if self.mask is not None: + denoised = self.init_latent * self.mask + self.nmask * denoised + ++ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) ++ cfg_after_cfg_callback(after_cfg_callback_params) ++ if after_cfg_callback_params.output_altered: ++ denoised = after_cfg_callback_params.x + self.step += 1 + + return denoised +-- +2.40.0 + diff --git a/resources/img/xyz_grid-0001-232592377.png b/resources/img/xyz_grid-0001-232592377.png new file mode 100644 index 0000000..78aba34 Binary files /dev/null and b/resources/img/xyz_grid-0001-232592377.png differ diff --git a/resources/img/xyz_grid-0014-232592377.png b/resources/img/xyz_grid-0014-232592377.png new file mode 100644 index 0000000..310f3a0 Binary files /dev/null and b/resources/img/xyz_grid-0014-232592377.png differ diff --git a/scripts/SAG.py b/scripts/SAG.py new file mode 100644 index 0000000..8468ab7 --- /dev/null +++ b/scripts/SAG.py @@ -0,0 +1,301 @@ + +from inspect import isfunction +import torch +from torch import nn, einsum +import torch.nn.functional as F +from einops import rearrange, repeat + +from modules.processing import StableDiffusionProcessing + +import math + + + +import modules.scripts as scripts +from modules import shared +import gradio as gr + +from modules.script_callbacks import on_cfg_denoiser,CFGDenoiserParams, CFGDenoisedParams, on_cfg_denoised, AfterCFGCallbackParams, on_cfg_after_cfg + +import os + +from scripts import xyz_grid_support_sag + +_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") +def exists(val): + return val is not None +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + +class LoggedSelfAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + self.attn_probs = None + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION == "fp32": + with torch.autocast(enabled=False, device_type='cuda'): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + self.attn_probs = sim + + out = einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + +def xattn_forward_log(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION == "fp32": + with torch.autocast(enabled=False, device_type='cuda'): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + self.attn_probs = sim + global current_selfattn_map + current_selfattn_map = sim + + out = einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = self.to_out(out) + global current_outsize + current_outsize = out.shape[-2:] + return out + +saved_original_selfattn_forward = None +current_selfattn_map = None +current_sag_guidance_scale = 1.0 +sag_enabled = False +sag_mask_threshold = 1.0 + +current_xin = None +current_outsize = (64,64) +current_batch_size = 1 +current_degraded_pred= None +current_unet_kwargs = {} +current_uncond_pred = None +current_degraded_pred_compensation = None + +def gaussian_blur_2d(img, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = F.pad(img, padding, mode="reflect") + img = F.conv2d(img, kernel2d, groups=img.shape[-3]) + + return img +class Script(scripts.Script): + + def __init__(self): + pass + + def title(self): + return "Latent Couple extension" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def denoiser_callback(self, parms: CFGDenoiserParams): + if not sag_enabled: + return + global current_xin, current_batch_size + + # logging current uncond size for cond/uncond output separation + current_batch_size = parms.text_uncond.shape[0] + # logging current input for eps calculation later + current_xin = parms.x[-current_batch_size:] + + # logging necessary information for SAG pred + current_uncond_emb = parms.text_uncond + current_sigma = parms.sigma + current_image_cond_in = parms.image_cond + global current_unet_kwargs + current_unet_kwargs = { + "sigma": current_sigma[-current_batch_size:], + "image_cond": current_image_cond_in[-current_batch_size:], + "text_uncond": current_uncond_emb, + } + + + + def denoised_callback(self, params: CFGDenoisedParams): + if not sag_enabled: + return + # output from DiscreteEpsDDPMDenoiser is already pred_x0 + uncond_output = params.x[-current_batch_size:] + original_latents = uncond_output + global current_uncond_pred + current_uncond_pred = uncond_output + + # Produce attention mask + # We're only interested in the last current_batch_size*head_count slices of logged self-attention map + attn_map = current_selfattn_map[-current_batch_size*8:] + bh, hw1, hw2 = attn_map.shape + b, latent_channel, latent_h, latent_w = original_latents.shape + h=8 + + middle_layer_latent_size = [math.ceil(latent_h/8), math.ceil(latent_w/8)] + + attn_map = attn_map.reshape(b, h, hw1, hw2) + attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > sag_mask_threshold + attn_mask = ( + attn_mask.reshape(b, middle_layer_latent_size[0], middle_layer_latent_size[1]) + .unsqueeze(1) + .repeat(1, latent_channel, 1, 1) + .type(attn_map.dtype) + ) + attn_mask = F.interpolate(attn_mask, (latent_h, latent_w)) + + # Blur according to the self-attention mask + degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) + degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) + + renoised_degraded_latent = degraded_latents - (uncond_output - current_xin) + # renoised_degraded_latent = degraded_latents + # get predicted x0 in degraded direction + global current_degraded_pred_compensation + current_degraded_pred_compensation = uncond_output - degraded_latents + if shared.sd_model.model.conditioning_key == "crossattn-adm": + make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} + else: + make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} + degraded_pred = params.inner_model(renoised_degraded_latent, current_unet_kwargs['sigma'], cond=make_condition_dict([current_unet_kwargs['text_uncond']], [current_unet_kwargs['image_cond']])) + global current_degraded_pred + current_degraded_pred = degraded_pred + + def cfg_after_cfg_callback(self, params: AfterCFGCallbackParams): + if not sag_enabled: + return + + params.x = params.x + (current_uncond_pred - (current_degraded_pred + current_degraded_pred_compensation)) * float(current_sag_guidance_scale) + params.output_altered = True + + + + def ui(self, is_img2img): + with gr.Accordion('Self Attention Guidance', open=False): + enabled = gr.Checkbox(label="Enabled", default=False) + scale = gr.Slider(label='Scale', minimum=-2.0, maximum=10.0, step=0.01, value=0.75) + mask_threshold = gr.Slider(label='SAG Mask Threshold', minimum=0.0, maximum=2.0, step=0.01, value=1.0) + + return [enabled, scale, mask_threshold] + + + + def process(self, p: StableDiffusionProcessing, *args, **kwargs): + enabled, scale, mask_threshold = args + global sag_enabled, sag_mask_threshold + if enabled: + + sag_enabled = True + sag_mask_threshold = mask_threshold + global current_sag_guidance_scale + current_sag_guidance_scale = scale + global saved_original_selfattn_forward + # replace target self attention module in unet with ours + + org_attn_module = shared.sd_model.model.diffusion_model.middle_block._modules['1'].transformer_blocks._modules['0'].attn1 + saved_original_selfattn_forward = org_attn_module.forward + org_attn_module.forward = xattn_forward_log.__get__(org_attn_module,org_attn_module.__class__) + + p.extra_generation_params["SAG Guidance Scale"] = scale + p.extra_generation_params["SAG Mask Threshold"] = mask_threshold + + else: + sag_enabled = False + + + if not hasattr(self, 'callbacks_added'): + on_cfg_denoiser(self.denoiser_callback) + on_cfg_denoised(self.denoised_callback) + on_cfg_after_cfg(self.cfg_after_cfg_callback) + self.callbacks_added = True + + + + + + return + + def postprocess(self, p, processed, *args): + enabled, scale, sag_mask_threshold = args + if enabled: + # restore original self attention module forward function + attn_module = shared.sd_model.model.diffusion_model.middle_block._modules['1'].transformer_blocks._modules[ + '0'].attn1 + attn_module.forward = saved_original_selfattn_forward + return + +xyz_grid_support_sag.initialize(Script) \ No newline at end of file diff --git a/scripts/xyz_grid_support_sag.py b/scripts/xyz_grid_support_sag.py new file mode 100644 index 0000000..b598827 --- /dev/null +++ b/scripts/xyz_grid_support_sag.py @@ -0,0 +1,76 @@ +import os +import os.path +import modules.scripts as scripts + + + + + + +xy_grid = None # XY Grid module +script_class = None # additional_networks scripts.Script class + + + + + + + +def update_script_args(p, value, arg_idx): + global script_class + for s in scripts.scripts_txt2img.alwayson_scripts: + if isinstance(s, script_class): + args = list(p.script_args) + # print(f"Changed arg {arg_idx} from {args[s.args_from + arg_idx - 1]} to {value}") + args[s.args_from + arg_idx] = value + p.script_args = tuple(args) + break + + + + +def apply_module(p, x, xs, i): + update_script_args(p, True, 0) # set Enabled to True + update_script_args(p, x, 2 + 4 * i) # enabled, separate_weights, ({module}, model, weight_unet, weight_tenc), ... + + + + +def apply_weight(p, x, xs, i): + update_script_args(p, True, 0) + update_script_args(p, x, 4 + 4 * i ) # enabled, separate_weights, (module, model, {weight_unet, weight_tenc}), ... + update_script_args(p, x, 5 + 4 * i) + + +def apply_weight_unet(p, x, xs, i): + update_script_args(p, True, 0) + update_script_args(p, x, 4 + 4 * i) # enabled, separate_weights, (module, model, {weight_unet}, weight_tenc), ... + + +def apply_weight_tenc(p, x, xs, i): + update_script_args(p, True, 0) + update_script_args(p, x, 5 + 4 * i) # enabled, separate_weights, (module, model, weight_unet, {weight_tenc}), ... + + +def apply_sag_guidance_scale(p, x, xs): + update_script_args(p, x, 0) + update_script_args(p, x, 1)# sag_guidance_scale + +def apply_sag_mask_threshold(p, x, xs): + update_script_args(p, x, 0) + update_script_args(p, x, 2)# sag_mask_threshold + + + + +def initialize(script): + global xy_grid, script_class + xy_grid = None + script_class = script + for scriptDataTuple in scripts.scripts_data: + if os.path.basename(scriptDataTuple.path) == "xy_grid.py" or os.path.basename(scriptDataTuple.path) == "xyz_grid.py": + xy_grid = scriptDataTuple.module + sag_guidance_scale = xy_grid.AxisOption("SAG Guidance Scale", float, lambda p, x, xs: apply_sag_guidance_scale(p,x,xs), xy_grid.format_value_add_label, None, cost=0.5) + sag_mask_threshold = xy_grid.AxisOption("SAG Mask Threshold", float, lambda p, x, xs: apply_sag_mask_threshold(p,x,xs), xy_grid.format_value_add_label, None, cost=0.5) + xy_grid.axis_options.extend([sag_guidance_scale, sag_mask_threshold]) + diff --git a/vladmandic-CFGDenoiser-and-script_callbacks-mod-for-SAG.patch b/vladmandic-CFGDenoiser-and-script_callbacks-mod-for-SAG.patch new file mode 100644 index 0000000..151e0ab --- /dev/null +++ b/vladmandic-CFGDenoiser-and-script_callbacks-mod-for-SAG.patch @@ -0,0 +1,121 @@ +From 5f7879230e3cc2c39c432c7e488de30ef824d598 Mon Sep 17 00:00:00 2001 +From: Ashen +Date: Fri, 21 Apr 2023 12:51:42 -0700 +Subject: [PATCH] CFGDenoiser and script_callbacks mod for SAG + +--- + modules/script_callbacks.py | 36 +++++++++++++++++++++++++++++++ + modules/sd_samplers_kdiffusion.py | 8 ++++++- + 2 files changed, 43 insertions(+), 1 deletion(-) + +diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py +index 700911b7..98cb1d98 100644 +--- a/modules/script_callbacks.py ++++ b/modules/script_callbacks.py +@@ -50,6 +50,21 @@ class CFGDenoiserParams: + + + class CFGDenoisedParams: ++ def __init__(self, x, sampling_step, total_sampling_steps, inner_model): ++ self.x = x ++ """Latent image representation in the process of being denoised""" ++ ++ self.sampling_step = sampling_step ++ """Current Sampling step number""" ++ ++ self.total_sampling_steps = total_sampling_steps ++ """Total number of sampling steps planned""" ++ ++ self.inner_model = inner_model ++ """Inner model reference used for denoising""" ++ ++ ++class AfterCFGCallbackParams: + def __init__(self, x, sampling_step, total_sampling_steps): + self.x = x + """Latent image representation in the process of being denoised""" +@@ -60,6 +75,10 @@ class CFGDenoisedParams: + self.total_sampling_steps = total_sampling_steps + """Total number of sampling steps planned""" + ++ self.output_altered = False ++ """A flag for CFGDenoiser indicating whether the output has been altered by the callback""" ++ ++ + + class UiTrainTabParams: + def __init__(self, txt2img_preview_params): +@@ -84,6 +103,7 @@ callback_map = dict( + callbacks_image_saved=[], + callbacks_cfg_denoiser=[], + callbacks_cfg_denoised=[], ++ callbacks_cfg_after_cfg=[], + callbacks_before_component=[], + callbacks_after_component=[], + callbacks_image_grid=[], +@@ -174,6 +194,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams): + report_exception(e, c, 'cfg_denoised_callback') + + ++def cfg_after_cfg_callback(params: AfterCFGCallbackParams): ++ for c in callback_map['callbacks_cfg_after_cfg']: ++ try: ++ c.callback(params) ++ except Exception as e: ++ report_exception(e, c, 'cfg_after_cfg_callback') ++ ++ + def before_component_callback(component, **kwargs): + for c in callback_map['callbacks_before_component']: + try: +@@ -315,6 +343,14 @@ def on_cfg_denoised(callback): + add_callback(callback_map['callbacks_cfg_denoised'], callback) + + ++def on_cfg_after_cfg(callback): ++ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed. ++ The callback is called with one argument: ++ - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation. ++ """ ++ add_callback(callback_map['callbacks_cfg_after_cfg'], callback) ++ ++ + def on_before_component(callback): + """register a function to be called before a component is created. + The callback is called with arguments: +diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py +index cf694098..3e4f882c 100644 +--- a/modules/sd_samplers_kdiffusion.py ++++ b/modules/sd_samplers_kdiffusion.py +@@ -8,6 +8,7 @@ from modules.shared import opts, state + import modules.shared as shared + from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback + from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback ++from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback + + samplers_k_diffusion = [ + ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), +@@ -145,7 +146,7 @@ class CFGDenoiser(torch.nn.Module): + + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:])) + +- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) ++ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) + cfg_denoised_callback(denoised_params) + + devices.test_for_nans(x_out, "unet") +@@ -165,6 +166,11 @@ class CFGDenoiser(torch.nn.Module): + if self.mask is not None: + denoised = self.init_latent * self.mask + self.nmask * denoised + ++ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) ++ cfg_after_cfg_callback(after_cfg_callback_params) ++ if after_cfg_callback_params.output_altered: ++ denoised = after_cfg_callback_params.x ++ + self.step += 1 + + return denoised +-- +2.40.0 +