Skip to content

Commit

Permalink
Merge pull request #10357 from catboxanon/sag
Browse files Browse the repository at this point in the history
Add/modify CFG callbacks for Self-Attention Guidance extension
  • Loading branch information
AUTOMATIC1111 committed May 14, 2023
2 parents 4051d51 + 8abfc95 commit cb9a3a7
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
35 changes: 35 additions & 0 deletions modules/script_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, te


class CFGDenoisedParams:
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
self.x = x
"""Latent image representation in the process of being denoised"""

self.sampling_step = sampling_step
"""Current Sampling step number"""

self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""

self.inner_model = inner_model
"""Inner model reference used for denoising"""


class AfterCFGCallbackParams:
def __init__(self, x, sampling_step, total_sampling_steps):
self.x = x
"""Latent image representation in the process of being denoised"""
Expand All @@ -63,6 +78,9 @@ def __init__(self, x, sampling_step, total_sampling_steps):
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""

self.output_altered = False
"""A flag for CFGDenoiser indicating whether the output has been altered by the callback"""


class UiTrainTabParams:
def __init__(self, txt2img_preview_params):
Expand All @@ -87,6 +105,7 @@ def __init__(self, imgs, cols, rows):
callbacks_image_saved=[],
callbacks_cfg_denoiser=[],
callbacks_cfg_denoised=[],
callbacks_cfg_after_cfg=[],
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
Expand Down Expand Up @@ -186,6 +205,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
report_exception(c, 'cfg_denoised_callback')


def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
for c in callback_map['callbacks_cfg_after_cfg']:
try:
c.callback(params)
except Exception:
report_exception(c, 'cfg_after_cfg_callback')


def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']:
try:
Expand Down Expand Up @@ -332,6 +359,14 @@ def on_cfg_denoised(callback):
add_callback(callback_map['callbacks_cfg_denoised'], callback)


def on_cfg_after_cfg(callback):
"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
The callback is called with one argument:
- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
"""
add_callback(callback_map['callbacks_cfg_after_cfg'], callback)


def on_before_component(callback):
"""register a function to be called before a component is created.
The callback is called with arguments:
Expand Down
8 changes: 7 additions & 1 deletion modules/sd_samplers_kdiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback

samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
Expand Down Expand Up @@ -160,7 +161,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be

denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)

devices.test_for_nans(x_out, "unet")
Expand All @@ -180,6 +181,11 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised

after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
cfg_after_cfg_callback(after_cfg_callback_params)
if after_cfg_callback_params.output_altered:
denoised = after_cfg_callback_params.x

self.step += 1
return denoised

Expand Down

0 comments on commit cb9a3a7

Please sign in to comment.