From 598aef823083397a58257be58bcfea01f1c8849a Mon Sep 17 00:00:00 2001 From: anapnoe <124302297+anapnoe@users.noreply.github.com> Date: Sat, 16 Dec 2023 21:14:48 +0200 Subject: [PATCH] update to latest version --- configs/alt-diffusion-m18-inference.yaml | 73 +++ extensions-builtin/Lora/lora_logger.py | 33 ++ extensions-builtin/Lora/lyco_helpers.py | 47 ++ extensions-builtin/Lora/network.py | 1 + extensions-builtin/Lora/network_glora.py | 33 ++ extensions-builtin/Lora/network_oft.py | 82 ++++ extensions-builtin/Lora/networks.py | 60 ++- .../Lora/ui_extra_networks_lora.py | 7 +- .../scripts/extra_options_section.py | 16 +- extensions-builtin/hypertile/hypertile.py | 351 +++++++++++++++ .../hypertile/scripts/hypertile_script.py | 109 +++++ .../hypertile/scripts/hypertile_xyz.py | 51 +++ .../mobile/javascript/mobile.js | 2 + javascript/dragdrop.js | 2 +- javascript/edit-attention.js | 79 ++-- javascript/extraNetworks.js | 95 +++- javascript/imageviewer.js | 7 +- javascript/inputAccordion.js | 81 ++-- javascript/notification.js | 6 +- javascript/settings.js | 71 +++ javascript/token-counters.js | 26 +- javascript/ui.js | 75 +++- modules/api/api.py | 81 ++-- modules/api/models.py | 39 +- modules/cache.py | 2 +- modules/cmd_args.py | 10 +- modules/config_states.py | 3 +- modules/devices.py | 18 +- modules/errors.py | 18 +- modules/extensions.py | 96 +++- modules/generation_parameters_copypaste.py | 15 +- modules/gfpgan_model.py | 25 +- modules/gitpython_hack.py | 2 +- modules/gradio_extensons.py | 10 + modules/hypernetworks/hypernetwork.py | 4 +- modules/images.py | 19 +- modules/img2img.py | 56 ++- modules/import_hook.py | 11 + modules/initialize.py | 4 +- modules/initialize_util.py | 6 +- modules/launch_utils.py | 42 +- modules/localization.py | 21 +- modules/logging_config.py | 27 +- modules/mac_specific.py | 15 + modules/models/diffusion/ddpm_edit.py | 7 +- modules/options.py | 81 +++- modules/paths.py | 2 +- modules/paths_internal.py | 1 + modules/postprocessing.py | 94 +++- modules/processing.py | 53 ++- modules/processing_scripts/seed.py | 4 +- modules/prompt_parser.py | 9 +- modules/restart.py | 4 +- modules/rng.py | 2 +- modules/script_callbacks.py | 6 +- modules/scripts.py | 137 +++++- modules/scripts_postprocessing.py | 86 +++- modules/sd_disable_initialization.py | 2 +- modules/sd_hijack.py | 41 +- modules/sd_models.py | 70 ++- modules/sd_models_config.py | 5 +- modules/sd_models_types.py | 5 +- modules/sd_samplers_extra.py | 2 +- modules/sd_samplers_timesteps_impl.py | 4 +- modules/sd_unet.py | 14 +- modules/shared_cmd_options.py | 4 +- modules/shared_items.py | 22 +- modules/shared_options.py | 145 +++--- modules/shared_state.py | 2 +- modules/styles.py | 169 +++++-- modules/sub_quadratic_attention.py | 4 +- modules/sysinfo.py | 18 +- modules/textual_inversion/autocrop.py | 239 +++++----- .../textual_inversion/textual_inversion.py | 78 ++-- modules/textual_inversion/ui.py | 7 - modules/txt2img.py | 4 +- modules/ui.py | 420 ++++++------------ modules/ui_common.py | 15 +- modules/ui_extensions.py | 19 +- modules/ui_extra_networks.py | 65 ++- modules/ui_extra_networks_checkpoints.py | 10 +- modules/ui_extra_networks_hypernets.py | 13 +- .../ui_extra_networks_textual_inversion.py | 10 +- modules/ui_extra_networks_user_metadata.py | 2 +- modules/ui_gradio_extensions.py | 6 +- modules/ui_loadsave.py | 24 +- modules/ui_postprocessing.py | 18 +- modules/ui_prompt_styles.py | 32 +- modules/ui_settings.py | 59 ++- modules/ui_toprow.py | 143 ++++++ modules/upscaler.py | 6 +- modules/xlmr_m18.py | 164 +++++++ modules/xpu_specific.py | 59 +++ scripts/postprocessing_caption.py | 30 ++ scripts/postprocessing_codeformer.py | 16 +- .../postprocessing_create_flipped_copies.py | 32 ++ scripts/postprocessing_focal_crop.py | 54 +++ scripts/postprocessing_gfpgan.py | 13 +- scripts/postprocessing_split_oversized.py | 71 +++ scripts/postprocessing_upscale.py | 14 +- scripts/processing_autosized_crop.py | 64 +++ scripts/prompts_from_file.py | 32 +- scripts/xyz_grid.py | 7 +- 103 files changed, 3505 insertions(+), 1045 deletions(-) create mode 100644 configs/alt-diffusion-m18-inference.yaml create mode 100644 extensions-builtin/Lora/lora_logger.py create mode 100644 extensions-builtin/Lora/network_glora.py create mode 100644 extensions-builtin/Lora/network_oft.py create mode 100644 extensions-builtin/hypertile/hypertile.py create mode 100644 extensions-builtin/hypertile/scripts/hypertile_script.py create mode 100644 extensions-builtin/hypertile/scripts/hypertile_xyz.py create mode 100644 javascript/settings.js create mode 100644 modules/ui_toprow.py create mode 100644 modules/xlmr_m18.py create mode 100644 modules/xpu_specific.py create mode 100644 scripts/postprocessing_caption.py create mode 100644 scripts/postprocessing_create_flipped_copies.py create mode 100644 scripts/postprocessing_focal_crop.py create mode 100644 scripts/postprocessing_split_oversized.py create mode 100644 scripts/processing_autosized_crop.py diff --git a/configs/alt-diffusion-m18-inference.yaml b/configs/alt-diffusion-m18-inference.yaml new file mode 100644 index 00000000000..41a031d55f0 --- /dev/null +++ b/configs/alt-diffusion-m18-inference.yaml @@ -0,0 +1,73 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: modules.xlmr_m18.BertSeriesModelWithTransformation + params: + name: "XLMR-Large" diff --git a/extensions-builtin/Lora/lora_logger.py b/extensions-builtin/Lora/lora_logger.py new file mode 100644 index 00000000000..d51de29704f --- /dev/null +++ b/extensions-builtin/Lora/lora_logger.py @@ -0,0 +1,33 @@ +import sys +import copy +import logging + + +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "\033[0;36m", # CYAN + "INFO": "\033[0;32m", # GREEN + "WARNING": "\033[0;33m", # YELLOW + "ERROR": "\033[0;31m", # RED + "CRITICAL": "\033[0;37;41m", # WHITE ON RED + "RESET": "\033[0m", # RESET COLOR + } + + def format(self, record): + colored_record = copy.copy(record) + levelname = colored_record.levelname + seq = self.COLORS.get(levelname, self.COLORS["RESET"]) + colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" + return super().format(colored_record) + + +logger = logging.getLogger("lora") +logger.propagate = False + + +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s") + ) + logger.addHandler(handler) diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py index 279b34bc928..1679a0ce633 100644 --- a/extensions-builtin/Lora/lyco_helpers.py +++ b/extensions-builtin/Lora/lyco_helpers.py @@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid): up = up.reshape(up.size(0), -1) down = down.reshape(down.size(0), -1) return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) + + +# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py +def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: + ''' + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + secon value is a value for weight. + + Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + ''' + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m length or new_m>factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index d8e8dfb7ff0..6021fd8de0f 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -93,6 +93,7 @@ def __init__(self, name, network_on_disk: NetworkOnDisk): self.unet_multiplier = 1.0 self.dyn_dim = None self.modules = {} + self.bundle_embeddings = {} self.mtime = None self.mentioned_name = None diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py new file mode 100644 index 00000000000..492d487078d --- /dev/null +++ b/extensions-builtin/Lora/network_glora.py @@ -0,0 +1,33 @@ + +import network + +class ModuleTypeGLora(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]): + return NetworkModuleGLora(net, weights) + + return None + +# adapted from https://github.com/KohakuBlueleaf/LyCORIS +class NetworkModuleGLora(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.w1a = weights.w["a1.weight"] + self.w1b = weights.w["b1.weight"] + self.w2a = weights.w["a2.weight"] + self.w2b = weights.w["b2.weight"] + + def calc_updown(self, orig_weight): + w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) + w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) + w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + + output_shape = [w1a.size(0), w1b.size(1)] + updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a)) + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py new file mode 100644 index 00000000000..fa647020f0a --- /dev/null +++ b/extensions-builtin/Lora/network_oft.py @@ -0,0 +1,82 @@ +import torch +import network +from lyco_helpers import factorization +from einops import rearrange + + +class ModuleTypeOFT(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]): + return NetworkModuleOFT(net, weights) + + return None + +# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py +class NetworkModuleOFT(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + + super().__init__(net, weights) + + self.lin_module = None + self.org_module: list[torch.Module] = [self.sd_module] + + self.scale = 1.0 + + # kohya-ss + if "oft_blocks" in weights.w.keys(): + self.is_kohya = True + self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) + self.alpha = weights.w["alpha"] # alpha is constraint + self.dim = self.oft_blocks.shape[0] # lora dim + # LyCORIS + elif "oft_diag" in weights.w.keys(): + self.is_kohya = False + self.oft_blocks = weights.w["oft_diag"] + # self.alpha is unused + self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) + + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported + + if is_linear: + self.out_dim = self.sd_module.out_features + elif is_conv: + self.out_dim = self.sd_module.out_channels + elif is_other_linear: + self.out_dim = self.sd_module.embed_dim + + if self.is_kohya: + self.constraint = self.alpha * self.out_dim + self.num_blocks = self.dim + self.block_size = self.out_dim // self.dim + else: + self.constraint = None + self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) + + def calc_updown(self, orig_weight): + oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + eye = torch.eye(self.block_size, device=self.oft_blocks.device) + + if self.is_kohya: + block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) + + R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + + # This errors out for MultiheadAttention, might need to be handled up-stream + merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R, + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + output_shape = orig_weight.shape + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 96f935b236f..629bf85376d 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -5,16 +5,21 @@ import lora_patches import network import network_lora +import network_glora import network_hada import network_ia3 import network_lokr import network_full import network_norm +import network_oft import torch from typing import Union from modules import shared, devices, sd_models, errors, scripts, sd_hijack +import modules.textual_inversion.textual_inversion as textual_inversion + +from lora_logger import logger module_types = [ network_lora.ModuleTypeLora(), @@ -23,6 +28,8 @@ network_lokr.ModuleTypeLokr(), network_full.ModuleTypeFull(), network_norm.ModuleTypeNorm(), + network_glora.ModuleTypeGLora(), + network_oft.ModuleTypeOFT(), ] @@ -149,9 +156,20 @@ def load_network(name, network_on_disk): is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping matched_networks = {} + bundle_embeddings = {} for key_network, weight in sd.items(): - key_network_without_network_parts, network_part = key_network.split(".", 1) + key_network_without_network_parts, _, network_part = key_network.partition(".") + + if key_network_without_network_parts == "bundle_emb": + emb_name, vec_name = network_part.split(".", 1) + emb_dict = bundle_embeddings.get(emb_name, {}) + if vec_name.split('.')[0] == 'string_to_param': + _, k2 = vec_name.split('.', 1) + emb_dict['string_to_param'] = {k2: weight} + else: + emb_dict[vec_name] = weight + bundle_embeddings[emb_name] = emb_dict key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) sd_module = shared.sd_model.network_layer_mapping.get(key, None) @@ -174,6 +192,17 @@ def load_network(name, network_on_disk): key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + # kohya_ss OFT module + elif sd_module is None and "oft_unet" in key_network_without_network_parts: + key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + + # KohakuBlueLeaf OFT module + if sd_module is None and "oft_diag" in key: + key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") + key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + if sd_module is None: keys_failed_to_match[key_network] = key continue @@ -195,6 +224,14 @@ def load_network(name, network_on_disk): net.modules[key] = net_module + embeddings = {} + for emb_name, data in bundle_embeddings.items(): + embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name) + embedding.loaded = None + embeddings[emb_name] = embedding + + net.bundle_embeddings = embeddings + if keys_failed_to_match: logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}") @@ -210,11 +247,15 @@ def purge_networks_from_memory(): def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): + emb_db = sd_hijack.model_hijack.embedding_db already_loaded = {} for net in loaded_networks: if net.name in names: already_loaded[net.name] = net + for emb_name, embedding in net.bundle_embeddings.items(): + if embedding.loaded: + emb_db.register_embedding_by_name(None, shared.sd_model, emb_name) loaded_networks.clear() @@ -257,6 +298,21 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 loaded_networks.append(net) + for emb_name, embedding in net.bundle_embeddings.items(): + if embedding.loaded is None and emb_name in emb_db.word_embeddings: + logger.warning( + f'Skip bundle embedding: "{emb_name}"' + ' as it was already loaded from embeddings folder' + ) + continue + + embedding.loaded = False + if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape: + embedding.loaded = True + emb_db.register_embedding(embedding, shared.sd_model) + else: + emb_db.skipped_embeddings[name] = embedding + if failed_to_load_networks: sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) @@ -418,6 +474,7 @@ def network_forward(module, input, original_forward): def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): self.network_current_names = () self.network_weights_backup = None + self.network_bias_backup = None def network_Linear_forward(self, input): @@ -564,6 +621,7 @@ def infotext_pasted(infotext, params): available_networks = {} available_network_aliases = {} loaded_networks = [] +loaded_bundle_embeddings = {} networks_in_memory = {} available_network_hash_lookup = {} forbidden_network_aliases = {} diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 55409a7829d..df02c663b12 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -17,6 +17,8 @@ def refresh(self): def create_item(self, name, index=None, enable_filter=True): lora_on_disk = networks.available_networks.get(name) + if lora_on_disk is None: + return path, ext = os.path.splitext(lora_on_disk.filename) @@ -66,9 +68,10 @@ def create_item(self, name, index=None, enable_filter=True): return item def list_items(self): - for index, name in enumerate(networks.available_networks): + # instantiate a list to protect against concurrent modification + names = list(networks.available_networks) + for index, name in enumerate(names): item = self.create_item(name, index) - if item is not None: yield item diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index 983f87ff033..ac2c3de4643 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -23,11 +23,12 @@ def ui(self, is_img2img): self.setting_names = [] self.infotext_fields = [] extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img + elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} with gr.Blocks() as interface: - with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group(): + with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols) @@ -64,11 +65,14 @@ def before_process(self, p, *args): p.override_settings[name] = value -shared.options_templates.update(shared.options_section(('ui', "User interface"), { - "extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), - "extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), - "extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(), - "extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui() +shared.options_templates.update(shared.options_section(('settings_in_ui', "Settings in UI", "ui"), { + "settings_in_ui": shared.OptionHTML(""" +This page allows you to add some settings to the main interface of txt2img and img2img tabs. +"""), + "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), + "extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(), + "extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui() })) diff --git a/extensions-builtin/hypertile/hypertile.py b/extensions-builtin/hypertile/hypertile.py new file mode 100644 index 00000000000..0f40e2d3925 --- /dev/null +++ b/extensions-builtin/hypertile/hypertile.py @@ -0,0 +1,351 @@ +""" +Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE +Warn: The patch works well only if the input image has a width and height that are multiples of 128 +Original author: @tfernd Github: https://github.com/tfernd/HyperTile +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable + +from functools import wraps, cache + +import math +import torch.nn as nn +import random + +from einops import rearrange + + +@dataclass +class HypertileParams: + depth = 0 + layer_name = "" + tile_size: int = 0 + swap_size: int = 0 + aspect_ratio: float = 1.0 + forward = None + enabled = False + + + +# TODO add SD-XL layers +DEPTH_LAYERS = { + 0: [ + # SD 1.5 U-Net (diffusers) + "down_blocks.0.attentions.0.transformer_blocks.0.attn1", + "down_blocks.0.attentions.1.transformer_blocks.0.attn1", + "up_blocks.3.attentions.0.transformer_blocks.0.attn1", + "up_blocks.3.attentions.1.transformer_blocks.0.attn1", + "up_blocks.3.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.1.1.transformer_blocks.0.attn1", + "input_blocks.2.1.transformer_blocks.0.attn1", + "output_blocks.9.1.transformer_blocks.0.attn1", + "output_blocks.10.1.transformer_blocks.0.attn1", + "output_blocks.11.1.transformer_blocks.0.attn1", + # SD 1.5 VAE + "decoder.mid_block.attentions.0", + "decoder.mid.attn_1", + ], + 1: [ + # SD 1.5 U-Net (diffusers) + "down_blocks.1.attentions.0.transformer_blocks.0.attn1", + "down_blocks.1.attentions.1.transformer_blocks.0.attn1", + "up_blocks.2.attentions.0.transformer_blocks.0.attn1", + "up_blocks.2.attentions.1.transformer_blocks.0.attn1", + "up_blocks.2.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.4.1.transformer_blocks.0.attn1", + "input_blocks.5.1.transformer_blocks.0.attn1", + "output_blocks.6.1.transformer_blocks.0.attn1", + "output_blocks.7.1.transformer_blocks.0.attn1", + "output_blocks.8.1.transformer_blocks.0.attn1", + ], + 2: [ + # SD 1.5 U-Net (diffusers) + "down_blocks.2.attentions.0.transformer_blocks.0.attn1", + "down_blocks.2.attentions.1.transformer_blocks.0.attn1", + "up_blocks.1.attentions.0.transformer_blocks.0.attn1", + "up_blocks.1.attentions.1.transformer_blocks.0.attn1", + "up_blocks.1.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.7.1.transformer_blocks.0.attn1", + "input_blocks.8.1.transformer_blocks.0.attn1", + "output_blocks.3.1.transformer_blocks.0.attn1", + "output_blocks.4.1.transformer_blocks.0.attn1", + "output_blocks.5.1.transformer_blocks.0.attn1", + ], + 3: [ + # SD 1.5 U-Net (diffusers) + "mid_block.attentions.0.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "middle_block.1.transformer_blocks.0.attn1", + ], +} +# XL layers, thanks for GitHub@gel-crabs for the help +DEPTH_LAYERS_XL = { + 0: [ + # SD 1.5 U-Net (diffusers) + "down_blocks.0.attentions.0.transformer_blocks.0.attn1", + "down_blocks.0.attentions.1.transformer_blocks.0.attn1", + "up_blocks.3.attentions.0.transformer_blocks.0.attn1", + "up_blocks.3.attentions.1.transformer_blocks.0.attn1", + "up_blocks.3.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.4.1.transformer_blocks.0.attn1", + "input_blocks.5.1.transformer_blocks.0.attn1", + "output_blocks.3.1.transformer_blocks.0.attn1", + "output_blocks.4.1.transformer_blocks.0.attn1", + "output_blocks.5.1.transformer_blocks.0.attn1", + # SD 1.5 VAE + "decoder.mid_block.attentions.0", + "decoder.mid.attn_1", + ], + 1: [ + # SD 1.5 U-Net (diffusers) + #"down_blocks.1.attentions.0.transformer_blocks.0.attn1", + #"down_blocks.1.attentions.1.transformer_blocks.0.attn1", + #"up_blocks.2.attentions.0.transformer_blocks.0.attn1", + #"up_blocks.2.attentions.1.transformer_blocks.0.attn1", + #"up_blocks.2.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.4.1.transformer_blocks.1.attn1", + "input_blocks.5.1.transformer_blocks.1.attn1", + "output_blocks.3.1.transformer_blocks.1.attn1", + "output_blocks.4.1.transformer_blocks.1.attn1", + "output_blocks.5.1.transformer_blocks.1.attn1", + "input_blocks.7.1.transformer_blocks.0.attn1", + "input_blocks.8.1.transformer_blocks.0.attn1", + "output_blocks.0.1.transformer_blocks.0.attn1", + "output_blocks.1.1.transformer_blocks.0.attn1", + "output_blocks.2.1.transformer_blocks.0.attn1", + "input_blocks.7.1.transformer_blocks.1.attn1", + "input_blocks.8.1.transformer_blocks.1.attn1", + "output_blocks.0.1.transformer_blocks.1.attn1", + "output_blocks.1.1.transformer_blocks.1.attn1", + "output_blocks.2.1.transformer_blocks.1.attn1", + "input_blocks.7.1.transformer_blocks.2.attn1", + "input_blocks.8.1.transformer_blocks.2.attn1", + "output_blocks.0.1.transformer_blocks.2.attn1", + "output_blocks.1.1.transformer_blocks.2.attn1", + "output_blocks.2.1.transformer_blocks.2.attn1", + "input_blocks.7.1.transformer_blocks.3.attn1", + "input_blocks.8.1.transformer_blocks.3.attn1", + "output_blocks.0.1.transformer_blocks.3.attn1", + "output_blocks.1.1.transformer_blocks.3.attn1", + "output_blocks.2.1.transformer_blocks.3.attn1", + "input_blocks.7.1.transformer_blocks.4.attn1", + "input_blocks.8.1.transformer_blocks.4.attn1", + "output_blocks.0.1.transformer_blocks.4.attn1", + "output_blocks.1.1.transformer_blocks.4.attn1", + "output_blocks.2.1.transformer_blocks.4.attn1", + "input_blocks.7.1.transformer_blocks.5.attn1", + "input_blocks.8.1.transformer_blocks.5.attn1", + "output_blocks.0.1.transformer_blocks.5.attn1", + "output_blocks.1.1.transformer_blocks.5.attn1", + "output_blocks.2.1.transformer_blocks.5.attn1", + "input_blocks.7.1.transformer_blocks.6.attn1", + "input_blocks.8.1.transformer_blocks.6.attn1", + "output_blocks.0.1.transformer_blocks.6.attn1", + "output_blocks.1.1.transformer_blocks.6.attn1", + "output_blocks.2.1.transformer_blocks.6.attn1", + "input_blocks.7.1.transformer_blocks.7.attn1", + "input_blocks.8.1.transformer_blocks.7.attn1", + "output_blocks.0.1.transformer_blocks.7.attn1", + "output_blocks.1.1.transformer_blocks.7.attn1", + "output_blocks.2.1.transformer_blocks.7.attn1", + "input_blocks.7.1.transformer_blocks.8.attn1", + "input_blocks.8.1.transformer_blocks.8.attn1", + "output_blocks.0.1.transformer_blocks.8.attn1", + "output_blocks.1.1.transformer_blocks.8.attn1", + "output_blocks.2.1.transformer_blocks.8.attn1", + "input_blocks.7.1.transformer_blocks.9.attn1", + "input_blocks.8.1.transformer_blocks.9.attn1", + "output_blocks.0.1.transformer_blocks.9.attn1", + "output_blocks.1.1.transformer_blocks.9.attn1", + "output_blocks.2.1.transformer_blocks.9.attn1", + ], + 2: [ + # SD 1.5 U-Net (diffusers) + "mid_block.attentions.0.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "middle_block.1.transformer_blocks.0.attn1", + "middle_block.1.transformer_blocks.1.attn1", + "middle_block.1.transformer_blocks.2.attn1", + "middle_block.1.transformer_blocks.3.attn1", + "middle_block.1.transformer_blocks.4.attn1", + "middle_block.1.transformer_blocks.5.attn1", + "middle_block.1.transformer_blocks.6.attn1", + "middle_block.1.transformer_blocks.7.attn1", + "middle_block.1.transformer_blocks.8.attn1", + "middle_block.1.transformer_blocks.9.attn1", + ], + 3 : [] # TODO - separate layers for SD-XL +} + + +RNG_INSTANCE = random.Random() + +@cache +def get_divisors(value: int, min_value: int, /, max_options: int = 1) -> list[int]: + """ + Returns divisors of value that + x * min_value <= value + in big -> small order, amount of divisors is limited by max_options + """ + max_options = max(1, max_options) # at least 1 option should be returned + min_value = min(min_value, value) + divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order + ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order + return ns + + +def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: + """ + Returns a random divisor of value that + x * min_value <= value + if max_options is 1, the behavior is deterministic + """ + ns = get_divisors(value, min_value, max_options=max_options) # get cached divisors + idx = RNG_INSTANCE.randint(0, len(ns) - 1) + + return ns[idx] + + +def set_hypertile_seed(seed: int) -> None: + RNG_INSTANCE.seed(seed) + + +@cache +def largest_tile_size_available(width: int, height: int) -> int: + """ + Calculates the largest tile size available for a given width and height + Tile size is always a power of 2 + """ + gcd = math.gcd(width, height) + largest_tile_size_available = 1 + while gcd % (largest_tile_size_available * 2) == 0: + largest_tile_size_available *= 2 + return largest_tile_size_available + + +def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]: + """ + Finds h and w such that h*w = hw and h/w = aspect_ratio + We check all possible divisors of hw and return the closest to the aspect ratio + """ + divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw + pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw + ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw + closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio + closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio + return closest_pair + + +@cache +def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]: + """ + Finds h and w such that h*w = hw and h/w = aspect_ratio + """ + h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) + # find h and w such that h*w = hw and h/w = aspect_ratio + if h * w != hw: + w_candidate = hw / h + # check if w is an integer + if not w_candidate.is_integer(): + h_candidate = hw / w + # check if h is an integer + if not h_candidate.is_integer(): + return iterative_closest_divisors(hw, aspect_ratio) + else: + h = int(h_candidate) + else: + w = int(w_candidate) + return h, w + + +def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable: + + @wraps(params.forward) + def wrapper(*args, **kwargs): + if not params.enabled: + return params.forward(*args, **kwargs) + + latent_tile_size = max(128, params.tile_size) // 8 + x = args[0] + + # VAE + if x.ndim == 4: + b, c, h, w = x.shape + + nh = random_divisor(h, latent_tile_size, params.swap_size) + nw = random_divisor(w, latent_tile_size, params.swap_size) + + if nh * nw > 1: + x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles + + out = params.forward(x, *args[1:], **kwargs) + + if nh * nw > 1: + out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw) + + # U-Net + else: + hw: int = x.size(1) + h, w = find_hw_candidates(hw, params.aspect_ratio) + assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}" + + factor = 2 ** params.depth if scale_depth else 1 + nh = random_divisor(h, latent_tile_size * factor, params.swap_size) + nw = random_divisor(w, latent_tile_size * factor, params.swap_size) + + if nh * nw > 1: + x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) + + out = params.forward(x, *args[1:], **kwargs) + + if nh * nw > 1: + out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) + out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) + + return out + + return wrapper + + +def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False): + hypertile_layers = getattr(model, "__webui_hypertile_layers", None) + if hypertile_layers is None: + if not enable: + return + + hypertile_layers = {} + layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS + + for depth in range(4): + for layer_name, module in model.named_modules(): + if any(layer_name.endswith(try_name) for try_name in layers[depth]): + params = HypertileParams() + module.__webui_hypertile_params = params + params.forward = module.forward + params.depth = depth + params.layer_name = layer_name + module.forward = self_attn_forward(params) + + hypertile_layers[layer_name] = 1 + + model.__webui_hypertile_layers = hypertile_layers + + aspect_ratio = width / height + tile_size = min(largest_tile_size_available(width, height), tile_size_max) + + for layer_name, module in model.named_modules(): + if layer_name in hypertile_layers: + params = module.__webui_hypertile_params + + params.tile_size = tile_size + params.swap_size = swap_size + params.aspect_ratio = aspect_ratio + params.enabled = enable and params.depth <= max_depth diff --git a/extensions-builtin/hypertile/scripts/hypertile_script.py b/extensions-builtin/hypertile/scripts/hypertile_script.py new file mode 100644 index 00000000000..52ba655fe37 --- /dev/null +++ b/extensions-builtin/hypertile/scripts/hypertile_script.py @@ -0,0 +1,109 @@ +import hypertile +from modules import scripts, script_callbacks, shared +from scripts.hypertile_xyz import add_axis_options + + +class ScriptHypertile(scripts.Script): + name = "Hypertile" + + def title(self): + return self.name + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def process(self, p, *args): + hypertile.set_hypertile_seed(p.all_seeds[0]) + + configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet) + + self.add_infotext(p) + + def before_hr(self, p, *args): + + enable = shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet + + # exclusive hypertile seed for the second pass + if enable: + hypertile.set_hypertile_seed(p.all_seeds[0]) + + configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=enable) + + if enable and not shared.opts.hypertile_enable_unet: + p.extra_generation_params["Hypertile U-Net second pass"] = True + + self.add_infotext(p, add_unet_params=True) + + def add_infotext(self, p, add_unet_params=False): + def option(name): + value = getattr(shared.opts, name) + default_value = shared.opts.get_default(name) + return None if value == default_value else value + + if shared.opts.hypertile_enable_unet: + p.extra_generation_params["Hypertile U-Net"] = True + + if shared.opts.hypertile_enable_unet or add_unet_params: + p.extra_generation_params["Hypertile U-Net max depth"] = option('hypertile_max_depth_unet') + p.extra_generation_params["Hypertile U-Net max tile size"] = option('hypertile_max_tile_unet') + p.extra_generation_params["Hypertile U-Net swap size"] = option('hypertile_swap_size_unet') + + if shared.opts.hypertile_enable_vae: + p.extra_generation_params["Hypertile VAE"] = True + p.extra_generation_params["Hypertile VAE max depth"] = option('hypertile_max_depth_vae') + p.extra_generation_params["Hypertile VAE max tile size"] = option('hypertile_max_tile_vae') + p.extra_generation_params["Hypertile VAE swap size"] = option('hypertile_swap_size_vae') + + +def configure_hypertile(width, height, enable_unet=True): + hypertile.hypertile_hook_model( + shared.sd_model.first_stage_model, + width, + height, + swap_size=shared.opts.hypertile_swap_size_vae, + max_depth=shared.opts.hypertile_max_depth_vae, + tile_size_max=shared.opts.hypertile_max_tile_vae, + enable=shared.opts.hypertile_enable_vae, + ) + + hypertile.hypertile_hook_model( + shared.sd_model.model, + width, + height, + swap_size=shared.opts.hypertile_swap_size_unet, + max_depth=shared.opts.hypertile_max_depth_unet, + tile_size_max=shared.opts.hypertile_max_tile_unet, + enable=enable_unet, + is_sdxl=shared.sd_model.is_sdxl + ) + + +def on_ui_settings(): + import gradio as gr + + options = { + "hypertile_explanation": shared.OptionHTML(""" + Hypertile optimizes the self-attention layer within U-Net and VAE models, + resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the + benefit. + """), + + "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net", infotext="Hypertile U-Net").info("enables hypertile for all modes, including hires fix second pass; noticeable change in details of the generated picture"), + "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass", infotext="Hypertile U-Net second pass").info("enables hypertile just for hires fix second pass - regardless of whether the above setting is enabled"), + "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile U-Net max depth").info("larger = more neural network layers affected; minor effect on performance"), + "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-Net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile U-Net max tile size").info("larger = worse performance"), + "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-Net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile U-Net swap size"), + + "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE", infotext="Hypertile VAE").info("minimal change in the generated picture"), + "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile VAE max depth"), + "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile VAE max tile size"), + "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile VAE swap size"), + } + + for name, opt in options.items(): + opt.section = ('hypertile', "Hypertile") + shared.opts.add_option(name, opt) + + +script_callbacks.on_ui_settings(on_ui_settings) +script_callbacks.on_before_ui(add_axis_options) diff --git a/extensions-builtin/hypertile/scripts/hypertile_xyz.py b/extensions-builtin/hypertile/scripts/hypertile_xyz.py new file mode 100644 index 00000000000..9e96ae3c527 --- /dev/null +++ b/extensions-builtin/hypertile/scripts/hypertile_xyz.py @@ -0,0 +1,51 @@ +from modules import scripts +from modules.shared import opts + +xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module + +def int_applier(value_name:str, min_range:int = -1, max_range:int = -1): + """ + Returns a function that applies the given value to the given value_name in opts.data. + """ + def validate(value_name:str, value:str): + value = int(value) + # validate value + if not min_range == -1: + assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}" + if not max_range == -1: + assert value <= max_range, f"Value {value} for {value_name} must be less than or equal to {max_range}" + def apply_int(p, x, xs): + validate(value_name, x) + opts.data[value_name] = int(x) + return apply_int + +def bool_applier(value_name:str): + """ + Returns a function that applies the given value to the given value_name in opts.data. + """ + def validate(value_name:str, value:str): + assert value.lower() in ["true", "false"], f"Value {value} for {value_name} must be either true or false" + def apply_bool(p, x, xs): + validate(value_name, x) + value_boolean = x.lower() == "true" + opts.data[value_name] = value_boolean + return apply_bool + +def add_axis_options(): + extra_axis_options = [ + xyz_grid.AxisOption("[Hypertile] Unet First pass Enabled", str, bool_applier("hypertile_enable_unet"), choices=xyz_grid.boolean_choice(reverse=True)), + xyz_grid.AxisOption("[Hypertile] Unet Second pass Enabled", str, bool_applier("hypertile_enable_unet_secondpass"), choices=xyz_grid.boolean_choice(reverse=True)), + xyz_grid.AxisOption("[Hypertile] Unet Max Depth", int, int_applier("hypertile_max_depth_unet", 0, 3), choices=lambda: [str(x) for x in range(4)]), + xyz_grid.AxisOption("[Hypertile] Unet Max Tile Size", int, int_applier("hypertile_max_tile_unet", 0, 512)), + xyz_grid.AxisOption("[Hypertile] Unet Swap Size", int, int_applier("hypertile_swap_size_unet", 0, 64)), + xyz_grid.AxisOption("[Hypertile] VAE Enabled", str, bool_applier("hypertile_enable_vae"), choices=xyz_grid.boolean_choice(reverse=True)), + xyz_grid.AxisOption("[Hypertile] VAE Max Depth", int, int_applier("hypertile_max_depth_vae", 0, 3), choices=lambda: [str(x) for x in range(4)]), + xyz_grid.AxisOption("[Hypertile] VAE Max Tile Size", int, int_applier("hypertile_max_tile_vae", 0, 512)), + xyz_grid.AxisOption("[Hypertile] VAE Swap Size", int, int_applier("hypertile_swap_size_vae", 0, 64)), + ] + set_a = {opt.label for opt in xyz_grid.axis_options} + set_b = {opt.label for opt in extra_axis_options} + if set_a.intersection(set_b): + return + + xyz_grid.axis_options.extend(extra_axis_options) diff --git a/extensions-builtin/mobile/javascript/mobile.js b/extensions-builtin/mobile/javascript/mobile.js index 652f07ac7ec..bff1acedff3 100644 --- a/extensions-builtin/mobile/javascript/mobile.js +++ b/extensions-builtin/mobile/javascript/mobile.js @@ -12,6 +12,8 @@ function isMobile() { } function reportWindowSize() { + if (gradioApp().querySelector('.toprow-compact-tools')) return; // not applicable for compact prompt layout + var currentlyMobile = isMobile(); if (currentlyMobile == isSetupForMobile) return; isSetupForMobile = currentlyMobile; diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js index 5803daea5ef..d680daf52f2 100644 --- a/javascript/dragdrop.js +++ b/javascript/dragdrop.js @@ -119,7 +119,7 @@ window.addEventListener('paste', e => { } const firstFreeImageField = visibleImageFields - .filter(el => el.querySelector('input[type=file]'))?.[0]; + .filter(el => !el.querySelector('img'))?.[0]; dropReplaceImage( firstFreeImageField ? diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index 8906c8922e1..688c2f112d6 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -18,37 +18,43 @@ function keyupEditAttention(event) { const before = text.substring(0, selectionStart); let beforeParen = before.lastIndexOf(OPEN); if (beforeParen == -1) return false; - let beforeParenClose = before.lastIndexOf(CLOSE); - while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { - beforeParen = before.lastIndexOf(OPEN, beforeParen - 1); - beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1); - } + + let beforeClosingParen = before.lastIndexOf(CLOSE); + if (beforeClosingParen != -1 && beforeClosingParen > beforeParen) return false; // Find closing parenthesis around current cursor const after = text.substring(selectionStart); let afterParen = after.indexOf(CLOSE); if (afterParen == -1) return false; - let afterParenOpen = after.indexOf(OPEN); - while (afterParenOpen !== -1 && afterParen > afterParenOpen) { - afterParen = after.indexOf(CLOSE, afterParen + 1); - afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1); - } - if (beforeParen === -1 || afterParen === -1) return false; + + let afterOpeningParen = after.indexOf(OPEN); + if (afterOpeningParen != -1 && afterOpeningParen < afterParen) return false; // Set the selection to the text between the parenthesis const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); - const lastColon = parenContent.lastIndexOf(":"); - selectionStart = beforeParen + 1; - selectionEnd = selectionStart + lastColon; + if (/.*:-?[\d.]+/s.test(parenContent)) { + const lastColon = parenContent.lastIndexOf(":"); + selectionStart = beforeParen + 1; + selectionEnd = selectionStart + lastColon; + } else { + selectionStart = beforeParen + 1; + selectionEnd = selectionStart + parenContent.length; + } + target.setSelectionRange(selectionStart, selectionEnd); return true; } function selectCurrentWord() { if (selectionStart !== selectionEnd) return false; - const delimiters = opts.keyedit_delimiters + " \r\n\t"; + const whitespace_delimiters = {"Tab": "\t", "Carriage Return": "\r", "Line Feed": "\n"}; + let delimiters = opts.keyedit_delimiters; + + for (let i of opts.keyedit_delimiters_whitespace) { + delimiters += whitespace_delimiters[i]; + } - // seek backward until to find beggining + // seek backward to find beginning while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) { selectionStart--; } @@ -63,7 +69,7 @@ function keyupEditAttention(event) { } // If the user hasn't selected anything, let's select their current parenthesis block or word - if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) { + if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')') && !selectCurrentParenthesisBlock('[', ']')) { selectCurrentWord(); } @@ -71,33 +77,54 @@ function keyupEditAttention(event) { var closeCharacter = ')'; var delta = opts.keyedit_precision_attention; + var start = selectionStart > 0 ? text[selectionStart - 1] : ""; + var end = text[selectionEnd]; - if (selectionStart > 0 && text[selectionStart - 1] == '<') { + if (start == '<') { closeCharacter = '>'; delta = opts.keyedit_precision_extra; - } else if (selectionStart == 0 || text[selectionStart - 1] != "(") { + } else if (start == '(' && end == ')' || start == '[' && end == ']') { // convert old-style (((emphasis))) + let numParen = 0; + + while (text[selectionStart - numParen - 1] == start && text[selectionEnd + numParen] == end) { + numParen++; + } + if (start == "[") { + weight = (1 / 1.1) ** numParen; + } else { + weight = 1.1 ** numParen; + } + + weight = Math.round(weight / opts.keyedit_precision_attention) * opts.keyedit_precision_attention; + + text = text.slice(0, selectionStart - numParen) + "(" + text.slice(selectionStart, selectionEnd) + ":" + weight + ")" + text.slice(selectionEnd + numParen); + selectionStart -= numParen - 1; + selectionEnd -= numParen - 1; + } else if (start != '(') { // do not include spaces at the end while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') { - selectionEnd -= 1; + selectionEnd--; } + if (selectionStart == selectionEnd) { return; } text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd); - selectionStart += 1; - selectionEnd += 1; + selectionStart++; + selectionEnd++; } - var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; - var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end)); + if (text[selectionEnd] != ':') return; + var weightLength = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; + var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + weightLength)); if (isNaN(weight)) return; weight += isPlus ? delta : -delta; weight = parseFloat(weight.toPrecision(12)); - if (String(weight).length == 1) weight += ".0"; + if (Number.isInteger(weight)) weight += ".0"; if (closeCharacter == ')' && weight == 1) { var endParenPos = text.substring(selectionEnd).indexOf(')'); @@ -105,7 +132,7 @@ function keyupEditAttention(event) { selectionStart--; selectionEnd--; } else { - text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end); + text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + weightLength); } target.focus(); diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 493f31af28a..98a7abb745c 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -26,8 +26,9 @@ function setupExtraNetworksForTab(tabname) { var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs'); var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input'); + var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container'); + var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt'); - sort.dataset.sortkey = 'sortDefault'; tabs.appendChild(searchDiv); tabs.appendChild(sort); tabs.appendChild(sortOrder); @@ -49,20 +50,23 @@ function setupExtraNetworksForTab(tabname) { elem.style.display = visible ? "" : "none"; }); + + applySort(); }; var applySort = function() { + var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); + var reverse = sortOrder.classList.contains("sortReverse"); - var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim(); - sortKey = sortKey ? "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1) : ""; - var sortKeyStore = sortKey ? sortKey + (reverse ? "Reverse" : "") : ""; - if (!sortKey || sortKeyStore == sort.dataset.sortkey) { + var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; + sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); + var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; + + if (sortKeyStore == sort.dataset.sortkey) { return; } - sort.dataset.sortkey = sortKeyStore; - var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); cards.forEach(function(card) { card.originalParentElement = card.parentElement; }); @@ -88,15 +92,13 @@ function setupExtraNetworksForTab(tabname) { }; search.addEventListener("input", applyFilter); - applyFilter(); - ["change", "blur", "click"].forEach(function(evt) { - sort.querySelector("input").addEventListener(evt, applySort); - }); sortOrder.addEventListener("click", function() { sortOrder.classList.toggle("sortReverse"); applySort(); }); + applyFilter(); + extraNetworksApplySort[tabname] = applySort; extraNetworksApplyFilter[tabname] = applyFilter; var showDirsUpdate = function() { @@ -109,11 +111,51 @@ function setupExtraNetworksForTab(tabname) { showDirsUpdate(); } +function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) { + if (!gradioApp().querySelector('.toprow-compact-tools')) return; // only applicable for compact prompt layout + + var promptContainer = gradioApp().getElementById(tabname + '_prompt_container'); + var prompt = gradioApp().getElementById(tabname + '_prompt_row'); + var negPrompt = gradioApp().getElementById(tabname + '_neg_prompt_row'); + var elem = id ? gradioApp().getElementById(id) : null; + + if (showNegativePrompt && elem) { + elem.insertBefore(negPrompt, elem.firstChild); + } else { + promptContainer.insertBefore(negPrompt, promptContainer.firstChild); + } + + if (showPrompt && elem) { + elem.insertBefore(prompt, elem.firstChild); + } else { + promptContainer.insertBefore(prompt, promptContainer.firstChild); + } + + if (elem) { + elem.classList.toggle('extra-page-prompts-active', showNegativePrompt || showPrompt); + } +} + + +function extraNetworksUrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate) + extraNetworksMovePromptToTab(tabname, '', false, false); +} + +function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab + extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt); + +} + function applyExtraNetworkFilter(tabname) { setTimeout(extraNetworksApplyFilter[tabname], 1); } +function applyExtraNetworkSort(tabname) { + setTimeout(extraNetworksApplySort[tabname], 1); +} + var extraNetworksApplyFilter = {}; +var extraNetworksApplySort = {}; var activePromptTextarea = {}; function setupExtraNetworks() { @@ -140,14 +182,15 @@ function setupExtraNetworks() { onUiLoaded(setupExtraNetworks); -var re_extranet = /<([^:]+:[^:]+):[\d.]+>(.*)/; -var re_extranet_g = /\s+<([^:]+:[^:]+):[\d.]+>/g; +var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/; +var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g; function tryToRemoveExtraNetworkFromPrompt(textarea, text) { var m = text.match(re_extranet); var replaced = false; var newTextareaText; if (m) { + var extraTextBeforeNet = opts.extra_networks_add_text_separator; var extraTextAfterNet = m[2]; var partToSearch = m[1]; var foundAtPosition = -1; @@ -161,8 +204,13 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { return found; }); - if (foundAtPosition >= 0 && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { - newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); + if (foundAtPosition >= 0) { + if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { + newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); + } + if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) { + newTextareaText = newTextareaText.substr(0, foundAtPosition - extraTextBeforeNet.length) + newTextareaText.substr(foundAtPosition); + } } } else { newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) { @@ -216,27 +264,24 @@ function extraNetworksSearchButton(tabs_id, event) { var globalPopup = null; var globalPopupInner = null; + function closePopup() { if (!globalPopup) return; - globalPopup.style.display = "none"; } + function popup(contents) { if (!globalPopup) { globalPopup = document.createElement('div'); - globalPopup.onclick = closePopup; globalPopup.classList.add('global-popup'); var close = document.createElement('div'); close.classList.add('global-popup-close'); - close.onclick = closePopup; + close.addEventListener("click", closePopup); close.title = "Close"; globalPopup.appendChild(close); globalPopupInner = document.createElement('div'); - globalPopupInner.onclick = function(event) { - event.stopPropagation(); return false; - }; globalPopupInner.classList.add('global-popup-inner'); globalPopup.appendChild(globalPopupInner); @@ -335,7 +380,7 @@ function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) { function extraNetworksRefreshSingleCard(page, tabname, name) { requestGet("./sd_extra_networks/get-single-card", {page: page, tabname: tabname, name: name}, function(data) { if (data && data.html) { - var card = gradioApp().querySelector('.card[data-name=' + JSON.stringify(name) + ']'); // likely using the wrong stringify function + var card = gradioApp().querySelector(`#${tabname}_${page.replace(" ", "_")}_cards > .card[data-name="${name}"]`); var newDiv = document.createElement('DIV'); newDiv.innerHTML = data.html; @@ -347,3 +392,9 @@ function extraNetworksRefreshSingleCard(page, tabname, name) { } }); } + +window.addEventListener("keydown", function(event) { + if (event.key == "Escape") { + closePopup(); + } +}); diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index c21d396eefd..625c5d148df 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -33,8 +33,11 @@ function updateOnBackgroundChange() { const modalImage = gradioApp().getElementById("modalImage"); if (modalImage && modalImage.offsetParent) { let currentButton = selected_gallery_button(); - - if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { + let preview = gradioApp().querySelectorAll('.livePreview > img'); + if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) { + // show preview image if available + modalImage.src = preview[preview.length - 1].src; + } else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { modalImage.src = currentButton.children[0].src; if (modalImage.style.display === 'none') { const modal = gradioApp().getElementById("lightboxModal"); diff --git a/javascript/inputAccordion.js b/javascript/inputAccordion.js index f2839852ee7..7570309aa73 100644 --- a/javascript/inputAccordion.js +++ b/javascript/inputAccordion.js @@ -1,37 +1,68 @@ -var observerAccordionOpen = new MutationObserver(function(mutations) { - mutations.forEach(function(mutationRecord) { - var elem = mutationRecord.target; - var open = elem.classList.contains('open'); +function inputAccordionChecked(id, checked) { + var accordion = gradioApp().getElementById(id); + accordion.visibleCheckbox.checked = checked; + accordion.onVisibleCheckboxChange(); +} - var accordion = elem.parentNode; - accordion.classList.toggle('input-accordion-open', open); +function setupAccordion(accordion) { + var labelWrap = accordion.querySelector('.label-wrap'); + var gradioCheckbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input"); + var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); + var span = labelWrap.querySelector('span'); + var linked = true; - var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input"); - checkbox.checked = open; - updateInput(checkbox); + var isOpen = function() { + return labelWrap.classList.contains('open'); + }; - var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); - if (extra) { - extra.style.display = open ? "" : "none"; - } + var observerAccordionOpen = new MutationObserver(function(mutations) { + mutations.forEach(function(mutationRecord) { + accordion.classList.toggle('input-accordion-open', isOpen()); + + if (linked) { + accordion.visibleCheckbox.checked = isOpen(); + accordion.onVisibleCheckboxChange(); + } + }); }); -}); + observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']}); -function inputAccordionChecked(id, checked) { - var label = gradioApp().querySelector('#' + id + " .label-wrap"); - if (label.classList.contains('open') != checked) { - label.click(); + if (extra) { + labelWrap.insertBefore(extra, labelWrap.lastElementChild); } + + accordion.onChecked = function(checked) { + if (isOpen() != checked) { + labelWrap.click(); + } + }; + + var visibleCheckbox = document.createElement('INPUT'); + visibleCheckbox.type = 'checkbox'; + visibleCheckbox.checked = isOpen(); + visibleCheckbox.id = accordion.id + "-visible-checkbox"; + visibleCheckbox.className = gradioCheckbox.className + " input-accordion-checkbox"; + span.insertBefore(visibleCheckbox, span.firstChild); + + accordion.visibleCheckbox = visibleCheckbox; + accordion.onVisibleCheckboxChange = function() { + if (linked && isOpen() != visibleCheckbox.checked) { + labelWrap.click(); + } + + gradioCheckbox.checked = visibleCheckbox.checked; + updateInput(gradioCheckbox); + }; + + visibleCheckbox.addEventListener('click', function(event) { + linked = false; + event.stopPropagation(); + }); + visibleCheckbox.addEventListener('input', accordion.onVisibleCheckboxChange); } onUiLoaded(function() { for (var accordion of gradioApp().querySelectorAll('.input-accordion')) { - var labelWrap = accordion.querySelector('.label-wrap'); - observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']}); - - var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); - if (extra) { - labelWrap.insertBefore(extra, labelWrap.lastElementChild); - } + setupAccordion(accordion); } }); diff --git a/javascript/notification.js b/javascript/notification.js index 6d79956125c..3ee972ae166 100644 --- a/javascript/notification.js +++ b/javascript/notification.js @@ -26,7 +26,11 @@ onAfterUiUpdate(function() { lastHeadImg = headImg; // play notification sound if available - gradioApp().querySelector('#audio_notification audio')?.play(); + const notificationAudio = gradioApp().querySelector('#audio_notification audio'); + if (notificationAudio) { + notificationAudio.volume = opts.notification_volume / 100.0 || 1.0; + notificationAudio.play(); + } if (document.hasFocus()) return; diff --git a/javascript/settings.js b/javascript/settings.js new file mode 100644 index 00000000000..e6009290ab3 --- /dev/null +++ b/javascript/settings.js @@ -0,0 +1,71 @@ +let settingsExcludeTabsFromShowAll = { + settings_tab_defaults: 1, + settings_tab_sysinfo: 1, + settings_tab_actions: 1, + settings_tab_licenses: 1, +}; + +function settingsShowAllTabs() { + gradioApp().querySelectorAll('#settings > div').forEach(function(elem) { + if (settingsExcludeTabsFromShowAll[elem.id]) return; + + elem.style.display = "block"; + }); +} + +function settingsShowOneTab() { + gradioApp().querySelector('#settings_show_one_page').click(); +} + +onUiLoaded(function() { + var edit = gradioApp().querySelector('#settings_search'); + var editTextarea = gradioApp().querySelector('#settings_search > label > input'); + var buttonShowAllPages = gradioApp().getElementById('settings_show_all_pages'); + var settings_tabs = gradioApp().querySelector('#settings div'); + + onEdit('settingsSearch', editTextarea, 250, function() { + var searchText = (editTextarea.value || "").trim().toLowerCase(); + + gradioApp().querySelectorAll('#settings > div[id^=settings_] div[id^=column_settings_] > *').forEach(function(elem) { + var visible = elem.textContent.trim().toLowerCase().indexOf(searchText) != -1; + elem.style.display = visible ? "" : "none"; + }); + + if (searchText != "") { + settingsShowAllTabs(); + } else { + settingsShowOneTab(); + } + }); + + settings_tabs.insertBefore(edit, settings_tabs.firstChild); + settings_tabs.appendChild(buttonShowAllPages); + + + buttonShowAllPages.addEventListener("click", settingsShowAllTabs); +}); + + +onOptionsChanged(function() { + if (gradioApp().querySelector('#settings .settings-category')) return; + + var sectionMap = {}; + gradioApp().querySelectorAll('#settings > div > button').forEach(function(x) { + sectionMap[x.textContent.trim()] = x; + }); + + opts._categories.forEach(function(x) { + var section = x[0]; + var category = x[1]; + + var span = document.createElement('SPAN'); + span.textContent = category; + span.className = 'settings-category'; + + var sectionElem = sectionMap[section]; + if (!sectionElem) return; + + sectionElem.parentElement.insertBefore(span, sectionElem); + }); +}); + diff --git a/javascript/token-counters.js b/javascript/token-counters.js index 9d81a723b01..2ecc7d91010 100644 --- a/javascript/token-counters.js +++ b/javascript/token-counters.js @@ -1,10 +1,9 @@ -let promptTokenCountDebounceTime = 800; -let promptTokenCountTimeouts = {}; -var promptTokenCountUpdateFunctions = {}; +let promptTokenCountUpdateFunctions = {}; function update_txt2img_tokens(...args) { // Called from Gradio update_token_counter("txt2img_token_button"); + update_token_counter("txt2img_negative_token_button"); if (args.length == 2) { return args[0]; } @@ -14,6 +13,7 @@ function update_txt2img_tokens(...args) { function update_img2img_tokens(...args) { // Called from Gradio update_token_counter("img2img_token_button"); + update_token_counter("img2img_negative_token_button"); if (args.length == 2) { return args[0]; } @@ -21,16 +21,7 @@ function update_img2img_tokens(...args) { } function update_token_counter(button_id) { - if (opts.disable_token_counters) { - return; - } - if (promptTokenCountTimeouts[button_id]) { - clearTimeout(promptTokenCountTimeouts[button_id]); - } - promptTokenCountTimeouts[button_id] = setTimeout( - () => gradioApp().getElementById(button_id)?.click(), - promptTokenCountDebounceTime, - ); + promptTokenCountUpdateFunctions[button_id]?.(); } @@ -69,10 +60,11 @@ function setupTokenCounting(id, id_counter, id_button) { prompt.parentElement.insertBefore(counter, prompt); prompt.parentElement.style.position = "relative"; - promptTokenCountUpdateFunctions[id] = function() { - update_token_counter(id_button); - }; - textarea.addEventListener("input", promptTokenCountUpdateFunctions[id]); + var func = onEdit(id, textarea, 800, function() { + gradioApp().getElementById(id_button)?.click(); + }); + promptTokenCountUpdateFunctions[id] = func; + promptTokenCountUpdateFunctions[id_button] = func; } function setupTokenCounters() { diff --git a/javascript/ui.js b/javascript/ui.js index dbab8334f92..18c9f891afc 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -9,7 +9,7 @@ function set_theme(theme) { function all_gallery_buttons() { var allGalleryButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnails > .thumbnail-item.thumbnail-small'); - var visibleGalleryButtons = []; + var visibleGalleryButtons = []; allGalleryButtons.forEach(function(elem) { if (elem.parentElement.offsetParent) { visibleGalleryButtons.push(elem); @@ -170,6 +170,23 @@ function submit_img2img() { return res; } +function submit_extras() { + showSubmitButtons('extras', false); + + var id = randomId(); + + requestProgress(id, gradioApp().getElementById('extras_gallery_container'), gradioApp().getElementById('extras_gallery'), function() { + showSubmitButtons('extras', true); + }); + + var res = create_submit_args(arguments); + + res[0] = id; + + console.log(res); + return res; +} + function restoreProgressTxt2img() { showRestoreProgressButton("txt2img", false); var id = localGet("txt2img_task_id"); @@ -198,9 +215,33 @@ function restoreProgressImg2img() { } +/** + * Configure the width and height elements on `tabname` to accept + * pasting of resolutions in the form of "width x height". + */ +function setupResolutionPasting(tabname) { + var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`); + var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`); + for (const el of [width, height]) { + el.addEventListener('paste', function(event) { + var pasteData = event.clipboardData.getData('text/plain'); + var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/); + if (parsed) { + width.value = parsed[1]; + height.value = parsed[2]; + updateInput(width); + updateInput(height); + event.preventDefault(); + } + }); + } +} + onUiLoaded(function() { showRestoreProgressButton('txt2img', localGet("txt2img_task_id")); showRestoreProgressButton('img2img', localGet("img2img_task_id")); + setupResolutionPasting('txt2img'); + setupResolutionPasting('img2img'); }); @@ -263,21 +304,6 @@ onAfterUiUpdate(function() { json_elem.parentElement.style.display = "none"; setupTokenCounters(); - - var show_all_pages = gradioApp().getElementById('settings_show_all_pages'); - var settings_tabs = gradioApp().querySelector('#settings div'); - if (show_all_pages && settings_tabs) { - settings_tabs.appendChild(show_all_pages); - show_all_pages.onclick = function() { - gradioApp().querySelectorAll('#settings > div').forEach(function(elem) { - if (elem.id == "settings_tab_licenses") { - return; - } - - elem.style.display = "block"; - }); - }; - } }); onOptionsChanged(function() { @@ -366,3 +392,20 @@ function switchWidthHeight(tabname) { updateInput(height); return []; } + + +var onEditTimers = {}; + +// calls func after afterMs milliseconds has passed since the input elem has beed enited by user +function onEdit(editId, elem, afterMs, func) { + var edited = function() { + var existingTimer = onEditTimers[editId]; + if (existingTimer) clearTimeout(existingTimer); + + onEditTimers[editId] = setTimeout(func, afterMs); + }; + + elem.addEventListener("input", edited); + + return edited; +} diff --git a/modules/api/api.py b/modules/api/api.py index e6edffe7144..b3d74e513a3 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,19 +17,17 @@ from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding -from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork -from PIL import PngImagePlugin,Image -from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases +from PIL import PngImagePlugin, Image from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices -from typing import Dict, List, Any +from typing import Any import piexif import piexif.helper from contextlib import closing @@ -103,7 +101,8 @@ def decode_base64_to_image(encoding): def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: - + if isinstance(image, str): + return image if opts.samples_format.lower() == 'png': use_metadata = False metadata = PngImagePlugin.PngInfo() @@ -221,28 +220,28 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) - self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) - self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) - self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) - self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) - self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) - self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) - self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) - self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) - self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) + self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem]) + self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem]) + self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem]) + self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem]) + self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem]) + self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem]) + self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem]) + self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem]) + self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem]) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) - self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) - self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) + self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo]) + self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem]) if shared.cmd_opts.api_server_stop: self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) @@ -473,9 +472,6 @@ def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest): return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) def pnginfoapi(self, req: models.PNGInfoRequest): - if(not req.image.strip()): - return models.PNGInfoResponse(info="") - image = decode_base64_to_image(req.image.strip()) if image is None: return models.PNGInfoResponse(info="") @@ -484,9 +480,10 @@ def pnginfoapi(self, req: models.PNGInfoRequest): if geninfo is None: geninfo = "" - items = {**{'parameters': geninfo}, **items} + params = generation_parameters_copypaste.parse_generation_parameters(geninfo) + script_callbacks.infotext_pasted_callback(geninfo, params) - return models.PNGInfoResponse(info=geninfo, items=items) + return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) def progressapi(self, req: models.ProgressRequest = Depends()): # copy from check_progress_call of ui.py @@ -541,12 +538,12 @@ def interruptapi(self): return {} def unloadapi(self): - unload_model_weights() + sd_models.unload_model_weights() return {} def reloadapi(self): - reload_model_weights() + sd_models.send_model_to_device(shared.sd_model) return {} @@ -564,9 +561,9 @@ def get_config(self): return options - def set_config(self, req: Dict[str, Any]): + def set_config(self, req: dict[str, Any]): checkpoint_name = req.get("sd_model_checkpoint", None) - if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases: + if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases: raise RuntimeError(f"model {checkpoint_name!r} not found") for k, v in req.items(): @@ -676,19 +673,6 @@ def create_hypernetwork(self, args: dict): finally: shared.state.end() - def preprocess(self, args: dict): - try: - shared.state.begin(job="preprocess") - preprocess(**args) # quick operation unless blip/booru interrogation is enabled - shared.state.end() - return models.PreprocessResponse(info='preprocess complete') - except KeyError as e: - return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}") - except Exception as e: - return models.PreprocessResponse(info=f"preprocess error: {e}") - finally: - shared.state.end() - def train_embedding(self, args: dict): try: shared.state.begin(job="train_embedding") @@ -770,6 +754,25 @@ def get_memory(self): cuda = {'error': f'{err}'} return models.MemoryResponse(ram=ram, cuda=cuda) + def get_extensions_list(self): + from modules import extensions + extensions.list_extensions() + ext_list = [] + for ext in extensions.extensions: + ext: extensions.Extension + ext.read_info_from_repo() + if ext.remote is not None: + ext_list.append({ + "name": ext.name, + "remote": ext.remote, + "branch": ext.branch, + "commit_hash":ext.commit_hash, + "commit_date":ext.commit_date, + "version":ext.version, + "enabled":ext.enabled + }) + return ext_list + def launch(self, server_name, port, root_path): self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path) diff --git a/modules/api/models.py b/modules/api/models.py index 6a574771c33..33894b3e694 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,12 +1,10 @@ import inspect from pydantic import BaseModel, Field, create_model -from typing import Any, Optional -from typing_extensions import Literal +from typing import Any, Optional, Literal from inflection import underscore from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img from modules.shared import sd_upscalers, opts, parser -from typing import Dict, List API_NOT_ALLOWED = [ "self", @@ -130,12 +128,12 @@ def generate_model(self): ).generate_model() class TextToImageResponse(BaseModel): - images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: dict info: str class ImageToImageResponse(BaseModel): - images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: dict info: str @@ -168,17 +166,18 @@ class FileData(BaseModel): name: str = Field(title="File name") class ExtrasBatchImagesRequest(ExtrasBaseRequest): - imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") + imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") class ExtrasBatchImagesResponse(ExtraBaseResponse): - images: List[str] = Field(title="Images", description="The generated images in base64 format.") + images: list[str] = Field(title="Images", description="The generated images in base64 format.") class PNGInfoRequest(BaseModel): image: str = Field(title="Image", description="The base64 encoded PNG image") class PNGInfoResponse(BaseModel): info: str = Field(title="Image info", description="A string with the parameters used to generate the image") - items: dict = Field(title="Items", description="An object containing all the info the image had") + items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had") + parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields") class ProgressRequest(BaseModel): skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization") @@ -203,9 +202,6 @@ class TrainResponse(BaseModel): class CreateResponse(BaseModel): info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.") -class PreprocessResponse(BaseModel): - info: str = Field(title="Preprocess info", description="Response string from preprocessing task.") - fields = {} for key, metadata in opts.data_labels.items(): value = opts.data.get(key) @@ -232,8 +228,8 @@ class PreprocessResponse(BaseModel): class SamplerItem(BaseModel): name: str = Field(title="Name") - aliases: List[str] = Field(title="Aliases") - options: Dict[str, str] = Field(title="Options") + aliases: list[str] = Field(title="Aliases") + options: dict[str, str] = Field(title="Options") class UpscalerItem(BaseModel): name: str = Field(title="Name") @@ -284,8 +280,8 @@ class EmbeddingItem(BaseModel): vectors: int = Field(title="Vectors", description="The number of vectors in the embedding") class EmbeddingsResponse(BaseModel): - loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model") - skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") + loaded: dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model") + skipped: dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") class MemoryResponse(BaseModel): ram: dict = Field(title="RAM", description="System memory stats") @@ -303,11 +299,20 @@ class ScriptArg(BaseModel): minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI") maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI") step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI") - choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument") + choices: Optional[list[str]] = Field(default=None, title="Choices", description="Possible values for the argument") class ScriptInfo(BaseModel): name: str = Field(default=None, title="Name", description="Script name") is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script") is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script") - args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments") + args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments") + +class ExtensionItem(BaseModel): + name: str = Field(title="Name", description="Extension name") + remote: str = Field(title="Remote", description="Extension Repository URL") + branch: str = Field(title="Branch", description="Extension Repository Branch") + commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash") + version: str = Field(title="Version", description="Extension Version") + commit_date: str = Field(title="Commit Date", description="Extension Repository Commit Date") + enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled") diff --git a/modules/cache.py b/modules/cache.py index ff26a2132d9..2d37e7b99d5 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -32,7 +32,7 @@ def thread_func(): with cache_lock: cache_filename_tmp = cache_filename + "-" with open(cache_filename_tmp, "w", encoding="utf8") as file: - json.dump(cache_data, file, indent=4) + json.dump(cache_data, file, indent=4, ensure_ascii=False) os.replace(cache_filename_tmp, cache_filename) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index aab62286e24..da93eb2669f 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -70,6 +70,7 @@ parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) +parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device") parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) @@ -90,7 +91,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) -parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) +parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") @@ -107,13 +108,14 @@ parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True) -parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions") +parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the default in earlier versions") parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') -parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') +parser.add_argument('--add-stop-route', action='store_true', help='does not do anything') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) -parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False) +parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) +parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) diff --git a/modules/config_states.py b/modules/config_states.py index b766aef11d8..651793c7f6f 100644 --- a/modules/config_states.py +++ b/modules/config_states.py @@ -4,7 +4,6 @@ import os import json -import time import tqdm from datetime import datetime @@ -38,7 +37,7 @@ def list_config_states(): config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True) for cs in config_states: - timestamp = time.asctime(time.gmtime(cs["created_at"])) + timestamp = datetime.fromtimestamp(cs["created_at"]).strftime('%Y-%m-%d %H:%M:%S') name = cs.get("name", "Config") full_name = f"{name}: {timestamp}" all_config_states[full_name] = cs diff --git a/modules/devices.py b/modules/devices.py index c01f06024b4..ea1f712f950 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -8,6 +8,13 @@ if sys.platform == "darwin": from modules import mac_specific +if shared.cmd_opts.use_ipex: + from modules import xpu_specific + + +def has_xpu() -> bool: + return shared.cmd_opts.use_ipex and xpu_specific.has_xpu + def has_mps() -> bool: if sys.platform != "darwin": @@ -30,6 +37,9 @@ def get_optimal_device_name(): if has_mps(): return "mps" + if has_xpu(): + return xpu_specific.get_xpu_device_string() + return "cpu" @@ -38,7 +48,7 @@ def get_optimal_device(): def get_device_for(task): - if task in shared.cmd_opts.use_cpu: + if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu: return cpu return get_optimal_device() @@ -54,13 +64,17 @@ def torch_gc(): if has_mps(): mac_specific.torch_mps_gc() + if has_xpu(): + xpu_specific.torch_xpu_gc() + def enable_tf32(): if torch.cuda.is_available(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())): + device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() + if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True diff --git a/modules/errors.py b/modules/errors.py index 8c339464d46..eb234a83811 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -6,6 +6,21 @@ exception_records = [] +def format_traceback(tb): + return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)] + + +def format_exception(e, tb): + return {"exception": str(e), "traceback": format_traceback(tb)} + + +def get_exceptions(): + try: + return list(reversed(exception_records)) + except Exception as e: + return str(e) + + def record_exception(): _, e, tb = sys.exc_info() if e is None: @@ -14,8 +29,7 @@ def record_exception(): if exception_records and exception_records[-1] == e: return - from modules import sysinfo - exception_records.append(sysinfo.format_exception(e, tb)) + exception_records.append(format_exception(e, tb)) if len(exception_records) > 5: exception_records.pop(0) diff --git a/modules/extensions.py b/modules/extensions.py index bf9a1878f5d..1899cd52975 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,11 +1,14 @@ +from __future__ import annotations + +import configparser import os import threading +import re from modules import shared, errors, cache, scripts from modules.gitpython_hack import Repo from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 -extensions = [] os.makedirs(extensions_dir, exist_ok=True) @@ -19,11 +22,55 @@ def active(): return [x for x in extensions if x.enabled] +class ExtensionMetadata: + filename = "metadata.ini" + config: configparser.ConfigParser + canonical_name: str + requires: list + + def __init__(self, path, canonical_name): + self.config = configparser.ConfigParser() + + filepath = os.path.join(path, self.filename) + if os.path.isfile(filepath): + try: + self.config.read(filepath) + except Exception: + errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True) + + self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name) + self.canonical_name = canonical_name.lower().strip() + + self.requires = self.get_script_requirements("Requires", "Extension") + + def get_script_requirements(self, field, section, extra_section=None): + """reads a list of requirements from the config; field is the name of the field in the ini file, + like Requires or Before, and section is the name of the [section] in the ini file; additionally, + reads more requirements from [extra_section] if specified.""" + + x = self.config.get(section, field, fallback='') + + if extra_section: + x = x + ', ' + self.config.get(extra_section, field, fallback='') + + return self.parse_list(x.lower()) + + def parse_list(self, text): + """converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])""" + + if not text: + return [] + + # both "," and " " are accepted as separator + return [x for x in re.split(r"[,\s]+", text.strip()) if x] + + class Extension: lock = threading.Lock() cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version'] + metadata: ExtensionMetadata - def __init__(self, name, path, enabled=True, is_builtin=False): + def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None): self.name = name self.path = path self.enabled = enabled @@ -36,6 +83,8 @@ def __init__(self, name, path, enabled=True, is_builtin=False): self.branch = None self.remote = None self.have_info_from_repo = False + self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower()) + self.canonical_name = metadata.canonical_name def to_dict(self): return {x: getattr(self, x) for x in self.cached_fields} @@ -56,6 +105,7 @@ def read_from_repo(): self.do_read_info_from_repo() return self.to_dict() + try: d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo) self.from_dict(d) @@ -136,9 +186,6 @@ def fetch_and_reset_hard(self, commit='origin'): def list_extensions(): extensions.clear() - if not os.path.isdir(extensions_dir): - return - if shared.cmd_opts.disable_all_extensions: print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") elif shared.opts.disable_all_extensions == "all": @@ -148,18 +195,43 @@ def list_extensions(): elif shared.opts.disable_all_extensions == "extra": print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") - extension_paths = [] - for dirname in [extensions_dir, extensions_builtin_dir]: + loaded_extensions = {} + + # scan through extensions directory and load metadata + for dirname in [extensions_builtin_dir, extensions_dir]: if not os.path.isdir(dirname): - return + continue for extension_dirname in sorted(os.listdir(dirname)): path = os.path.join(dirname, extension_dirname) if not os.path.isdir(path): continue - extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir)) + canonical_name = extension_dirname + metadata = ExtensionMetadata(path, canonical_name) + + # check for duplicated canonical names + already_loaded_extension = loaded_extensions.get(metadata.canonical_name) + if already_loaded_extension is not None: + errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False) + continue + + is_builtin = dirname == extensions_builtin_dir + extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata) + extensions.append(extension) + loaded_extensions[canonical_name] = extension + + # check for requirements + for extension in extensions: + for req in extension.metadata.requires: + required_extension = loaded_extensions.get(req) + if required_extension is None: + errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False) + continue + + if not extension.enabled: + errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False) + continue + - for dirname, path, is_builtin in extension_paths: - extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin) - extensions.append(extension) +extensions: list[Extension] = [] diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index d39f2ebac36..4efe53e0c48 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,3 +1,4 @@ +from __future__ import annotations import base64 import io import json @@ -9,15 +10,12 @@ from modules import shared, ui_tempdir, script_callbacks, processing from PIL import Image -re_param_code = r'\s*([\w ]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' +re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") type_of_gr_update = type(gr.update()) -paste_fields = {} -registered_param_bindings = [] - class ParamBinding: def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None): @@ -30,6 +28,10 @@ def __init__(self, paste_button, tabname, source_text_component=None, source_ima self.paste_field_names = paste_field_names or [] +paste_fields: dict[str, dict] = {} +registered_param_bindings: list[ParamBinding] = [] + + def reset(): paste_fields.clear() registered_param_bindings.clear() @@ -113,7 +115,6 @@ def register_paste_params_button(binding: ParamBinding): def connect_paste_params_buttons(): - binding: ParamBinding for binding in registered_param_bindings: destination_image_component = paste_fields[binding.tabname]["init_img"] fields = paste_fields[binding.tabname]["fields"] @@ -313,6 +314,9 @@ def parse_generation_parameters(x: str): if "VAE Decoder" not in res: res["VAE Decoder"] = "Full" + skip = set(shared.opts.infotext_skip_pasting) + res = {k: v for k, v in res.items() if k not in skip} + return res @@ -443,3 +447,4 @@ def paste_settings(params): outputs=[], show_progress=False, ) + diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 8e0f13bdc7d..01d668ecdaf 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -9,6 +9,7 @@ model_dir = "GFPGAN" user_path = None model_path = os.path.join(paths.models_path, model_dir) +model_file_path = None model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" have_gfpgan = False loaded_gfpgan_model = None @@ -17,6 +18,7 @@ def gfpgann(): global loaded_gfpgan_model global model_path + global model_file_path if loaded_gfpgan_model is not None: loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan) return loaded_gfpgan_model @@ -24,17 +26,24 @@ def gfpgann(): if gfpgan_constructor is None: return None - models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") + models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth']) + if len(models) == 1 and models[0].startswith("http"): model_file = models[0] elif len(models) != 0: - latest_file = max(models, key=os.path.getctime) + gfp_models = [] + for item in models: + if 'GFPGAN' in os.path.basename(item): + gfp_models.append(item) + latest_file = max(gfp_models, key=os.path.getctime) model_file = latest_file else: print("Unable to load gfpgan model!") return None + if hasattr(facexlib.detection.retinaface, 'device'): facexlib.detection.retinaface.device = devices.device_gfpgan + model_file_path = model_file model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) loaded_gfpgan_model = model @@ -77,19 +86,25 @@ def setup_model(dirname): global user_path global have_gfpgan global gfpgan_constructor + global model_file_path + + facexlib_path = model_path + + if dirname is not None: + facexlib_path = dirname load_file_from_url_orig = gfpgan.utils.load_file_from_url facex_load_file_from_url_orig = facexlib.detection.load_file_from_url facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url def my_load_file_from_url(**kwargs): - return load_file_from_url_orig(**dict(kwargs, model_dir=model_path)) + return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path)) def facex_load_file_from_url(**kwargs): - return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None)) + return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None)) def facex_load_file_from_url2(**kwargs): - return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None)) + return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None)) gfpgan.utils.load_file_from_url = my_load_file_from_url facexlib.detection.load_file_from_url = facex_load_file_from_url diff --git a/modules/gitpython_hack.py b/modules/gitpython_hack.py index e537c1df93e..b55f0640e5e 100644 --- a/modules/gitpython_hack.py +++ b/modules/gitpython_hack.py @@ -23,7 +23,7 @@ def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]: ) return self._parse_object_header(ret) - def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]: + def stream_object_data(self, ref: str) -> tuple[str, str, int, Git.CatFileContentStream]: # Not really streaming, per se; this buffers the entire object in memory. # Shouldn't be a problem for our use case, since we're only using this for # object headers (commit objects). diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py index aac742ef0c1..34e38700a62 100644 --- a/modules/gradio_extensons.py +++ b/modules/gradio_extensons.py @@ -47,10 +47,20 @@ def Block_get_config(self): def BlockContext_init(self, *args, **kwargs): + if scripts.scripts_current is not None: + scripts.scripts_current.before_component(self, **kwargs) + + scripts.script_callbacks.before_component_callback(self, **kwargs) + res = original_BlockContext_init(self, *args, **kwargs) add_classes_to_gradio_component(self) + scripts.script_callbacks.after_component_callback(self, **kwargs) + + if scripts.scripts_current is not None: + scripts.scripts_current.after_component(self, **kwargs) + return res diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 70f1cbd26b6..be3e4648486 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -468,7 +468,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() -def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(id_task, hypernetwork_name: str, learn_rate: float, batch_size: int, gradient_step: int, data_root: str, log_directory: str, training_width: int, training_height: int, varsize: bool, steps: int, clip_grad_mode: str, clip_grad_value: float, shuffle_tags: bool, tag_drop_out: bool, latent_sampling_method: str, use_weight: bool, create_image_every: int, save_hypernetwork_every: int, template_filename: str, preview_from_txt2img: bool, preview_prompt: str, preview_negative_prompt: str, preview_steps: int, preview_sampler_name: str, preview_cfg_scale: float, preview_seed: int, preview_width: int, preview_height: int): from modules import images, processing save_hypernetwork_every = save_hypernetwork_every or 0 @@ -698,7 +698,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi p.prompt = preview_prompt p.negative_prompt = preview_negative_prompt p.steps = preview_steps - p.sampler_name = sd_samplers.samplers[preview_sampler_index].name + p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()] p.cfg_scale = preview_cfg_scale p.seed = preview_seed p.width = preview_width diff --git a/modules/images.py b/modules/images.py index eb644733898..daf4eebe4b1 100644 --- a/modules/images.py +++ b/modules/images.py @@ -561,6 +561,8 @@ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_p }) piexif.insert(exif_bytes, filename) + elif extension.lower() == ".gif": + image.save(filename, format=image_format, comment=geninfo) else: image.save(filename, format=image_format, quality=opts.jpeg_quality) @@ -661,7 +663,13 @@ def _atomically_save_image(image_to_save, filename_without_extension, extension) save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name) - os.replace(temp_file_path, filename_without_extension + extension) + filename = filename_without_extension + extension + if shared.opts.save_images_replace_action != "Replace": + n = 0 + while os.path.exists(filename): + n += 1 + filename = f"{filename_without_extension}-{n}{extension}" + os.replace(temp_file_path, filename) fullfn_without_extension, extension = os.path.splitext(params.filename) if hasattr(os, 'statvfs'): @@ -718,7 +726,12 @@ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]: geninfo = items.pop('parameters', None) if "exif" in items: - exif = piexif.load(items["exif"]) + exif_data = items["exif"] + try: + exif = piexif.load(exif_data) + except OSError: + # memory / exif was not valid so piexif tried to read from a file + exif = None exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'') try: exif_comment = piexif.helper.UserComment.load(exif_comment) @@ -728,6 +741,8 @@ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]: if exif_comment: items['exif comment'] = exif_comment geninfo = exif_comment + elif "comment" in items: # for gif + geninfo = items["comment"].decode('utf8', errors="ignore") for field in IGNORED_INFO_KEYS: items.pop(field, None) diff --git a/modules/img2img.py b/modules/img2img.py index 1519e132b2b..c583290a0ea 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -10,6 +10,7 @@ from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state +from modules.sd_models import get_closet_checkpoint_match import modules.shared as shared import modules.processing as processing from modules.ui import plaintext_to_html @@ -41,7 +42,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal cfg_scale = p.cfg_scale sampler_name = p.sampler_name steps = p.steps - + override_settings = p.override_settings + sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None)) + batch_results = None + discard_further_results = False for i, image in enumerate(images): state.job = f"{i+1} out of {len(images)}" if state.skipped: @@ -104,16 +108,42 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal p.sampler_name = parsed_parameters.get("Sampler", sampler_name) p.steps = int(parsed_parameters.get("Steps", steps)) + model_info = get_closet_checkpoint_match(parsed_parameters.get("Model hash", None)) + if model_info is not None: + p.override_settings['sd_model_checkpoint'] = model_info.name + elif sd_model_checkpoint_override: + p.override_settings['sd_model_checkpoint'] = sd_model_checkpoint_override + else: + p.override_settings.pop("sd_model_checkpoint", None) + + if output_dir: + p.outpath_samples = output_dir + p.override_settings['save_to_dirs'] = False + p.override_settings['save_images_replace_action'] = "Add number suffix" + if p.n_iter > 1 or p.batch_size > 1: + p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]' + else: + p.override_settings['samples_filename_pattern'] = f'{image_path.stem}' + proc = modules.scripts.scripts_img2img.run(p, *args) + if proc is None: - if output_dir: - p.outpath_samples = output_dir - p.override_settings['save_to_dirs'] = False - if p.n_iter > 1 or p.batch_size > 1: - p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]' - else: - p.override_settings['samples_filename_pattern'] = f'{image_path.stem}' - process_images(p) + p.override_settings.pop('save_images_replace_action', None) + proc = process_images(p) + + if not discard_further_results and proc: + if batch_results: + batch_results.images.extend(proc.images) + batch_results.infotexts.extend(proc.infotexts) + else: + batch_results = proc + + if 0 <= shared.opts.img2img_batch_show_results_limit < len(batch_results.images): + discard_further_results = True + batch_results.images = batch_results.images[:int(shared.opts.img2img_batch_show_results_limit)] + batch_results.infotexts = batch_results.infotexts[:int(shared.opts.img2img_batch_show_results_limit)] + + return batch_results def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): @@ -189,7 +219,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s p.user = request.username - if shared.cmd_opts.enable_console_prompts: + if shared.opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) if mask: @@ -198,10 +228,10 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s with closing(p): if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" + processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) - process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) - - processed = Processed(p, [], p.seed, "") + if processed is None: + processed = Processed(p, [], p.seed, "") else: processed = modules.scripts.scripts_img2img.run(p, *args) if processed is None: diff --git a/modules/import_hook.py b/modules/import_hook.py index 28c67dfa897..eba9a372929 100644 --- a/modules/import_hook.py +++ b/modules/import_hook.py @@ -3,3 +3,14 @@ # this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it if "--xformers" not in "".join(sys.argv): sys.modules["xformers"] = None + +# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks +# basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985 +try: + import torchvision.transforms.functional_tensor # noqa: F401 +except ImportError: + try: + import torchvision.transforms.functional as functional + sys.modules["torchvision.transforms.functional_tensor"] = functional + except ImportError: + pass # shrug... diff --git a/modules/initialize.py b/modules/initialize.py index 6c163c75846..99f5a461649 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -151,8 +151,8 @@ def load_model(): from modules import devices devices.first_time_calculation() - - Thread(target=load_model).start() + if not shared.cmd_opts.skip_load_model_at_start: + Thread(target=load_model).start() from modules import shared_items shared_items.reload_hypernetworks() diff --git a/modules/initialize_util.py b/modules/initialize_util.py index e1803487125..07c578acdb6 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -150,10 +150,14 @@ def dumpstacks(): def configure_sigint_handler(): # make the program just exit at ctrl+c without waiting for anything + + from modules import shared + def sigint_handler(sig, frame): print(f'Interrupted with signal {sig} in {frame}') - dumpstacks() + if shared.opts.dump_stacks_on_signal: + dumpstacks() os._exit(0) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 6e54d06367c..29506f24965 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -6,6 +6,7 @@ import shutil import sys import importlib.util +import importlib.metadata import platform import json from functools import lru_cache @@ -64,7 +65,7 @@ def check_python_version(): @lru_cache() def commit_hash(): try: - return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip() + return subprocess.check_output([git, "-C", script_path, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip() except Exception: return "" @@ -72,7 +73,7 @@ def commit_hash(): @lru_cache() def git_tag(): try: - return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip() + return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip() except Exception: try: @@ -119,11 +120,16 @@ def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_ def is_installed(package): try: - spec = importlib.util.find_spec(package) - except ModuleNotFoundError: - return False + dist = importlib.metadata.distribution(package) + except importlib.metadata.PackageNotFoundError: + try: + spec = importlib.util.find_spec(package) + except ModuleNotFoundError: + return False + + return spec is not None - return spec is not None + return dist is not None def repo_dir(name): @@ -310,6 +316,26 @@ def requirements_met(requirements_file): def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") + if args.use_ipex: + if platform.system() == "Windows": + # The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main + # This is NOT an Intel official release so please use it at your own risk!! + # See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details. + # + # Strengths (over official IPEX 2.0.110 windows release): + # - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399 + # - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system. + # - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465 + # Limitation: + # - Only works for python 3.10 + url_prefix = "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle" + torch_command = os.environ.get('TORCH_COMMAND', f"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl") + else: + # Using official IPEX release for linux since it's already an AOT build. + # However, users still have to install oneAPI toolkit and activate oneAPI environment manually. + # See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details. + torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/") + torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') @@ -352,6 +378,8 @@ def prepare_environment(): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) startup_timer.record("install torch") + if args.use_ipex: + args.skip_torch_cuda_test = True if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"): raise RuntimeError( 'Torch is not able to use GPU; ' @@ -441,7 +469,7 @@ def dump_sysinfo(): import datetime text = sysinfo.get() - filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt" + filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json" with open(filename, "w", encoding="utf8") as file: file.write(text) diff --git a/modules/localization.py b/modules/localization.py index c132028856f..108f792e96f 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -14,21 +14,24 @@ def list_localizations(dirname): if ext.lower() != ".json": continue - localizations[fn] = os.path.join(dirname, file) + localizations[fn] = [os.path.join(dirname, file)] for file in scripts.list_scripts("localizations", ".json"): fn, ext = os.path.splitext(file.filename) - localizations[fn] = file.path + if fn not in localizations: + localizations[fn] = [] + localizations[fn].append(file.path) def localization_js(current_localization_name: str) -> str: - fn = localizations.get(current_localization_name, None) + fns = localizations.get(current_localization_name, None) data = {} - if fn is not None: - try: - with open(fn, "r", encoding="utf8") as file: - data = json.load(file) - except Exception: - errors.report(f"Error loading localization from {fn}", exc_info=True) + if fns is not None: + for fn in fns: + try: + with open(fn, "r", encoding="utf8") as file: + data.update(json.load(file)) + except Exception: + errors.report(f"Error loading localization from {fn}", exc_info=True) return f"window.localization = {json.dumps(data)}" diff --git a/modules/logging_config.py b/modules/logging_config.py index ad55898885b..0e18b984eb8 100644 --- a/modules/logging_config.py +++ b/modules/logging_config.py @@ -1,16 +1,41 @@ import os import logging +try: + from tqdm.auto import tqdm + + class TqdmLoggingHandler(logging.Handler): + def __init__(self, level=logging.INFO): + super().__init__(level) + + def emit(self, record): + try: + msg = self.format(record) + tqdm.write(msg) + self.flush() + except Exception: + self.handleError(record) + + TQDM_IMPORTED = True +except ImportError: + # tqdm does not exist before first launch + # I will import once the UI finishes seting up the enviroment and reloads. + TQDM_IMPORTED = False def setup_logging(loglevel): if loglevel is None: loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL") + loghandlers = [] + + if TQDM_IMPORTED: + loghandlers.append(TqdmLoggingHandler()) + if loglevel: log_level = getattr(logging, loglevel.upper(), None) or logging.INFO logging.basicConfig( level=log_level, format='%(asctime)s %(levelname)s [%(name)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', + handlers=loghandlers ) - diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 89256c5b060..d96d86d792c 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -1,6 +1,7 @@ import logging import torch +from torch import Tensor import platform from modules.sd_hijack_utils import CondFunc from packaging import version @@ -51,6 +52,17 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): return cumsum_func(input, *args, **kwargs) +# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046 +def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor: + try: + return orig_func(*args, **kwargs) + except RuntimeError as e: + if "not implemented for" in str(e) and "Half" in str(e): + input_tensor = args[0] + return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype) + else: + print(f"An unexpected RuntimeError occurred: {str(e)}") + if has_mps: if platform.mac_ver()[0].startswith("13.2."): # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) @@ -77,6 +89,9 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): # MPS workaround for https://github.com/pytorch/pytorch/issues/96113 CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps') + # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046 + CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None) + # MPS workaround for https://github.com/pytorch/pytorch/issues/92311 if platform.processor() == 'i386': for funcName in ['torch.argmax', 'torch.Tensor.argmax']: diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index b892d5fc7b0..6db340da40b 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -24,10 +24,15 @@ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config from ldm.modules.ema import LitEma from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution -from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like from ldm.models.diffusion.ddim import DDIMSampler +try: + from ldm.models.autoencoder import VQModelInterface +except Exception: + class VQModelInterface: + pass __conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', diff --git a/modules/options.py b/modules/options.py index 0616ee95b05..6d8e61966d4 100644 --- a/modules/options.py +++ b/modules/options.py @@ -1,5 +1,6 @@ import json import sys +from dataclasses import dataclass import gradio as gr @@ -8,13 +9,14 @@ class OptionInfo: - def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False): + def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False, category_id=None): self.default = default self.label = label self.component = component self.component_args = component_args self.onchange = onchange self.section = section + self.category_id = category_id self.refresh = refresh self.do_not_save = False @@ -63,7 +65,11 @@ def __init__(self, text): def options_section(section_identifier, options_dict): for v in options_dict.values(): - v.section = section_identifier + if len(section_identifier) == 2: + v.section = section_identifier + elif len(section_identifier) == 3: + v.section = section_identifier[0:2] + v.category_id = section_identifier[2] return options_dict @@ -76,7 +82,7 @@ class Options: def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts): self.data_labels = data_labels - self.data = {k: v.default for k, v in self.data_labels.items()} + self.data = {k: v.default for k, v in self.data_labels.items() if not v.do_not_save} self.restricted_opts = restricted_opts def __setattr__(self, key, value): @@ -158,7 +164,7 @@ def save(self, filename): assert not cmd_opts.freeze_settings, "saving settings is disabled" with open(filename, "w", encoding="utf8") as file: - json.dump(self.data, file, indent=4) + json.dump(self.data, file, indent=4, ensure_ascii=False) def same_type(self, x, y): if x is None or y is None: @@ -206,21 +212,59 @@ def dumpjson(self): d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()} d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None} d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None} + + item_categories = {} + for item in self.data_labels.values(): + category = categories.mapping.get(item.category_id) + category = "Uncategorized" if category is None else category.label + if category not in item_categories: + item_categories[category] = item.section[1] + + # _categories is a list of pairs: [section, category]. Each section (a setting page) will get a special heading above it with the category as text. + d["_categories"] = [[v, k] for k, v in item_categories.items()] + [["Defaults", "Other"]] + return json.dumps(d) def add_option(self, key, info): self.data_labels[key] = info + if key not in self.data and not info.do_not_save: + self.data[key] = info.default def reorder(self): - """reorder settings so that all items related to section always go together""" + """Reorder settings so that: + - all items related to section always go together + - all sections belonging to a category go together + - sections inside a category are ordered alphabetically + - categories are ordered by creation order + + Category is a superset of sections: for category "postprocessing" there could be multiple sections: "face restoration", "upscaling". + + This function also changes items' category_id so that all items belonging to a section have the same category_id. + """ + + category_ids = {} + section_categories = {} - section_ids = {} settings_items = self.data_labels.items() for _, item in settings_items: - if item.section not in section_ids: - section_ids[item.section] = len(section_ids) + if item.section not in section_categories: + section_categories[item.section] = item.category_id + + for _, item in settings_items: + item.category_id = section_categories.get(item.section) + + for category_id in categories.mapping: + if category_id not in category_ids: + category_ids[category_id] = len(category_ids) - self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section])) + def sort_key(x): + item: OptionInfo = x[1] + category_order = category_ids.get(item.category_id, len(category_ids)) + section_order = item.section[1] + + return category_order, section_order + + self.data_labels = dict(sorted(settings_items, key=sort_key)) def cast_value(self, key, value): """casts an arbitrary to the same type as this setting's value with key @@ -243,3 +287,22 @@ def cast_value(self, key, value): value = expected_type(value) return value + + +@dataclass +class OptionsCategory: + id: str + label: str + +class OptionsCategories: + def __init__(self): + self.mapping = {} + + def register_category(self, category_id, label): + if category_id in self.mapping: + return category_id + + self.mapping[category_id] = OptionsCategory(category_id, label) + + +categories = OptionsCategories() diff --git a/modules/paths.py b/modules/paths.py index 2505233999b..187b949612a 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -1,6 +1,6 @@ import os import sys -from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401 +from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401 import modules.safe # noqa: F401 diff --git a/modules/paths_internal.py b/modules/paths_internal.py index 005a9b0aa75..89131a54fa1 100644 --- a/modules/paths_internal.py +++ b/modules/paths_internal.py @@ -8,6 +8,7 @@ commandline_args = os.environ.get('COMMANDLINE_ARGS', "") sys.argv += shlex.split(commandline_args) +cwd = os.getcwd() modules_path = os.path.dirname(os.path.realpath(__file__)) script_path = os.path.dirname(modules_path) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index cf04d38b059..0c59fad480b 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -29,11 +29,7 @@ def get_images(extras_mode, image, image_folder, input_dir): image_list = shared.listfiles(input_dir) for filename in image_list: - try: - image = Image.open(filename) - except Exception: - continue - yield image, filename + yield filename, filename else: assert image, 'image not selected' yield image, None @@ -45,43 +41,97 @@ def get_images(extras_mode, image, image_folder, input_dir): infotext = '' - for image_data, name in get_images(extras_mode, image, image_folder, input_dir): + data_to_process = list(get_images(extras_mode, image, image_folder, input_dir)) + shared.state.job_count = len(data_to_process) + + for image_placeholder, name in data_to_process: image_data: Image.Image + shared.state.nextjob() shared.state.textinfo = name + shared.state.skipped = False + + if shared.state.interrupted: + break + + if isinstance(image_placeholder, str): + try: + image_data = Image.open(image_placeholder) + except Exception: + continue + else: + image_data = image_placeholder + + shared.state.assign_current_image(image_data) parameters, existing_pnginfo = images.read_info_from_image(image_data) if parameters: existing_pnginfo["parameters"] = parameters - pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB")) + initial_pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB")) - scripts.scripts_postproc.run(pp, args) + scripts.scripts_postproc.run(initial_pp, args) - if opts.use_original_name_batch and name is not None: - basename = os.path.splitext(os.path.basename(name))[0] - else: - basename = '' + if shared.state.skipped: + continue + + used_suffixes = {} + for pp in [initial_pp, *initial_pp.extra_images]: + suffix = pp.get_suffix(used_suffixes) - infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) + if opts.use_original_name_batch and name is not None: + basename = os.path.splitext(os.path.basename(name))[0] + forced_filename = basename + suffix + else: + basename = '' + forced_filename = None - if opts.enable_pnginfo: - pp.image.info = existing_pnginfo - pp.image.info["postprocessing"] = infotext + infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) - if save_output: - images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) + if opts.enable_pnginfo: + pp.image.info = existing_pnginfo + pp.image.info["postprocessing"] = infotext - if extras_mode != 2 or show_extras_results: - outputs.append(pp.image) + if save_output: + fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix) + + if pp.caption: + caption_filename = os.path.splitext(fullfn)[0] + ".txt" + if os.path.isfile(caption_filename): + with open(caption_filename, encoding="utf8") as file: + existing_caption = file.read().strip() + else: + existing_caption = "" + + action = shared.opts.postprocessing_existing_caption_action + if action == 'Prepend' and existing_caption: + caption = f"{existing_caption} {pp.caption}" + elif action == 'Append' and existing_caption: + caption = f"{pp.caption} {existing_caption}" + elif action == 'Keep' and existing_caption: + caption = existing_caption + else: + caption = pp.caption + + caption = caption.strip() + if caption: + with open(caption_filename, "w", encoding="utf8") as file: + file.write(caption) + + if extras_mode != 2 or show_extras_results: + outputs.append(pp.image) image_data.close() devices.torch_gc() - + shared.state.end() return outputs, ui_common.plaintext_to_html(infotext), '' +def run_postprocessing_webui(id_task, *args, **kwargs): + return run_postprocessing(*args, **kwargs) + + def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): """old handler for API""" @@ -97,9 +147,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ "upscaler_2_visibility": extras_upscaler_2_visibility, }, "GFPGAN": { + "enable": True, "gfpgan_visibility": gfpgan_visibility, }, "CodeFormer": { + "enable": True, "codeformer_visibility": codeformer_visibility, "codeformer_weight": codeformer_weight, }, diff --git a/modules/processing.py b/modules/processing.py index e124e7f0dd2..6f01c95f5b6 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -142,7 +142,7 @@ class StableDiffusionProcessing: overlay_images: list = None eta: float = None do_not_reload_embeddings: bool = False - denoising_strength: float = 0 + denoising_strength: float = None ddim_discretize: str = None s_min_uncond: float = None s_churn: float = None @@ -296,7 +296,7 @@ def depth2img_image_conditioning(self, source_image): return conditioning def edit_image_conditioning(self, source_image): - conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) + conditioning_image = shared.sd_model.encode_first_stage(source_image).mode() return conditioning_image @@ -533,6 +533,7 @@ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", self.all_seeds = all_seeds or p.all_seeds or [self.seed] self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed] self.infotexts = infotexts or [info] + self.version = program_version() def js(self): obj = { @@ -567,6 +568,7 @@ def js(self): "job_timestamp": self.job_timestamp, "clip_skip": self.clip_skip, "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning, + "version": self.version, } return json.dumps(obj) @@ -677,8 +679,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}", "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None, "Model": p.sd_model_name if opts.add_model_name_to_info else None, - "VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None, - "VAE": p.sd_vae_name if opts.add_model_name_to_info else None, + "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None, + "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None, "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), @@ -709,7 +711,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.before_process(p) - stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} + stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data} try: # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint @@ -797,7 +799,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: infotexts = [] output_images = [] - with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) @@ -871,7 +872,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -884,6 +884,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() + state.nextjob() + if p.scripts is not None: p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n) @@ -936,27 +938,27 @@ def infotext(index=0, use_main_prompt=False): if opts.enable_pnginfo: image.info["parameters"] = text output_images.append(image) - if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): - image_mask = p.mask_for_overlay.convert('RGB') - image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') - - if opts.save_mask: - images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") - - if opts.save_mask_composite: - images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") - - if opts.return_mask: - output_images.append(image_mask) - - if opts.return_mask_composite: - output_images.append(image_mask_composite) + if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay: + if opts.return_mask or opts.save_mask: + image_mask = p.mask_for_overlay.convert('RGB') + if save_samples and opts.save_mask: + images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") + if opts.return_mask: + output_images.append(image_mask) + + if opts.return_mask_composite or opts.save_mask_composite: + image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') + if save_samples and opts.save_mask_composite: + images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") + if opts.return_mask_composite: + output_images.append(image_mask_composite) del x_samples_ddim devices.torch_gc() - state.nextjob() + if not infotexts: + infotexts.append(Processed(p, []).infotext(p, 0)) p.color_corrections = None @@ -1142,6 +1144,7 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs if not self.enable_hr: return samples + devices.torch_gc() if self.latent_scale_mode is None: decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) @@ -1151,8 +1154,6 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs with sd_models.SkipWritingToConfig(): sd_models.reload_model_weights(info=self.hr_checkpoint_info) - devices.torch_gc() - return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts): @@ -1160,7 +1161,6 @@ def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_stre return samples self.is_hr_pass = True - target_width = self.hr_upscale_to_x target_height = self.hr_upscale_to_y @@ -1249,7 +1249,6 @@ def save_intermediate(image, index): decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) self.is_hr_pass = False - return decoded_samples def close(self): diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index 265087fd85e..ccaee9731a8 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -29,8 +29,8 @@ def ui(self, is_img2img): else: self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0) - random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), label='Random seed') - reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), label='Reuse seed') + random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), tooltip="Set seed to -1, which will cause a new random number to be used every time") + reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), tooltip="Reuse seed from last generation, mostly useful if it was randomized") seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 334efeef317..cba1345545d 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -2,10 +2,9 @@ import re from collections import namedtuple -from typing import List import lark -# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" +# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][: in background:0.25] [shoddy:masterful:0.5]" # will be represented with prompt_schedule like this (assuming steps=100): # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy'] @@ -240,14 +239,14 @@ def get_multicond_prompt_list(prompts: SdConditioning | list[str]): class ComposableScheduledPromptConditioning: def __init__(self, schedules, weight=1.0): - self.schedules: List[ScheduledPromptConditioning] = schedules + self.schedules: list[ScheduledPromptConditioning] = schedules self.weight: float = weight class MulticondLearnedConditioning: def __init__(self, shape, batch): self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS - self.batch: List[List[ComposableScheduledPromptConditioning]] = batch + self.batch: list[list[ComposableScheduledPromptConditioning]] = batch def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning: @@ -278,7 +277,7 @@ def shape(self): return self["crossattn"].shape -def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step): +def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step): param = c[0][0].cond is_dict = isinstance(param, dict) diff --git a/modules/restart.py b/modules/restart.py index 18eacaf377e..2dd6493b450 100644 --- a/modules/restart.py +++ b/modules/restart.py @@ -14,7 +14,9 @@ def is_restartable() -> bool: def restart_program() -> None: """creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again""" - (Path(script_path) / "tmp" / "restart").touch() + tmpdir = Path(script_path) / "tmp" + tmpdir.mkdir(parents=True, exist_ok=True) + (tmpdir / "restart").touch() stop_program() diff --git a/modules/rng.py b/modules/rng.py index 21113c9449c..17bc83344ee 100644 --- a/modules/rng.py +++ b/modules/rng.py @@ -110,7 +110,7 @@ def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resiz self.is_first = True def first(self): - noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8) + noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8)) xs = [] diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index c99695eb3d9..9ed7ad21d1b 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,7 +1,7 @@ import inspect import os from collections import namedtuple -from typing import Optional, Dict, Any +from typing import Optional, Any from fastapi import FastAPI from gradio import Blocks @@ -258,7 +258,7 @@ def image_grid_callback(params: ImageGridLoopParams): report_exception(c, 'image_grid') -def infotext_pasted_callback(infotext: str, params: Dict[str, Any]): +def infotext_pasted_callback(infotext: str, params: dict[str, Any]): for c in callback_map['callbacks_infotext_pasted']: try: c.callback(infotext, params) @@ -449,7 +449,7 @@ def on_infotext_pasted(callback): """register a function to be called before applying an infotext. The callback is called with two arguments: - infotext: str - raw infotext. - - result: Dict[str, any] - parsed infotext parameters. + - result: dict[str, any] - parsed infotext parameters. """ add_callback(callback_map['callbacks_infotext_pasted'], callback) diff --git a/modules/scripts.py b/modules/scripts.py index e8518ad0fba..7f9454eb578 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -311,20 +311,113 @@ def basedir(): postprocessing_scripts_data = [] ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) +def topological_sort(dependencies): + """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies. + Ignores errors relating to missing dependeencies or circular dependencies + """ + + visited = {} + result = [] + + def inner(name): + visited[name] = True + + for dep in dependencies.get(name, []): + if dep in dependencies and dep not in visited: + inner(dep) + + result.append(name) + + for depname in dependencies: + if depname not in visited: + inner(depname) + + return result + + +@dataclass +class ScriptWithDependencies: + script_canonical_name: str + file: ScriptFile + requires: list + load_before: list + load_after: list + def list_scripts(scriptdirname, extension, *, include_extensions=True): - scripts_list = [] + scripts = {} + + loaded_extensions = {ext.canonical_name: ext for ext in extensions.active()} + loaded_extensions_scripts = {ext.canonical_name: [] for ext in extensions.active()} + + # build script dependency map + root_script_basedir = os.path.join(paths.script_path, scriptdirname) + if os.path.exists(root_script_basedir): + for filename in sorted(os.listdir(root_script_basedir)): + if not os.path.isfile(os.path.join(root_script_basedir, filename)): + continue + + if os.path.splitext(filename)[1].lower() != extension: + continue - basedir = os.path.join(paths.script_path, scriptdirname) - if os.path.exists(basedir): - for filename in sorted(os.listdir(basedir)): - scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) + script_file = ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)) + scripts[filename] = ScriptWithDependencies(filename, script_file, [], [], []) if include_extensions: for ext in extensions.active(): - scripts_list += ext.list_files(scriptdirname, extension) - - scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] + extension_scripts_list = ext.list_files(scriptdirname, extension) + for extension_script in extension_scripts_list: + if not os.path.isfile(extension_script.path): + continue + + script_canonical_name = ("builtin/" if ext.is_builtin else "") + ext.canonical_name + "/" + extension_script.filename + relative_path = scriptdirname + "/" + extension_script.filename + + script = ScriptWithDependencies( + script_canonical_name=script_canonical_name, + file=extension_script, + requires=ext.metadata.get_script_requirements("Requires", relative_path, scriptdirname), + load_before=ext.metadata.get_script_requirements("Before", relative_path, scriptdirname), + load_after=ext.metadata.get_script_requirements("After", relative_path, scriptdirname), + ) + + scripts[script_canonical_name] = script + loaded_extensions_scripts[ext.canonical_name].append(script) + + for script_canonical_name, script in scripts.items(): + # load before requires inverse dependency + # in this case, append the script name into the load_after list of the specified script + for load_before in script.load_before: + # if this requires an individual script to be loaded before + other_script = scripts.get(load_before) + if other_script: + other_script.load_after.append(script_canonical_name) + + # if this requires an extension + other_extension_scripts = loaded_extensions_scripts.get(load_before) + if other_extension_scripts: + for other_script in other_extension_scripts: + other_script.load_after.append(script_canonical_name) + + # if After mentions an extension, remove it and instead add all of its scripts + for load_after in list(script.load_after): + if load_after not in scripts and load_after in loaded_extensions_scripts: + script.load_after.remove(load_after) + + for other_script in loaded_extensions_scripts.get(load_after, []): + script.load_after.append(other_script.script_canonical_name) + + dependencies = {} + + for script_canonical_name, script in scripts.items(): + for required_script in script.requires: + if required_script not in scripts and required_script not in loaded_extensions: + errors.report(f'Script "{script_canonical_name}" requires "{required_script}" to be loaded, but it is not.', exc_info=False) + + dependencies[script_canonical_name] = script.load_after + + ordered_scripts = topological_sort(dependencies) + scripts_list = [scripts[script_canonical_name].file for script_canonical_name in ordered_scripts] return scripts_list @@ -365,15 +458,9 @@ def register_scripts_from_module(module): elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) - def orderby(basedir): - # 1st webui, 2nd extensions-builtin, 3rd extensions - priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0} - for key in priority: - if basedir.startswith(key): - return priority[key] - return 9999 - - for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]): + # here the scripts_list is already ordered + # processing_script is not considered though + for scriptfile in scripts_list: try: if scriptfile.basedir != paths.script_path: sys.path = [scriptfile.basedir] + sys.path @@ -473,17 +560,25 @@ def apply_on_before_component_callbacks(self): on_after.clear() def create_script_ui(self, script): - import modules.api.models as api_models script.args_from = len(self.inputs) script.args_to = len(self.inputs) + try: + self.create_script_ui_inner(script) + except Exception: + errors.report(f"Error creating UI for {script.name}: ", exc_info=True) + + def create_script_ui_inner(self, script): + import modules.api.models as api_models + controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) if controls is None: return script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() + api_args = [] for control in controls: @@ -491,11 +586,15 @@ def create_script_ui(self, script): arg_info = api_models.ScriptArg(label=control.label or "") - for field in ("value", "minimum", "maximum", "step", "choices"): + for field in ("value", "minimum", "maximum", "step"): v = getattr(control, field, None) if v is not None: setattr(arg_info, field, v) + choices = getattr(control, 'choices', None) # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string + if choices is not None: + arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices] + api_args.append(arg_info) script.api_info = api_models.ScriptInfo( diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py index bac1335dc35..901cad080c0 100644 --- a/modules/scripts_postprocessing.py +++ b/modules/scripts_postprocessing.py @@ -1,13 +1,56 @@ +import dataclasses import os import gradio as gr from modules import errors, shared +@dataclasses.dataclass +class PostprocessedImageSharedInfo: + target_width: int = None + target_height: int = None + + class PostprocessedImage: def __init__(self, image): self.image = image self.info = {} + self.shared = PostprocessedImageSharedInfo() + self.extra_images = [] + self.nametags = [] + self.disable_processing = False + self.caption = None + + def get_suffix(self, used_suffixes=None): + used_suffixes = {} if used_suffixes is None else used_suffixes + suffix = "-".join(self.nametags) + if suffix: + suffix = "-" + suffix + + if suffix not in used_suffixes: + used_suffixes[suffix] = 1 + return suffix + + for i in range(1, 100): + proposed_suffix = suffix + "-" + str(i) + + if proposed_suffix not in used_suffixes: + used_suffixes[proposed_suffix] = 1 + return proposed_suffix + + return suffix + + def create_copy(self, new_image, *, nametags=None, disable_processing=False): + pp = PostprocessedImage(new_image) + pp.shared = self.shared + pp.nametags = self.nametags.copy() + pp.info = self.info.copy() + pp.disable_processing = disable_processing + + if nametags is not None: + pp.nametags += nametags + + return pp class ScriptPostprocessing: @@ -42,10 +85,17 @@ def process(self, pp: PostprocessedImage, **args): pass - def image_changed(self): - pass + def process_firstpass(self, pp: PostprocessedImage, **args): + """ + Called for all scripts before calling process(). Scripts can examine the image here and set fields + of the pp object to communicate things to other scripts. + args contains a dictionary with all values returned by components from ui() + """ + pass + def image_changed(self): + pass def wrap_call(func, filename, funcname, *args, default=None, **kwargs): @@ -118,16 +168,42 @@ def setup_ui(self): return inputs def run(self, pp: PostprocessedImage, args): - for script in self.scripts_in_preferred_order(): - shared.state.job = script.name + scripts = [] + for script in self.scripts_in_preferred_order(): script_args = args[script.args_from:script.args_to] process_args = {} for (name, _component), value in zip(script.controls.items(), script_args): process_args[name] = value - script.process(pp, **process_args) + scripts.append((script, process_args)) + + for script, process_args in scripts: + script.process_firstpass(pp, **process_args) + + all_images = [pp] + + for script, process_args in scripts: + if shared.state.skipped: + break + + shared.state.job = script.name + + for single_image in all_images.copy(): + + if not single_image.disable_processing: + script.process(single_image, **process_args) + + for extra_image in single_image.extra_images: + if not isinstance(extra_image, PostprocessedImage): + extra_image = single_image.create_copy(extra_image) + + all_images.append(extra_image) + + single_image.extra_images.clear() + + pp.extra_images = all_images[1:] def create_args_for_run(self, scripts_args): if not self.ui_created: diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 8863107ae6f..273a7edd8b4 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -215,7 +215,7 @@ def load_state_dict(original, module, state_dict, strict=True): would be on the meta device. """ - if state_dict == sd: + if state_dict is sd: state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} original(module, state_dict, strict=strict) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 592f00551f1..e139d9964cb 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -2,14 +2,15 @@ from torch.nn.functional import silu from types import MethodType -from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet +from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18 import ldm.modules.attention import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.openaimodel +import ldm.models.diffusion.ddpm import ldm.models.diffusion.ddim import ldm.models.diffusion.plms import ldm.modules.encoders.modules @@ -37,6 +38,12 @@ optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None +ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward) +ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward) + +sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward) +sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward) + def list_optimizers(): new_optimizers = script_callbacks.list_optimizers_callback() @@ -181,6 +188,20 @@ def apply_optimizations(self, option=None): errors.display(e, "applying cross attention optimization") undo_optimizations() + def convert_sdxl_to_ssd(self, m): + """Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)""" + + delattr(m.model.diffusion_model.middle_block, '1') + delattr(m.model.diffusion_model.middle_block, '2') + for i in ['9', '8', '7', '6', '5', '4']: + delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i) + delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i) + delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i) + delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i) + delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1') + delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1') + devices.torch_gc() + def hijack(self, m): conditioner = getattr(m, 'conditioner', None) if conditioner: @@ -208,7 +229,7 @@ def hijack(self, m): else: m.cond_stage_model = conditioner - if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: + if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self) @@ -239,10 +260,17 @@ def flatten(el): self.layers = flatten(m) - if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'): - ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward + import modules.models.diffusion.ddpm_edit + + if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion): + sd_unet.original_forward = ldm_original_forward + elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion): + sd_unet.original_forward = ldm_original_forward + elif isinstance(m, sgm.models.diffusion.DiffusionEngine): + sd_unet.original_forward = sgm_original_forward + else: + sd_unet.original_forward = None - ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward def undo_hijack(self, m): conditioner = getattr(m, 'conditioner', None) @@ -279,7 +307,6 @@ def undo_hijack(self, m): self.layers = None self.clip = None - ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui def apply_circular(self, enable): if self.circular_enabled == enable: diff --git a/modules/sd_models.py b/modules/sd_models.py index 930d0bee5c8..9355f1e16b7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,22 +1,22 @@ import collections import os.path import sys -import gc import threading import torch import re import safetensors.torch -from omegaconf import OmegaConf +from omegaconf import OmegaConf, ListConfig from os import mkdir from urllib import request import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches from modules.timer import Timer import tomesd +import numpy as np model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) @@ -49,11 +49,12 @@ class CheckpointInfo: def __init__(self, filename): self.filename = filename abspath = os.path.abspath(filename) + abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" - if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): - name = abspath.replace(shared.cmd_opts.ckpt_dir, '') + if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir): + name = abspath.replace(abs_ckpt_dir, '') elif abspath.startswith(model_path): name = abspath.replace(model_path, '') else: @@ -129,9 +130,12 @@ def calculate_shorthash(self): def setup_model(): + """called once at startup to do various one-time tasks related to SD models""" + os.makedirs(model_path, exist_ok=True) enable_midas_autodownload() + patch_given_betas() def checkpoint_tiles(use_short=False): @@ -226,15 +230,19 @@ def select_checkpoint(): return checkpoint_info -checkpoint_dict_replacements = { +checkpoint_dict_replacements_sd1 = { 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', } +checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format. + 'conditioner.embedders.0.': 'cond_stage_model.', +} + -def transform_checkpoint_dict_key(k): - for text, replacement in checkpoint_dict_replacements.items(): +def transform_checkpoint_dict_key(k, replacements): + for text, replacement in replacements.items(): if k.startswith(text): k = replacement + k[len(text):] @@ -245,9 +253,14 @@ def get_state_dict_from_checkpoint(pl_sd): pl_sd = pl_sd.pop("state_dict", pl_sd) pl_sd.pop("state_dict", None) + is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024 + sd = {} for k, v in pl_sd.items(): - new_key = transform_checkpoint_dict_key(k) + if is_sd2_turbo: + new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo) + else: + new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1) if new_key is not None: sd[new_key] = v @@ -309,6 +322,8 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): if checkpoint_info in checkpoints_loaded: # use checkpoint cache print(f"Loading weights [{sd_model_hash}] from cache") + # move to end as latest + checkpoints_loaded.move_to_end(checkpoint_info) return checkpoints_loaded[checkpoint_info] print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") @@ -346,16 +361,19 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.is_sdxl = hasattr(model, 'conditioner') model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') model.is_sd1 = not model.is_sdxl and not model.is_sd2 - + model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() if model.is_sdxl: sd_models_xl.extend_sdxl(model) - model.load_state_dict(state_dict, strict=False) - timer.record("apply weights to model") + if model.is_ssd: + sd_hijack.model_hijack.convert_sdxl_to_ssd(model) if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model - checkpoints_loaded[checkpoint_info] = state_dict + checkpoints_loaded[checkpoint_info] = state_dict.copy() + + model.load_state_dict(state_dict, strict=False) + timer.record("apply weights to model") del state_dict @@ -453,6 +471,20 @@ def load_model_wrapper(model_type): midas.api.load_model = load_model_wrapper +def patch_given_betas(): + import ldm.models.diffusion.ddpm + + def patched_register_schedule(*args, **kwargs): + """a modified version of register_schedule function that converts plain list from Omegaconf into numpy""" + + if isinstance(args[1], ListConfig): + args = (args[0], np.array(args[1]), *args[2:]) + + original_register_schedule(*args, **kwargs) + + original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule) + + def repair_config(sd_config): if not hasattr(sd_config.model.params, "use_ema"): @@ -777,17 +809,7 @@ def reload_model_weights(sd_model=None, info=None): def unload_model_weights(sd_model=None, info=None): - timer = Timer() - - if model_data.sd_model: - model_data.sd_model.to(devices.cpu) - sd_hijack.model_hijack.undo_hijack(model_data.sd_model) - model_data.sd_model = None - sd_model = None - gc.collect() - devices.torch_gc() - - print(f"Unloaded weights {timer.summary()}.") + send_model_to_cpu(sd_model or shared.sd_model) return sd_model diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 08dd03f19c7..deab2f6e237 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -21,7 +21,7 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") - +config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") def is_using_v_parameterization_for_sd2(state_dict): """ @@ -95,7 +95,10 @@ def guess_model_config_from_state_dict(sd, filename): if diffusion_model_input.shape[1] == 8: return config_instruct_pix2pix + if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: + if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: + return config_alt_diffusion_m18 return config_alt_diffusion return config_default diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py index 4ff81d4c679..ca5eac6011e 100644 --- a/modules/sd_models_types.py +++ b/modules/sd_models_types.py @@ -22,7 +22,10 @@ class WebuiSdModel(LatentDiffusion): """structure with additional information about the file with model's weights""" is_sdxl: bool - """True if the model's architecture is SDXL""" + """True if the model's architecture is SDXL or SSD""" + + is_ssd: bool + """True if the model is SSD""" is_sd2: bool """True if the model's architecture is SD 2.x""" diff --git a/modules/sd_samplers_extra.py b/modules/sd_samplers_extra.py index 24491ca1616..638d7560690 100644 --- a/modules/sd_samplers_extra.py +++ b/modules/sd_samplers_extra.py @@ -60,7 +60,7 @@ def heun_step(x, old_sigma, new_sigma, second_order=True): sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] while restart_times > 0: restart_times -= 1 - step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])]) + step_list.extend(zip(sigma_restart[:-1], sigma_restart[1:])) last_sigma = None for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable): diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py index 9b33232b4eb..dcb2b208b7e 100644 --- a/modules/sd_samplers_timesteps_impl.py +++ b/modules/sd_samplers_timesteps_impl.py @@ -11,7 +11,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] - alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) @@ -43,7 +43,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta= def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] - alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) extra_args = {} if extra_args is None else extra_args diff --git a/modules/sd_unet.py b/modules/sd_unet.py index 5525cfbc3a0..a771849c8c2 100644 --- a/modules/sd_unet.py +++ b/modules/sd_unet.py @@ -1,12 +1,11 @@ import torch.nn -import ldm.modules.diffusionmodules.openaimodel from modules import script_callbacks, shared, devices unet_options = [] current_unet_option = None current_unet = None - +original_forward = None # not used, only left temporarily for compatibility def list_unets(): new_unets = script_callbacks.list_unets_callback() @@ -84,9 +83,12 @@ def deactivate(self): pass -def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): - if current_unet is not None: - return current_unet.forward(x, timesteps, context, *args, **kwargs) +def create_unet_forward(original_forward): + def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): + if current_unet is not None: + return current_unet.forward(x, timesteps, context, *args, **kwargs) + + return original_forward(self, x, timesteps, context, *args, **kwargs) - return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs) + return UNetModel_forward diff --git a/modules/shared_cmd_options.py b/modules/shared_cmd_options.py index e9dcacd8abb..bb0c4acb07e 100644 --- a/modules/shared_cmd_options.py +++ b/modules/shared_cmd_options.py @@ -14,5 +14,5 @@ else: cmd_opts, _ = parser.parse_known_args() - -cmd_opts.disable_extension_access = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name]) and not cmd_opts.enable_insecure_extension_access +cmd_opts.webui_is_non_local = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name]) +cmd_opts.disable_extension_access = cmd_opts.webui_is_non_local and not cmd_opts.enable_insecure_extension_access diff --git a/modules/shared_items.py b/modules/shared_items.py index 84d69c8df43..991971ad0fb 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -44,9 +44,9 @@ def refresh_unet_list(): modules.sd_unet.list_unets() -def list_checkpoint_tiles(): +def list_checkpoint_tiles(use_short=False): import modules.sd_models - return modules.sd_models.checkpoint_tiles() + return modules.sd_models.checkpoint_tiles(use_short) def refresh_checkpoints(): @@ -66,7 +66,25 @@ def reload_hypernetworks(): shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) +def get_infotext_names(): + from modules import generation_parameters_copypaste, shared + res = {} + + for info in shared.opts.data_labels.values(): + if info.infotext: + res[info.infotext] = 1 + + for tab_data in generation_parameters_copypaste.paste_fields.values(): + for _, name in tab_data.get("fields") or []: + if isinstance(name, str): + res[name] = 1 + + return list(res) + + ui_reorder_categories_builtin_items = [ + "prompt", + "image", "inpaint", "sampler", "accordions", diff --git a/modules/shared_options.py b/modules/shared_options.py index 7739b80291f..3b77f7df7ce 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -3,7 +3,7 @@ from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 from modules.shared_cmd_options import cmd_opts -from modules.options import options_section, OptionInfo, OptionHTML +from modules.options import options_section, OptionInfo, OptionHTML, categories options_templates = {} hide_dirs = shared.hide_dirs @@ -21,12 +21,19 @@ "outdir_init_images" } -options_templates.update(options_section(('saving-images', "Saving images/grids"), { +categories.register_category("saving", "Saving images") +categories.register_category("sd", "Stable Diffusion") +categories.register_category("ui", "User Interface") +categories.register_category("system", "System") +categories.register_category("postprocessing", "Postprocessing") +categories.register_category("training", "Training") + +options_templates.update(options_section(('saving-images', "Saving images/grids", "saving"), { "samples_save": OptionInfo(True, "Always save all generated images"), "samples_format": OptionInfo('png', 'File format for images'), "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs), - + "save_images_replace_action": OptionInfo("Replace", "Saving the image to an existing file", gr.Radio, {"choices": ["Replace", "Add number suffix"], **hide_dirs}), "grid_save": OptionInfo(True, "Always save all generated image grids"), "grid_format": OptionInfo('png', 'File format for grids'), "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), @@ -39,8 +46,6 @@ "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}), "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}), - "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), - "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."), "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), @@ -62,9 +67,12 @@ "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"), "save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."), + + "notification_audio": OptionInfo(True, "Play notification sound after image generation").info("notification.mp3 should be present in the root directory").needs_reload_ui(), + "notification_volume": OptionInfo(100, "Notification sound volume", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}).info("in %"), })) -options_templates.update(options_section(('saving-paths', "Paths for saving"), { +options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), { "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs), "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs), "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs), @@ -76,7 +84,7 @@ "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs), })) -options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { +options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), { "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"), "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"), "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"), @@ -84,22 +92,23 @@ "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}), })) -options_templates.update(options_section(('upscaling', "Upscaling"), { +options_templates.update(options_section(('upscaling', "Upscaling", "postprocessing"), { "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}), })) -options_templates.update(options_section(('face-restoration', "Face restoration"), { +options_templates.update(options_section(('face-restoration', "Face restoration", "postprocessing"), { "face_restoration": OptionInfo(False, "Restore faces", infotext='Face restoration').info("will use a third-party model on generation result to reconstruct faces"), "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}), "code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"), "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"), })) -options_templates.update(options_section(('system', "System"), { +options_templates.update(options_section(('system', "System", "system"), { "auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}), + "enable_console_prompts": OptionInfo(shared.cmd_opts.enable_console_prompts, "Print prompts to console when generating with txt2img and img2img."), "show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(), "show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"), @@ -109,15 +118,16 @@ "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), + "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), })) -options_templates.update(options_section(('API', "API"), { +options_templates.update(options_section(('API', "API", "system"), { "api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True), "api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True), "api_useragent": OptionInfo("", "User agent for requests", restrict_api=True), })) -options_templates.update(options_section(('training', "Training"), { +options_templates.update(options_section(('training', "Training", "training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), @@ -132,8 +142,8 @@ "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."), })) -options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'), +options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), { + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'), "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"), @@ -149,14 +159,14 @@ "hires_fix_refiner_pass": OptionInfo("second pass", "Hires fix: which pass to enable refiner for", gr.Radio, {"choices": ["first pass", "second pass", "both passes"]}, infotext="Hires refiner"), })) -options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { +options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), { "sdxl_crop_top": OptionInfo(0, "crop top coordinate"), "sdxl_crop_left": OptionInfo(0, "crop left coordinate"), "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"), "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), })) -options_templates.update(options_section(('vae', "VAE"), { +options_templates.update(options_section(('vae', "VAE", "sd"), { "sd_vae_explanation": OptionHTML(""" VAE is a neural network that transforms a standard RGB image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling @@ -171,7 +181,7 @@ "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"), })) -options_templates.update(options_section(('img2img', "img2img"), { +options_templates.update(options_section(('img2img', "img2img", "sd"), { "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'), "img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"), @@ -184,9 +194,10 @@ "img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(), "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"), "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"), + "img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'), })) -options_templates.update(options_section(('optimizations', "Optimizations"), { +options_templates.update(options_section(('optimizations', "Optimizations", "sd"), { "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), @@ -197,7 +208,7 @@ "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), })) -options_templates.update(options_section(('compatibility', "Compatibility"), { +options_templates.update(options_section(('compatibility', "Compatibility", "sd"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."), @@ -222,14 +233,17 @@ "deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"), })) -options_templates.update(options_section(('extra_networks', "Extra Networks"), { +options_templates.update(options_section(('extra_networks', "Extra Networks", "sd"), { "extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."), + "extra_networks_dir_button_function": OptionInfo(False, "Add a '/' to the beginning of directory buttons").info("Buttons will display the contents of the selected directory without acting as a search filter."), "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'), "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}), "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"), "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"), "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"), + "extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(), + "extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(), "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), @@ -237,42 +251,66 @@ "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks), })) -options_templates.update(options_section(('ui', "User interface"), { - "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(), - "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the gallery.").needs_reload_ui(), - "gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"), - "gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("an be any valid CSS value").needs_reload_ui(), - "return_grid": OptionInfo(True, "Show grid in results for web"), - "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), - "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), - "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), - "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), - "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), - "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"), - "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"), - "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), +options_templates.update(options_section(('ui_prompt_editing', "Prompt editing", "ui"), { + "keyedit_precision_attention": OptionInfo(0.1, "Precision for (attention:1.1) when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "keyedit_precision_extra": OptionInfo(0.05, "Precision for when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"), + "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}), + "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), + "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), +})) + +options_templates.update(options_section(('ui_gallery', "Gallery", "ui"), { + "return_grid": OptionInfo(True, "Show grid in gallery"), + "do_not_show_images": OptionInfo(False, "Do not show any images in gallery"), + "js_modal_lightbox": OptionInfo(True, "Full page image viewer: enable"), + "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Full page image viewer: show images zoomed in by default"), + "js_modal_lightbox_gamepad": OptionInfo(False, "Full page image viewer: navigate with gamepad"), + "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Full page image viewer: gamepad repeat period").info("in milliseconds"), + "gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("can be any valid CSS value, for example 768px or 20em").needs_reload_ui(), +})) + +options_templates.update(options_section(('ui_alternatives', "UI alternatives", "ui"), { + "compact_prompt_box": OptionInfo(False, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(), "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(), "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(), - "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), - "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), - "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"), - "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), + "sd_checkpoint_dropdown_use_short": OptionInfo(False, "Checkpoint dropdown: use filenames without paths").info("models in subdirectories like photo/sd15.ckpt will be listed as just sd15.ckpt"), + "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(), + "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), + "txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(), + "img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(), +})) + +options_templates.update(options_section(('ui', "User interface", "ui"), { + "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(), "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(), "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), - "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(), - "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(), - "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), - "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), + "ui_reorder_list": OptionInfo([], "UI item order for txt2img/img2img tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(), + "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the gallery.").needs_reload_ui(), + "gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"), + "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), + "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), + "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), })) -options_templates.update(options_section(('infotext', "Infotext"), { - "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), - "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), - "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"), - "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"), +options_templates.update(options_section(('infotext', "Infotext", "ui"), { + "infotext_explanation": OptionHTML(""" +Infotext is what this software calls the text that contains generation parameters and can be used to generate the same picture again. +It is displayed in UI below the image. To use infotext, paste it into the prompt and click the ↙️ paste button. +"""), + "enable_pnginfo": OptionInfo(True, "Write infotext to metadata of the generated image"), + "save_txt": OptionInfo(False, "Create a text file with infotext next to every generated image"), + + "add_model_name_to_info": OptionInfo(True, "Add model name to infotext"), + "add_model_hash_to_info": OptionInfo(True, "Add model hash to infotext"), + "add_vae_name_to_info": OptionInfo(True, "Add VAE name to infotext"), + "add_vae_hash_to_info": OptionInfo(True, "Add VAE hash to infotext"), + "add_user_name_to_info": OptionInfo(False, "Add user name to infotext when authenticated"), + "add_version_to_infotext": OptionInfo(True, "Add program version to infotext"), "disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"), + "infotext_skip_pasting": OptionInfo([], "Disregard fields from pasted infotext", ui_components.DropdownMulti, lambda: {"choices": shared_items.get_infotext_names()}), "infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""
  • Ignore: keep prompt and styles dropdown as it is.
  • Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).
  • @@ -282,7 +320,7 @@ })) -options_templates.update(options_section(('ui', "Live previews"), { +options_templates.update(options_section(('ui', "Live previews", "ui"), { "show_progressbar": OptionInfo(True, "Show progressbar"), "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), "live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}), @@ -293,9 +331,10 @@ "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), + "js_live_preview_in_modal_lightbox": OptionInfo(False, "Show Live preview in full page image viewer"), })) -options_templates.update(options_section(('sampler-params', "Sampler parameters"), { +options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), { "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(), "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"), "eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"), @@ -305,8 +344,8 @@ 's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'), 'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"), - 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"), - 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"), + 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule min sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"), + 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule max sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"), 'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"), 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"), @@ -317,10 +356,11 @@ 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'), })) -options_templates.update(options_section(('postprocessing', "Postprocessing"), { +options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), { 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + 'postprocessing_existing_caption_action': OptionInfo("Ignore", "Action for existing captions", gr.Radio, {"choices": ["Ignore", "Keep", "Prepend", "Append"]}).info("when generating captions using postprocessing; Ignore = use generated; Keep = use original; Prepend/Append = combine both"), })) options_templates.update(options_section((None, "Hidden options"), { @@ -329,4 +369,3 @@ "restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"), "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), })) - diff --git a/modules/shared_state.py b/modules/shared_state.py index 91740d38fb1..e18a4b20d5e 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -103,6 +103,7 @@ def dict(self): def begin(self, job: str = "(unknown)"): self.sampling_step = 0 + self.time_start = time.time() self.job_count = -1 self.processing_has_refined_job_count = False self.job_no = 0 @@ -114,7 +115,6 @@ def begin(self, job: str = "(unknown)"): self.skipped = False self.interrupted = False self.textinfo = None - self.time_start = time.time() self.job = job devices.torch_gc() log.info("Starting job %s", job) diff --git a/modules/styles.py b/modules/styles.py index 0740fe1b1c0..81d9800d184 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -1,7 +1,7 @@ import csv +import fnmatch import os import os.path -import re import typing import shutil @@ -10,6 +10,7 @@ class PromptStyle(typing.NamedTuple): name: str prompt: str negative_prompt: str + path: str = None def merge_prompts(style_prompt: str, prompt: str) -> str: @@ -29,38 +30,61 @@ def apply_styles_to_prompt(prompt, styles): return prompt -re_spaces = re.compile(" +") +def unwrap_style_text_from_prompt(style_text, prompt): + """ + Checks the prompt to see if the style text is wrapped around it. If so, + returns True plus the prompt text without the style text. Otherwise, returns + False with the original prompt. - -def extract_style_text_from_prompt(style_text, prompt): - stripped_prompt = re.sub(re_spaces, " ", prompt.strip()) - stripped_style_text = re.sub(re_spaces, " ", style_text.strip()) + Note that the "cleaned" version of the style text is only used for matching + purposes here. It isn't returned; the original style text is not modified. + """ + stripped_prompt = prompt + stripped_style_text = style_text if "{prompt}" in stripped_style_text: - left, right = stripped_style_text.split("{prompt}", 2) + # Work out whether the prompt is wrapped in the style text. If so, we + # return True and the "inner" prompt text that isn't part of the style. + try: + left, right = stripped_style_text.split("{prompt}", 2) + except ValueError as e: + # If the style text has multple "{prompt}"s, we can't split it into + # two parts. This is an error, but we can't do anything about it. + print(f"Unable to compare style text to prompt:\n{style_text}") + print(f"Error: {e}") + return False, prompt if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): - prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)] + prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)] return True, prompt else: + # Work out whether the given prompt ends with the style text. If so, we + # return True and the prompt text up to where the style text starts. if stripped_prompt.endswith(stripped_style_text): - prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)] - - if prompt.endswith(', '): + prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)] + if prompt.endswith(", "): prompt = prompt[:-2] - return True, prompt return False, prompt -def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt): +def extract_original_prompts(style: PromptStyle, prompt, negative_prompt): + """ + Takes a style and compares it to the prompt and negative prompt. If the style + matches, returns True plus the prompt and negative prompt with the style text + removed. Otherwise, returns False with the original prompt and negative prompt. + """ if not style.prompt and not style.negative_prompt: return False, prompt, negative_prompt - match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt) + match_positive, extracted_positive = unwrap_style_text_from_prompt( + style.prompt, prompt + ) if not match_positive: return False, prompt, negative_prompt - match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt) + match_negative, extracted_negative = unwrap_style_text_from_prompt( + style.negative_prompt, negative_prompt + ) if not match_negative: return False, prompt, negative_prompt @@ -69,25 +93,84 @@ def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt): class StyleDatabase: def __init__(self, path: str): - self.no_style = PromptStyle("None", "", "") + self.no_style = PromptStyle("None", "", "", None) self.styles = {} self.path = path + folder, file = os.path.split(self.path) + filename, _, ext = file.partition('*') + self.default_path = os.path.join(folder, filename + ext) + + self.prompt_fields = [field for field in PromptStyle._fields if field != "path"] + self.reload() def reload(self): + """ + Clears the style database and reloads the styles from the CSV file(s) + matching the path used to initialize the database. + """ self.styles.clear() - if not os.path.exists(self.path): + path, filename = os.path.split(self.path) + + if "*" in filename: + fileglob = filename.split("*")[0] + "*.csv" + filelist = [] + for file in os.listdir(path): + if fnmatch.fnmatch(file, fileglob): + filelist.append(file) + # Add a visible divider to the style list + half_len = round(len(file) / 2) + divider = f"{'-' * (20 - half_len)} {file.upper()}" + divider = f"{divider} {'-' * (40 - len(divider))}" + self.styles[divider] = PromptStyle( + f"{divider}", None, None, "do_not_save" + ) + # Add styles from this CSV file + self.load_from_csv(os.path.join(path, file)) + if len(filelist) == 0: + print(f"No styles found in {path} matching {fileglob}") + return + elif not os.path.exists(self.path): + print(f"Style database not found: {self.path}") return + else: + self.load_from_csv(self.path) - with open(self.path, "r", encoding="utf-8-sig", newline='') as file: + def load_from_csv(self, path: str): + with open(path, "r", encoding="utf-8-sig", newline="") as file: reader = csv.DictReader(file, skipinitialspace=True) for row in reader: + # Ignore empty rows or rows starting with a comment + if not row or row["name"].startswith("#"): + continue # Support loading old CSV format with "name, text"-columns prompt = row["prompt"] if "prompt" in row else row["text"] negative_prompt = row.get("negative_prompt", "") - self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt) + # Add style to database + self.styles[row["name"]] = PromptStyle( + row["name"], prompt, negative_prompt, path + ) + + def get_style_paths(self) -> set: + """Returns a set of all distinct paths of files that styles are loaded from.""" + # Update any styles without a path to the default path + for style in list(self.styles.values()): + if not style.path: + self.styles[style.name] = style._replace(path=self.default_path) + + # Create a list of all distinct paths, including the default path + style_paths = set() + style_paths.add(self.default_path) + for _, style in self.styles.items(): + if style.path: + style_paths.add(style.path) + + # Remove any paths for styles that are just list dividers + style_paths.discard("do_not_save") + + return style_paths def get_style_prompts(self, styles): return [self.styles.get(x, self.no_style).prompt for x in styles] @@ -96,20 +179,40 @@ def get_negative_style_prompts(self, styles): return [self.styles.get(x, self.no_style).negative_prompt for x in styles] def apply_styles_to_prompt(self, prompt, styles): - return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles]) + return apply_styles_to_prompt( + prompt, [self.styles.get(x, self.no_style).prompt for x in styles] + ) def apply_negative_styles_to_prompt(self, prompt, styles): - return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]) - - def save_styles(self, path: str) -> None: - # Always keep a backup file around - if os.path.exists(path): - shutil.copy(path, f"{path}.bak") - - with open(path, "w", encoding="utf-8-sig", newline='') as file: - writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) - writer.writeheader() - writer.writerows(style._asdict() for k, style in self.styles.items()) + return apply_styles_to_prompt( + prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles] + ) + + def save_styles(self, path: str = None) -> None: + # The path argument is deprecated, but kept for backwards compatibility + _ = path + + style_paths = self.get_style_paths() + + csv_names = [os.path.split(path)[1].lower() for path in style_paths] + + for style_path in style_paths: + # Always keep a backup file around + if os.path.exists(style_path): + shutil.copy(style_path, f"{style_path}.bak") + + # Write the styles to the CSV file + with open(style_path, "w", encoding="utf-8-sig", newline="") as file: + writer = csv.DictWriter(file, fieldnames=self.prompt_fields) + writer.writeheader() + for style in (s for s in self.styles.values() if s.path == style_path): + # Skip style list dividers, e.g. "STYLES.CSV" + if style.name.lower().strip("# ") in csv_names: + continue + # Write style fields, ignoring the path field + writer.writerow( + {k: v for k, v in style._asdict().items() if k != "path"} + ) def extract_styles_from_prompt(self, prompt, negative_prompt): extracted = [] @@ -120,7 +223,9 @@ def extract_styles_from_prompt(self, prompt, negative_prompt): found_style = None for style in applicable_styles: - is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt) + is_match, new_prompt, new_neg_prompt = extract_original_prompts( + style, prompt, negative_prompt + ) if is_match: found_style = style prompt = new_prompt diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index ae4ee4bbec0..4cb561ef207 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -15,7 +15,7 @@ from torch import Tensor from torch.utils.checkpoint import checkpoint import math -from typing import Optional, NamedTuple, List +from typing import Optional, NamedTuple def narrow_trunc( @@ -97,7 +97,7 @@ def chunk_scanner(chunk_idx: int) -> AttnChunk: ) return summarize_chunk(query, key_chunk, value_chunk) - chunks: List[AttnChunk] = [ + chunks: list[AttnChunk] = [ chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) ] acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) diff --git a/modules/sysinfo.py b/modules/sysinfo.py index 2db7551dcf0..b669edd0cfd 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -1,7 +1,6 @@ import json import os import sys -import traceback import platform import hashlib @@ -84,7 +83,7 @@ def get_dict(): "Checksum": checksum_token, "Commandline": get_argv(), "Torch env info": get_torch_sysinfo(), - "Exceptions": get_exceptions(), + "Exceptions": errors.get_exceptions(), "CPU": { "model": platform.processor(), "count logical": psutil.cpu_count(logical=True), @@ -104,21 +103,6 @@ def get_dict(): return res -def format_traceback(tb): - return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)] - - -def format_exception(e, tb): - return {"exception": str(e), "traceback": format_traceback(tb)} - - -def get_exceptions(): - try: - return list(reversed(errors.exception_records)) - except Exception as e: - return str(e) - - def get_environment(): return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist} diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 1675e39a54b..e223a2e0cc9 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -3,6 +3,8 @@ import os import numpy as np from PIL import ImageDraw +from modules import paths_internal +from pkg_resources import parse_version GREEN = "#0F0" BLUE = "#00F" @@ -25,7 +27,6 @@ def crop_image(im, settings): elif is_portrait(settings.crop_width, settings.crop_height): scale_by = settings.crop_height / im.height - im = im.resize((int(im.width * scale_by), int(im.height * scale_by))) im_debug = im.copy() @@ -69,6 +70,7 @@ def crop_image(im, settings): return results + def focal_point(im, settings): corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else [] entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else [] @@ -78,118 +80,120 @@ def focal_point(im, settings): weight_pref_total = 0 if corner_points: - weight_pref_total += settings.corner_points_weight + weight_pref_total += settings.corner_points_weight if entropy_points: - weight_pref_total += settings.entropy_points_weight + weight_pref_total += settings.entropy_points_weight if face_points: - weight_pref_total += settings.face_points_weight + weight_pref_total += settings.face_points_weight corner_centroid = None if corner_points: - corner_centroid = centroid(corner_points) - corner_centroid.weight = settings.corner_points_weight / weight_pref_total - pois.append(corner_centroid) + corner_centroid = centroid(corner_points) + corner_centroid.weight = settings.corner_points_weight / weight_pref_total + pois.append(corner_centroid) entropy_centroid = None if entropy_points: - entropy_centroid = centroid(entropy_points) - entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total - pois.append(entropy_centroid) + entropy_centroid = centroid(entropy_points) + entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total + pois.append(entropy_centroid) face_centroid = None if face_points: - face_centroid = centroid(face_points) - face_centroid.weight = settings.face_points_weight / weight_pref_total - pois.append(face_centroid) + face_centroid = centroid(face_points) + face_centroid.weight = settings.face_points_weight / weight_pref_total + pois.append(face_centroid) average_point = poi_average(pois, settings) if settings.annotate_image: - d = ImageDraw.Draw(im) - max_size = min(im.width, im.height) * 0.07 - if corner_centroid is not None: - color = BLUE - box = corner_centroid.bounding(max_size * corner_centroid.weight) - d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color) - d.ellipse(box, outline=color) - if len(corner_points) > 1: - for f in corner_points: - d.rectangle(f.bounding(4), outline=color) - if entropy_centroid is not None: - color = "#ff0" - box = entropy_centroid.bounding(max_size * entropy_centroid.weight) - d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color) - d.ellipse(box, outline=color) - if len(entropy_points) > 1: - for f in entropy_points: - d.rectangle(f.bounding(4), outline=color) - if face_centroid is not None: - color = RED - box = face_centroid.bounding(max_size * face_centroid.weight) - d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color) - d.ellipse(box, outline=color) - if len(face_points) > 1: - for f in face_points: - d.rectangle(f.bounding(4), outline=color) - - d.ellipse(average_point.bounding(max_size), outline=GREEN) + d = ImageDraw.Draw(im) + max_size = min(im.width, im.height) * 0.07 + if corner_centroid is not None: + color = BLUE + box = corner_centroid.bounding(max_size * corner_centroid.weight) + d.text((box[0], box[1] - 15), f"Edge: {corner_centroid.weight:.02f}", fill=color) + d.ellipse(box, outline=color) + if len(corner_points) > 1: + for f in corner_points: + d.rectangle(f.bounding(4), outline=color) + if entropy_centroid is not None: + color = "#ff0" + box = entropy_centroid.bounding(max_size * entropy_centroid.weight) + d.text((box[0], box[1] - 15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color) + d.ellipse(box, outline=color) + if len(entropy_points) > 1: + for f in entropy_points: + d.rectangle(f.bounding(4), outline=color) + if face_centroid is not None: + color = RED + box = face_centroid.bounding(max_size * face_centroid.weight) + d.text((box[0], box[1] - 15), f"Face: {face_centroid.weight:.02f}", fill=color) + d.ellipse(box, outline=color) + if len(face_points) > 1: + for f in face_points: + d.rectangle(f.bounding(4), outline=color) + + d.ellipse(average_point.bounding(max_size), outline=GREEN) return average_point def image_face_points(im, settings): if settings.dnn_model_path is not None: - detector = cv2.FaceDetectorYN.create( - settings.dnn_model_path, - "", - (im.width, im.height), - 0.9, # score threshold - 0.3, # nms threshold - 5000 # keep top k before nms - ) - faces = detector.detect(np.array(im)) - results = [] - if faces[1] is not None: - for face in faces[1]: - x = face[0] - y = face[1] - w = face[2] - h = face[3] - results.append( - PointOfInterest( - int(x + (w * 0.5)), # face focus left/right is center - int(y + (h * 0.33)), # face focus up/down is close to the top of the head - size = w, - weight = 1/len(faces[1]) - ) - ) - return results + detector = cv2.FaceDetectorYN.create( + settings.dnn_model_path, + "", + (im.width, im.height), + 0.9, # score threshold + 0.3, # nms threshold + 5000 # keep top k before nms + ) + faces = detector.detect(np.array(im)) + results = [] + if faces[1] is not None: + for face in faces[1]: + x = face[0] + y = face[1] + w = face[2] + h = face[3] + results.append( + PointOfInterest( + int(x + (w * 0.5)), # face focus left/right is center + int(y + (h * 0.33)), # face focus up/down is close to the top of the head + size=w, + weight=1 / len(faces[1]) + ) + ) + return results else: - np_im = np.array(im) - gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) - - tries = [ - [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] - ] - for t in tries: - classifier = cv2.CascadeClassifier(t[0]) - minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side - try: - faces = classifier.detectMultiScale(gray, scaleFactor=1.1, - minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) - except Exception: - continue - - if faces: - rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] - return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects] + np_im = np.array(im) + gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) + + tries = [ + [f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01], + [f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05] + ] + for t in tries: + classifier = cv2.CascadeClassifier(t[0]) + minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side + try: + faces = classifier.detectMultiScale(gray, scaleFactor=1.1, + minNeighbors=7, minSize=(minsize, minsize), + flags=cv2.CASCADE_SCALE_IMAGE) + except Exception: + continue + + if faces: + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + return [PointOfInterest((r[0] + r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0] - r[2]), + weight=1 / len(rects)) for r in rects] return [] @@ -198,7 +202,7 @@ def image_corner_points(im, settings): # naive attempt at preventing focal points from collecting at watermarks near the bottom gd = ImageDraw.Draw(grayscale) - gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") + gd.rectangle([0, im.height * .9, im.width, im.height], fill="#999") np_im = np.array(grayscale) @@ -206,7 +210,7 @@ def image_corner_points(im, settings): np_im, maxCorners=100, qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.06, + minDistance=min(grayscale.width, grayscale.height) * 0.06, useHarrisDetector=False, ) @@ -215,8 +219,8 @@ def image_corner_points(im, settings): focal_points = [] for point in points: - x, y = point.ravel() - focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points))) + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y, size=4, weight=1 / len(points))) return focal_points @@ -225,13 +229,13 @@ def image_entropy_points(im, settings): landscape = im.height < im.width portrait = im.height > im.width if landscape: - move_idx = [0, 2] - move_max = im.size[0] + move_idx = [0, 2] + move_max = im.size[0] elif portrait: - move_idx = [1, 3] - move_max = im.size[1] + move_idx = [1, 3] + move_max = im.size[1] else: - return [] + return [] e_max = 0 crop_current = [0, 0, settings.crop_width, settings.crop_height] @@ -241,14 +245,14 @@ def image_entropy_points(im, settings): e = image_entropy(crop) if (e > e_max): - e_max = e - crop_best = list(crop_current) + e_max = e + crop_best = list(crop_current) crop_current[move_idx[0]] += 4 crop_current[move_idx[1]] += 4 - x_mid = int(crop_best[0] + settings.crop_width/2) - y_mid = int(crop_best[1] + settings.crop_height/2) + x_mid = int(crop_best[0] + settings.crop_width / 2) + y_mid = int(crop_best[1] + settings.crop_height / 2) return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)] @@ -294,22 +298,23 @@ def is_square(w, h): return w == h -def download_and_cache_models(dirname): - download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' - model_file_name = 'face_detection_yunet.onnx' +model_dir_opencv = os.path.join(paths_internal.models_path, 'opencv') +if parse_version(cv2.__version__) >= parse_version('4.8'): + model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet_2023mar.onnx') + model_url = 'https://github.com/opencv/opencv_zoo/blob/b6e370b10f641879a87890d44e42173077154a05/models/face_detection_yunet/face_detection_yunet_2023mar.onnx?raw=true' +else: + model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet.onnx') + model_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' - os.makedirs(dirname, exist_ok=True) - cache_file = os.path.join(dirname, model_file_name) - if not os.path.exists(cache_file): - print(f"downloading face detection model from '{download_url}' to '{cache_file}'") - response = requests.get(download_url) - with open(cache_file, "wb") as f: +def download_and_cache_models(): + if not os.path.exists(model_file_path): + os.makedirs(model_dir_opencv, exist_ok=True) + print(f"downloading face detection model from '{model_url}' to '{model_file_path}'") + response = requests.get(model_url) + with open(model_file_path, "wb") as f: f.write(response.content) - - if os.path.exists(cache_file): - return cache_file - return None + return model_file_path class PointOfInterest: diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index aa79dc09843..04dda585cc9 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -181,40 +181,7 @@ def load_from_file(self, path, filename): else: return - - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding - vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} - shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] - vectors = data['clip_g'].shape[0] - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - else: - raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") - - embedding = Embedding(vec, name) - embedding.step = data.get('step', None) - embedding.sd_checkpoint = data.get('sd_checkpoint', None) - embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) - embedding.vectors = vectors - embedding.shape = shape - embedding.filename = path - embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '') + embedding = create_embedding_from_data(data, name, filename=filename, filepath=path) if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) @@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): return fn +def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None): + if 'string_to_param' in data: # textual inversion embeddings + param_dict = data['string_to_param'] + param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding + vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} + shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] + vectors = data['clip_g'].shape[0] + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + embedding.sd_checkpoint = data.get('sd_checkpoint', None) + embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) + embedding.vectors = vectors + embedding.shape = shape + + if filepath: + embedding.filename = filepath + embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '') + + return embedding + + def write_loss(log_directory, filename, step, epoch_len, values): if shared.opts.training_write_csv_every == 0: return @@ -386,7 +392,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert log_directory, "Log directory is empty" -def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_name, preview_cfg_scale, preview_seed, preview_width, preview_height): from modules import processing save_embedding_every = save_embedding_every or 0 @@ -590,7 +596,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st p.prompt = preview_prompt p.negative_prompt = preview_negative_prompt p.steps = preview_steps - p.sampler_name = sd_samplers.samplers[preview_sampler_index].name + p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()] p.cfg_scale = preview_cfg_scale p.seed = preview_seed p.width = preview_width diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index 35c4feeff45..f149ad1f0f2 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -3,7 +3,6 @@ import gradio as gr import modules.textual_inversion.textual_inversion -import modules.textual_inversion.preprocess from modules import sd_hijack, shared @@ -15,12 +14,6 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old): return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" -def preprocess(*args): - modules.textual_inversion.preprocess.preprocess(*args) - - return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", "" - - def train_embedding(*args): assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' diff --git a/modules/txt2img.py b/modules/txt2img.py index 1ee592ad944..e4e18ceb6dd 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -3,7 +3,7 @@ import modules.scripts from modules import processing from modules.generation_parameters_copypaste import create_override_settings_dict -from modules.shared import opts, cmd_opts +from modules.shared import opts import modules.shared as shared from modules.ui import plaintext_to_html import gradio as gr @@ -45,7 +45,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step p.user = request.username - if cmd_opts.enable_console_prompts: + if shared.opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) with closing(p): diff --git a/modules/ui.py b/modules/ui.py index 579bab9800c..d80486dd4ac 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -4,6 +4,7 @@ import sys from functools import reduce import warnings +from contextlib import ExitStack import gradio as gr import gradio.utils @@ -12,7 +13,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import gradio_extensons # noqa: F401 -from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow from modules.paths import script_path from modules.ui_common import create_refresh_button @@ -25,7 +26,6 @@ import modules.textual_inversion.ui as textual_inversion_ui import modules.textual_inversion.textual_inversion as textual_inversion import modules.shared as shared -import modules.images from modules import prompt_parser from modules.sd_hijack import model_hijack from modules.generation_parameters_copypaste import image_from_url_text @@ -151,11 +151,15 @@ def connect_clear_prompt(button): ) -def update_token_counter(text, steps): +def update_token_counter(text, steps, *, is_positive=True): try: text, _ = extra_networks.parse_prompt(text) - _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + if is_positive: + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + else: + prompt_flat_list = [text] + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) except Exception: @@ -169,76 +173,9 @@ def update_token_counter(text, steps): return f"{token_count}/{max_length}" -class Toprow: - """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation""" +def update_negative_prompt_token_counter(text, steps): + return update_token_counter(text, steps, is_positive=False) - def __init__(self, is_img2img): - id_part = "img2img" if is_img2img else "txt2img" - self.id_part = id_part - - with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"): - with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6): - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) - self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False) - - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) - - self.button_interrogate = None - self.button_deepbooru = None - if is_img2img: - with gr.Column(scale=1, elem_classes="interrogate-col"): - self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - - with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"): - with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"): - self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt") - self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip") - self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') - - self.skip.click( - fn=lambda: shared.state.skip(), - inputs=[], - outputs=[], - ) - - self.interrupt.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - with gr.Row(elem_id=f"{id_part}_tools"): - self.paste = ToolButton(value=paste_symbol, elem_id="paste") - self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") - self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False) - - self.token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) - self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - self.negative_token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"]) - self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") - - self.clear_prompt_button.click( - fn=lambda *x: x, - _js="confirm_clear_prompt", - inputs=[self.prompt, self.negative_prompt], - outputs=[self.prompt, self.negative_prompt], - ) - - self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt) - - self.prompt_img.change( - fn=modules.images.image_data, - inputs=[self.prompt_img], - outputs=[self.prompt, self.prompt_img], - show_progress=False, - ) def setup_progressbar(*args, **kwargs): @@ -278,8 +215,8 @@ def apply_setting(key, value): return getattr(opts, key) -def create_output_panel(tabname, outdir): - return ui_common.create_output_panel(tabname, outdir) +def create_output_panel(tabname, outdir, toprow=None): + return ui_common.create_output_panel(tabname, outdir, toprow) def create_sampler_and_steps_selection(choices, tabname): @@ -326,7 +263,7 @@ def create_ui(): scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - toprow = Toprow(is_img2img=False) + toprow = ui_toprow.Toprow(is_img2img=False, is_compact=shared.opts.compact_prompt_box) dummy_component = gr.Label(visible=False) @@ -334,10 +271,17 @@ def create_ui(): extra_tabs.__enter__() with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False): - with gr.Column(variant='compact', elem_id="txt2img_settings"): + with ExitStack() as stack: + if shared.opts.txt2img_settings_accordion: + stack.enter_context(gr.Accordion("Open for Settings", open=False)) + stack.enter_context(gr.Column(variant='compact', elem_id="txt2img_settings")) + scripts.scripts_txt2img.prepare_ui() for category in ordered_ui_categories(): + if category == "prompt": + toprow.create_inline_toprow_prompts() + if category == "sampler": steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img") @@ -348,7 +292,7 @@ def create_ui(): height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): - res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims") + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", tooltip="Switch width/height") if opts.dimensions_and_batch_together: with gr.Column(elem_id="txt2img_column_batch"): @@ -432,7 +376,7 @@ def create_ui(): show_progress=False, ) - txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) + txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow) txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), @@ -533,7 +477,7 @@ def create_ui(): ] toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) - toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) + toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img') ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) @@ -544,13 +488,17 @@ def create_ui(): scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - toprow = Toprow(is_img2img=True) + toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box) extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs") extra_tabs.__enter__() with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False): - with gr.Column(variant='compact', elem_id="img2img_settings"): + with ExitStack() as stack: + if shared.opts.img2img_settings_accordion: + stack.enter_context(gr.Accordion("Open for Settings", open=False)) + stack.enter_context(gr.Column(variant='compact', elem_id="img2img_settings")) + copy_image_buttons = [] copy_image_destinations = {} @@ -567,85 +515,89 @@ def add_copy_image_controls(tab_name, elem): button = gr.Button(title) copy_image_buttons.append((button, name, elem)) - with gr.Tabs(elem_id="mode_img2img"): - img2img_selected_tab = gr.State(0) - - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) - add_copy_image_controls('img2img', init_img) - - with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: - sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color) - add_copy_image_controls('sketch', sketch) - - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color) - add_copy_image_controls('inpaint', init_img_with_mask) - - with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: - inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color) - inpaint_color_sketch_orig = gr.State(None) - add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) - - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state - - inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) - - with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask") - - with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: - hidden = '
    Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML( - "

    Process images in a directory on the same machine where the server is running." + - "
    Use an empty output directory to save pictures normally instead of writing to the output directory." + - f"
    Add inpaint batch mask directory to enable inpaint batch processing." - f"{hidden}

    " - ) - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") - with gr.Accordion("PNG info", open=False): - img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") - img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") - img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") - - img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] - - for i, tab in enumerate(img2img_tabs): - tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab]) - - def copy_image(img): - if isinstance(img, dict) and 'image' in img: - return img['image'] - - return img - - for button, name, elem in copy_image_buttons: - button.click( - fn=copy_image, - inputs=[elem], - outputs=[copy_image_destinations[name]], - ) - button.click( - fn=lambda: None, - _js=f"switch_to_{name.replace(' ', '_')}", - inputs=[], - outputs=[], - ) - - with FormRow(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - scripts.scripts_img2img.prepare_ui() for category in ordered_ui_categories(): + if category == "prompt": + toprow.create_inline_toprow_prompts() + + if category == "image": + with gr.Tabs(elem_id="mode_img2img"): + img2img_selected_tab = gr.State(0) + + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) + add_copy_image_controls('img2img', init_img) + + with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: + sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color) + add_copy_image_controls('sketch', sketch) + + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color) + add_copy_image_controls('inpaint', init_img_with_mask) + + with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: + inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color) + inpaint_color_sketch_orig = gr.State(None) + add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) + + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state + + inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) + + with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask") + + with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: + hidden = '
    Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML( + "

    Process images in a directory on the same machine where the server is running." + + "
    Use an empty output directory to save pictures normally instead of writing to the output directory." + + f"
    Add inpaint batch mask directory to enable inpaint batch processing." + f"{hidden}

    " + ) + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") + with gr.Accordion("PNG info", open=False): + img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") + img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") + img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") + + img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] + + for i, tab in enumerate(img2img_tabs): + tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab]) + + def copy_image(img): + if isinstance(img, dict) and 'image' in img: + return img['image'] + + return img + + for button, name, elem in copy_image_buttons: + button.click( + fn=copy_image, + inputs=[elem], + outputs=[copy_image_destinations[name]], + ) + button.click( + fn=lambda: None, + _js=f"switch_to_{name.replace(' ', '_')}", + inputs=[], + outputs=[], + ) + + with FormRow(): + resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + if category == "sampler": steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") @@ -661,8 +613,8 @@ def copy_image(img): width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): - res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") - detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn") + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height") + detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img") with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by: scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") @@ -683,12 +635,6 @@ def copy_image(img): scale_by.release(**on_change_args) button_update_resize_to.click(**on_change_args) - # the code below is meant to update the resolution label after the image in the image selection UI has changed. - # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests. - # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs. - for component in [init_img, sketch]: - component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) - tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab]) tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab]) @@ -746,20 +692,26 @@ def copy_image(img): with gr.Column(scale=4): inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - def select_img2img_tab(tab): - return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), - - for i, elem in enumerate(img2img_tabs): - elem.select( - fn=lambda tab=i: select_img2img_tab(tab), - inputs=[], - outputs=[inpaint_controls, mask_alpha], - ) - if category not in {"accordions"}: scripts.scripts_img2img.setup_ui_for_section(category) - img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) + # the code below is meant to update the resolution label after the image in the image selection UI has changed. + # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests. + # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs. + for component in [init_img, sketch]: + component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) + + def select_img2img_tab(tab): + return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), + + for i, elem in enumerate(img2img_tabs): + elem.select( + fn=lambda tab=i: select_img2img_tab(tab), + inputs=[], + outputs=[inpaint_controls, mask_alpha], + ) + + img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow) img2img_args = dict( fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), @@ -960,71 +912,6 @@ def select_img2img_tab(tab): with gr.Column(): create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") - with gr.Tab(label="Preprocess images", id="preprocess_images"): - process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") - process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") - process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") - process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") - - with gr.Row(): - process_keep_original_size = gr.Checkbox(label='Keep original size', elem_id="train_process_keep_original_size") - process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") - process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") - process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") - process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop") - process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") - - with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") - - with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") - process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") - - with gr.Column(visible=False) as process_multicrop_col: - gr.Markdown('Each image is center-cropped with an automatically chosen width and height.') - with gr.Row(): - process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim") - process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim") - with gr.Row(): - process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea") - process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea") - with gr.Row(): - process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective") - process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") - run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") - - process_split.change( - fn=lambda show: gr_show(show), - inputs=[process_split], - outputs=[process_split_extra_row], - ) - - process_focal_crop.change( - fn=lambda show: gr_show(show), - inputs=[process_focal_crop], - outputs=[process_focal_crop_row], - ) - - process_multicrop.change( - fn=lambda show: gr_show(show), - inputs=[process_multicrop], - outputs=[process_multicrop_col], - ) - def get_textual_inversion_template_names(): return sorted(textual_inversion.textual_inversion_templates) @@ -1125,42 +1012,6 @@ def get_textual_inversion_template_names(): ] ) - run_preprocess.click( - fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - dummy_component, - process_src, - process_dst, - process_width, - process_height, - preprocess_txt_action, - process_keep_original_size, - process_flip, - process_split, - process_caption, - process_caption_deepbooru, - process_split_threshold, - process_overlap_ratio, - process_focal_crop, - process_focal_crop_face_weight, - process_focal_crop_entropy_weight, - process_focal_crop_edges_weight, - process_focal_crop_debug, - process_multicrop, - process_multicrop_mindim, - process_multicrop_maxdim, - process_multicrop_minarea, - process_multicrop_maxarea, - process_multicrop_objective, - process_multicrop_threshold, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) - train_embedding.click( fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", @@ -1234,12 +1085,6 @@ def get_textual_inversion_template_names(): outputs=[], ) - interrupt_preprocessing.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file) settings = ui_settings.UiSettings() @@ -1286,7 +1131,7 @@ def get_textual_inversion_template_names(): loadsave.setup_ui() - if os.path.exists(os.path.join(script_path, "notification.mp3")): + if os.path.exists(os.path.join(script_path, "notification.mp3")) and shared.opts.notification_audio: gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) footer = shared.html("footer.html") @@ -1338,7 +1183,6 @@ def versions_html(): def setup_ui_api(app): from pydantic import BaseModel, Field - from typing import List class QuicksettingsHint(BaseModel): name: str = Field(title="Name of the quicksettings field") @@ -1347,7 +1191,7 @@ class QuicksettingsHint(BaseModel): def quicksettings_hint(): return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()] - app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint]) + app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=list[QuicksettingsHint]) app.add_api_route("/internal/ping", lambda: {}, methods=["GET"]) @@ -1357,7 +1201,7 @@ def download_sysinfo(attachment=False): from fastapi.responses import PlainTextResponse text = sysinfo.get() - filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt" + filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json" return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'}) diff --git a/modules/ui_common.py b/modules/ui_common.py index 84a7d7f2756..032ec4af762 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -104,7 +104,7 @@ def __init__(self, d=None): return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") -def create_output_panel(tabname, outdir): +def create_output_panel(tabname, outdir, toprow=None): def open_folder(f): if not os.path.exists(f): @@ -130,12 +130,15 @@ def open_folder(f): else: sp.Popen(["xdg-open", path]) - with gr.Column(variant='panel', elem_id=f"{tabname}_results"): - with gr.Group(elem_id=f"{tabname}_gallery_container"): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None) + with gr.Column(elem_id=f"{tabname}_results"): + if toprow: + toprow.create_inline_toprow_image() - generation_info = None - with gr.Column(): + with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"): + with gr.Group(elem_id=f"{tabname}_gallery_container"): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None) + + generation_info = None with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"): open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.") diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 2e8c1d6d21d..dc1e34c8af8 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -65,7 +65,7 @@ def save_config_state(name): filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json") print(f"Saving backup of webui/extension state to {filename}.") with open(filename, "w", encoding="utf-8") as f: - json.dump(current_config_state, f, indent=4) + json.dump(current_config_state, f, indent=4, ensure_ascii=False) config_states.list_config_states() new_value = next(iter(config_states.all_config_states.keys()), "Current") new_choices = ["Current"] + list(config_states.all_config_states.keys()) @@ -197,7 +197,7 @@ def update_config_states_table(state_name): config_state = config_states.all_config_states[state_name] config_name = config_state.get("name", "Config") - created_date = time.asctime(time.gmtime(config_state["created_at"])) + created_date = datetime.fromtimestamp(config_state["created_at"]).strftime('%Y-%m-%d %H:%M:%S') filepath = config_state.get("filepath", "") try: @@ -335,6 +335,11 @@ def normalize_git_url(url): return url +def get_extension_dirname_from_url(url): + *parts, last_part = url.split('/') + return normalize_git_url(last_part) + + def install_extension_from_url(dirname, url, branch_name=None): check_access() @@ -346,10 +351,7 @@ def install_extension_from_url(dirname, url, branch_name=None): assert url, 'No URL specified' if dirname is None or dirname == "": - *parts, last_part = url.split('/') - last_part = normalize_git_url(last_part) - - dirname = last_part + dirname = get_extension_dirname_from_url(url) target_dir = os.path.join(extensions.extensions_dir, dirname) assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' @@ -449,7 +451,8 @@ def get_date(info: dict, key): def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""): extlist = available_extensions["extensions"] - installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} + installed_extensions = {extension.name for extension in extensions.extensions} + installed_extension_urls = {normalize_git_url(extension.remote) for extension in extensions.extensions if extension.remote is not None} tags = available_extensions.get("tags", {}) tags_to_hide = set(hide_tags) @@ -482,7 +485,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" if url is None: continue - existing = installed_extension_urls.get(normalize_git_url(url), None) + existing = get_extension_dirname_from_url(url) in installed_extensions or normalize_git_url(url) in installed_extension_urls extension_tags = extension_tags + ["installed"] if existing else extension_tags if any(x for x in extension_tags if x in tags_to_hide): diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 063bd7b80e6..fe5d3ba3338 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,3 +1,4 @@ +import functools import os.path import urllib.parse from pathlib import Path @@ -15,6 +16,17 @@ extra_pages = [] allowed_dirs = set() +default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] + + +@functools.cache +def allowed_preview_extensions_with_extra(extra_extensions=None): + return set(default_allowed_preview_extensions) | set(extra_extensions or []) + + +def allowed_preview_extensions(): + return allowed_preview_extensions_with_extra((shared.opts.samples_format, )) + def register_page(page): """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" @@ -33,9 +45,9 @@ def fetch_file(filename: str = ""): if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs): raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") - ext = os.path.splitext(filename)[1].lower() - if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"): - raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.") + ext = os.path.splitext(filename)[1].lower()[1:] + if ext not in allowed_preview_extensions(): + raise ValueError(f"File cannot be fetched: {filename}. Extensions allowed: {allowed_preview_extensions()}.") # would profit from returning 304 return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) @@ -91,6 +103,7 @@ def __init__(self, title): self.name = title.lower() self.id_page = self.name.replace(" ", "_") self.card_page = shared.html("extra-networks-card.html") + self.allow_prompt = True self.allow_negative_prompt = False self.metadata = {} self.items = {} @@ -138,8 +151,13 @@ def create_html(self, tabname): continue subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") - while subdir.startswith("/"): - subdir = subdir[1:] + + if shared.opts.extra_networks_dir_button_function: + if not subdir.startswith("/"): + subdir = "/" + subdir + else: + while subdir.startswith("/"): + subdir = subdir[1:] is_empty = len(os.listdir(x)) == 0 if not is_empty and not subdir.endswith("/"): @@ -213,9 +231,9 @@ def create_html_for_item(self, item, tabname): metadata_button = "" metadata = item.get("metadata") if metadata: - metadata_button = f"" + metadata_button = f"" - edit_button = f"
    " + edit_button = f"
    " local_path = "" filename = item.get("filename", "") @@ -235,7 +253,7 @@ def create_html_for_item(self, item, tabname): if search_only and shared.opts.extra_networks_hidden_models == "Never": return "" - sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip() + sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip() args = { "background_image": background_image, @@ -266,6 +284,7 @@ def get_sort_keys(self, path): "date_created": int(stat.st_ctime or 0), "date_modified": int(stat.st_mtime or 0), "name": pth.name.lower(), + "path": str(pth.parent).lower(), } def find_preview(self, path): @@ -273,11 +292,7 @@ def find_preview(self, path): Find a preview PNG for a given path (without extension) and call link_preview on it. """ - preview_extensions = ["png", "jpg", "jpeg", "webp"] - if shared.opts.samples_format not in preview_extensions: - preview_extensions.append(shared.opts.samples_format) - - potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], []) + potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], []) for file in potential_files: if os.path.isfile(file): @@ -359,7 +374,10 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs = [] for page in ui.stored_extra_pages: - with gr.Tab(page.title, id=page.id_page) as tab: + with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab: + with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]): + pass + elem_id = f"{tabname}_{page.id_page}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) @@ -373,19 +391,28 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) - dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") - button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False) + dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") + button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + tab_controls = [edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs] + for tab in unrelated_tabs: - tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False) + tab.select(fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=tab_controls, show_progress=False) + + for page, tab in zip(ui.stored_extra_pages, related_tabs): + allow_prompt = "true" if page.allow_prompt else "false" + allow_negative_prompt = "true" if page.allow_negative_prompt else "false" + + jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}_prompts" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');' + + tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False) - for tab in related_tabs: - tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False) + dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }") def pages_html(): if not ui.pages_contents: diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index ca6c26076f9..1693e71f16f 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -10,11 +10,16 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): def __init__(self): super().__init__('Checkpoints') + self.allow_prompt = False + def refresh(self): shared.refresh_checkpoints() def create_item(self, name, index=None, enable_filter=True): checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name) + if checkpoint is None: + return + path, ext = os.path.splitext(checkpoint.filename) return { "name": checkpoint.name_for_extra, @@ -30,9 +35,12 @@ def create_item(self, name, index=None, enable_filter=True): } def list_items(self): + # instantiate a list to protect against concurrent modification names = list(sd_models.checkpoints_list) for index, name in enumerate(names): - yield self.create_item(name, index) + item = self.create_item(name, index) + if item is not None: + yield item def allowed_directories_for_previews(self): return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 4cedf085196..c96c4fa3b12 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -13,7 +13,10 @@ def refresh(self): shared.reload_hypernetworks() def create_item(self, name, index=None, enable_filter=True): - full_path = shared.hypernetworks[name] + full_path = shared.hypernetworks.get(name) + if full_path is None: + return + path, ext = os.path.splitext(full_path) sha256 = sha256_from_cache(full_path, f'hypernet/{name}') shorthash = sha256[0:10] if sha256 else None @@ -31,8 +34,12 @@ def create_item(self, name, index=None, enable_filter=True): } def list_items(self): - for index, name in enumerate(shared.hypernetworks): - yield self.create_item(name, index) + # instantiate a list to protect against concurrent modification + names = list(shared.hypernetworks) + for index, name in enumerate(names): + item = self.create_item(name, index) + if item is not None: + yield item def allowed_directories_for_previews(self): return [shared.cmd_opts.hypernetwork_dir] diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 55ef0ea7b54..1b334fda174 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -14,6 +14,8 @@ def refresh(self): def create_item(self, name, index=None, enable_filter=True): embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) + if embedding is None: + return path, ext = os.path.splitext(embedding.filename) return { @@ -29,8 +31,12 @@ def create_item(self, name, index=None, enable_filter=True): } def list_items(self): - for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings): - yield self.create_item(name, index) + # instantiate a list to protect against concurrent modification + names = list(sd_hijack.model_hijack.embedding_db.word_embeddings) + for index, name in enumerate(names): + item = self.create_item(name, index) + if item is not None: + yield item def allowed_directories_for_previews(self): return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index bfec140cc73..36a807fcdf9 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -134,7 +134,7 @@ def write_user_metadata(self, name, metadata): basename, ext = os.path.splitext(filename) with open(basename + '.json', "w", encoding="utf8") as file: - json.dump(metadata, file, indent=4) + json.dump(metadata, file, indent=4, ensure_ascii=False) def save_user_metadata(self, name, desc, notes): user_metadata = self.get_user_metadata(name) diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py index b824b113732..0d368f8b2c4 100644 --- a/modules/ui_gradio_extensions.py +++ b/modules/ui_gradio_extensions.py @@ -2,12 +2,12 @@ import gradio as gr from modules import localization, shared, scripts -from modules.paths import script_path, data_path +from modules.paths import script_path, data_path, cwd def webpath(fn): - if fn.startswith(script_path): - web_path = os.path.relpath(fn, script_path).replace('\\', '/') + if fn.startswith(cwd): + web_path = os.path.relpath(fn, cwd) else: web_path = os.path.abspath(fn) diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py index ec8fa8e89e3..7826786ccde 100644 --- a/modules/ui_loadsave.py +++ b/modules/ui_loadsave.py @@ -4,7 +4,7 @@ import gradio as gr from modules import errors -from modules.ui_components import ToolButton +from modules.ui_components import ToolButton, InputAccordion def radio_choices(comp): # gradio 3.41 changes choices from list of values to list of pairs @@ -32,8 +32,6 @@ def __init__(self, filename): self.error_loading = True errors.display(e, "loading settings") - - def add_component(self, path, x): """adds component to the registry of tracked components""" @@ -43,20 +41,24 @@ def apply_field(obj, field, condition=None, init_field=None): key = f"{path}/{field}" if getattr(obj, 'custom_script_source', None) is not None: - key = f"customscript/{obj.custom_script_source}/{key}" + key = f"customscript/{obj.custom_script_source}/{key}" if getattr(obj, 'do_not_save_to_config', False): return saved_value = self.ui_settings.get(key, None) + + if isinstance(obj, gr.Accordion) and isinstance(x, InputAccordion) and field == 'value': + field = 'open' + if saved_value is None: self.ui_settings[key] = getattr(obj, field) elif condition and not condition(saved_value): pass else: - if isinstance(x, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies + if isinstance(obj, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies saved_value = str(saved_value) - elif isinstance(x, gr.Number) and field == 'value': + elif isinstance(obj, gr.Number) and field == 'value': try: saved_value = float(saved_value) except ValueError: @@ -67,7 +69,7 @@ def apply_field(obj, field, condition=None, init_field=None): init_field(saved_value) if field == 'value' and key not in self.component_mapping: - self.component_mapping[key] = x + self.component_mapping[key] = obj if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible: apply_field(x, 'visible') @@ -100,6 +102,12 @@ def check_dropdown(val): apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None)) + if type(x) == InputAccordion: + if x.accordion.visible: + apply_field(x.accordion, 'visible') + apply_field(x, 'value') + apply_field(x.accordion, 'value') + def check_tab_id(tab_id): tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children)) if type(tab_id) == str: @@ -133,7 +141,7 @@ def read_from_file(self): def write_to_file(self, current_ui_settings): with open(self.filename, "w", encoding="utf8") as file: - json.dump(current_ui_settings, file, indent=4) + json.dump(current_ui_settings, file, indent=4, ensure_ascii=False) def dump_defaults(self): """saves default values to a file unless tjhe file is present and there was an error loading default values at start""" diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 802e1ce71a1..13d888e48d1 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -1,9 +1,10 @@ import gradio as gr -from modules import scripts, shared, ui_common, postprocessing, call_queue +from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow import modules.generation_parameters_copypaste as parameters_copypaste def create_ui(): + dummy_component = gr.Label(visible=False) tab_index = gr.State(value=0) with gr.Row(equal_height=False, variant='compact'): @@ -20,11 +21,13 @@ def create_ui(): extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - script_inputs = scripts.scripts_postproc.setup_ui() with gr.Column(): + toprow = ui_toprow.Toprow(is_compact=True, is_img2img=False, id_part="extras") + toprow.create_inline_toprow_image() + submit = toprow.submit + result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) @@ -32,8 +35,10 @@ def create_ui(): tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) submit.click( - fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), + fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing_webui, extra_outputs=[None, '']), + _js="submit_extras", inputs=[ + dummy_component, tab_index, extras_image, image_batch, @@ -45,8 +50,9 @@ def create_ui(): outputs=[ result_images, html_info_x, - html_info, - ] + html_log, + ], + show_progress=False, ) parameters_copypaste.add_paste_fields("extras", extras_image, None) diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py index e2a88a41193..02bd41e2a3a 100644 --- a/modules/ui_prompt_styles.py +++ b/modules/ui_prompt_styles.py @@ -4,6 +4,7 @@ styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️ styles_materialize_symbol = '\U0001f4cb' # 📋 +styles_copy_symbol = '\U0001f4dd' # 📝 def select_style(name): @@ -52,6 +53,8 @@ def refresh_styles(): class UiPromptStyles: def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): self.tabname = tabname + self.main_ui_prompt = main_ui_prompt + self.main_ui_negative_prompt = main_ui_negative_prompt with gr.Row(elem_id=f"{tabname}_styles_row"): self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles") @@ -61,13 +64,14 @@ def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): with gr.Row(): self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.") ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles") - self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.") + self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply_dialog", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.") + self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.") with gr.Row(): - self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3) + self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3, elem_classes=["prompt"]) with gr.Row(): - self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3) + self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3, elem_classes=["prompt"]) with gr.Row(): self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False) @@ -96,15 +100,21 @@ def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): show_progress=False, ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False) - self.materialize.click( - fn=materialize_styles, - inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], - outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], + self.setup_apply_button(self.materialize) + + self.copy.click( + fn=lambda p, n: (p, n), + inputs=[main_ui_prompt, main_ui_negative_prompt], + outputs=[self.prompt, self.neg_prompt], show_progress=False, - ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False) + ) ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close) - - - + def setup_apply_button(self, button): + button.click( + fn=materialize_styles, + inputs=[self.main_ui_prompt, self.main_ui_negative_prompt, self.dropdown], + outputs=[self.main_ui_prompt, self.main_ui_negative_prompt, self.dropdown], + show_progress=False, + ).then(fn=None, _js="function(){update_"+self.tabname+"_tokens(); closePopup();}", show_progress=False) diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 8ff9c074718..e054d00ab04 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -1,10 +1,11 @@ import gradio as gr -from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo +from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer from modules.call_queue import wrap_gradio_call from modules.shared import opts from modules.ui_components import FormRow from modules.ui_gradio_extensions import reload_javascript +from concurrent.futures import ThreadPoolExecutor, as_completed def get_value_for_setting(key): @@ -63,6 +64,9 @@ class UiSettings: quicksettings_list = None quicksettings_names = None text_settings = None + show_all_pages = None + show_one_page = None + search_input = None def run_settings(self, *args): changed = [] @@ -135,7 +139,7 @@ def create_ui(self, loadsave, dummy_component): gr.Group() current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text) current_tab.__enter__() - current_row = gr.Column(variant='compact') + current_row = gr.Column(elem_id=f"column_settings_{elem_id}", variant='compact') current_row.__enter__() previous_section = item.section @@ -173,26 +177,43 @@ def create_ui(self, loadsave, dummy_component): download_localization = gr.Button(value='Download localization template', elem_id="download_localization") reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") with gr.Row(): - unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") - reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") + unload_sd_model = gr.Button(value='Unload SD checkpoint to RAM', elem_id="sett_unload_sd_model") + reload_sd_model = gr.Button(value='Load SD checkpoint to VRAM from RAM', elem_id="sett_reload_sd_model") + with gr.Row(): + calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash") + calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1) with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"): gr.HTML(shared.html("licenses.html"), elem_id="licenses") - gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + self.show_all_pages = gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + self.show_one_page = gr.Button(value="Show only one page", elem_id="settings_show_one_page", visible=False) + self.show_one_page.click(lambda: None) + + self.search_input = gr.Textbox(value="", elem_id="settings_search", max_lines=1, placeholder="Search...", show_label=False) self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + def call_func_and_return_text(func, text): + def handler(): + t = timer.Timer() + func() + t.record(text) + + return f'{text} in {t.total:.1f}s' + + return handler + unload_sd_model.click( - fn=sd_models.unload_model_weights, + fn=call_func_and_return_text(sd_models.unload_model_weights, 'Unloaded the checkpoint'), inputs=[], - outputs=[] + outputs=[self.result] ) reload_sd_model.click( - fn=sd_models.reload_model_weights, + fn=call_func_and_return_text(lambda: sd_models.send_model_to_device(shared.sd_model), 'Loaded the checkpoint'), inputs=[], - outputs=[] + outputs=[self.result] ) request_notifications.click( @@ -241,6 +262,21 @@ def check_file(x): outputs=[sysinfo_check_output], ) + def calculate_all_checkpoint_hash_fn(max_thread): + checkpoints_list = sd_models.checkpoints_list.values() + with ThreadPoolExecutor(max_workers=max_thread) as executor: + futures = [executor.submit(checkpoint.calculate_shorthash) for checkpoint in checkpoints_list] + completed = 0 + for _ in as_completed(futures): + completed += 1 + print(f"{completed} / {len(checkpoints_list)} ") + print("Finish calculating hash for all checkpoints") + + calculate_all_checkpoint_hash.click( + fn=calculate_all_checkpoint_hash_fn, + inputs=[calculate_all_checkpoint_hash_threads], + ) + self.interface = settings_interface def add_quicksettings(self): @@ -294,3 +330,8 @@ def get_settings_values(): outputs=[self.component_dict[k] for k in component_keys], queue=False, ) + + def search(self, text): + print(text) + + return [gr.update(visible=text in (comp.label or "")) for comp in self.components] diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py new file mode 100644 index 00000000000..9effb3c5bab --- /dev/null +++ b/modules/ui_toprow.py @@ -0,0 +1,143 @@ +import gradio as gr + +from modules import shared, ui_prompt_styles +import modules.images + +from modules.ui_components import ToolButton + + +class Toprow: + """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation""" + + prompt = None + prompt_img = None + negative_prompt = None + + button_interrogate = None + button_deepbooru = None + + interrupt = None + skip = None + submit = None + + paste = None + clear_prompt_button = None + apply_styles = None + restore_progress_button = None + + token_counter = None + token_button = None + negative_token_counter = None + negative_token_button = None + + ui_styles = None + + submit_box = None + + def __init__(self, is_img2img, is_compact=False, id_part=None): + if id_part is None: + id_part = "img2img" if is_img2img else "txt2img" + + self.id_part = id_part + self.is_img2img = is_img2img + self.is_compact = is_compact + + if not is_compact: + with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"): + self.create_classic_toprow() + else: + self.create_submit_box() + + def create_classic_toprow(self): + self.create_prompts() + + with gr.Column(scale=1, elem_id=f"{self.id_part}_actions_column"): + self.create_submit_box() + + self.create_tools_row() + + self.create_styles_ui() + + def create_inline_toprow_prompts(self): + if not self.is_compact: + return + + self.create_prompts() + + with gr.Row(elem_classes=["toprow-compact-stylerow"]): + with gr.Column(elem_classes=["toprow-compact-tools"]): + self.create_tools_row() + with gr.Column(): + self.create_styles_ui() + + def create_inline_toprow_image(self): + if not self.is_compact: + return + + self.submit_box.render() + + def create_prompts(self): + with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6): + with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]): + self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False) + + with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]): + self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + + self.prompt_img.change( + fn=modules.images.image_data, + inputs=[self.prompt_img], + outputs=[self.prompt, self.prompt_img], + show_progress=False, + ) + + def create_submit_box(self): + with gr.Row(elem_id=f"{self.id_part}_generate_box", elem_classes=["generate-box"] + (["generate-box-compact"] if self.is_compact else []), render=not self.is_compact) as submit_box: + self.submit_box = submit_box + + self.interrupt = gr.Button('Interrupt', elem_id=f"{self.id_part}_interrupt", elem_classes="generate-box-interrupt") + self.skip = gr.Button('Skip', elem_id=f"{self.id_part}_skip", elem_classes="generate-box-skip") + self.submit = gr.Button('Generate', elem_id=f"{self.id_part}_generate", variant='primary') + + self.skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + + self.interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + def create_tools_row(self): + with gr.Row(elem_id=f"{self.id_part}_tools"): + from modules.ui import paste_symbol, clear_prompt_symbol, restore_progress_symbol + + self.paste = ToolButton(value=paste_symbol, elem_id="paste", tooltip="Read generation parameters from prompt or last generation if prompt is empty into user interface.") + self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{self.id_part}_clear_prompt", tooltip="Clear prompt") + self.apply_styles = ToolButton(value=ui_prompt_styles.styles_materialize_symbol, elem_id=f"{self.id_part}_style_apply", tooltip="Apply all selected styles to prompts.") + + if self.is_img2img: + self.button_interrogate = ToolButton('📎', tooltip='Interrogate CLIP - use CLIP neural network to create a text describing the image, and put it into the prompt field', elem_id="interrogate") + self.button_deepbooru = ToolButton('📦', tooltip='Interrogate DeepBooru - use DeepBooru neural network to create a text describing the image, and put it into the prompt field', elem_id="deepbooru") + + self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{self.id_part}_restore_progress", visible=False, tooltip="Restore progress") + + self.token_counter = gr.HTML(value="0/75", elem_id=f"{self.id_part}_token_counter", elem_classes=["token-counter"]) + self.token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_token_button") + self.negative_token_counter = gr.HTML(value="0/75", elem_id=f"{self.id_part}_negative_token_counter", elem_classes=["token-counter"]) + self.negative_token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_negative_token_button") + + self.clear_prompt_button.click( + fn=lambda *x: x, + _js="confirm_clear_prompt", + inputs=[self.prompt, self.negative_prompt], + outputs=[self.prompt, self.negative_prompt], + ) + + def create_styles_ui(self): + self.ui_styles = ui_prompt_styles.UiPromptStyles(self.id_part, self.prompt, self.negative_prompt) + self.ui_styles.setup_apply_button(self.apply_styles) diff --git a/modules/upscaler.py b/modules/upscaler.py index e682bbaa26c..b256e085b6d 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -57,6 +57,9 @@ def upscale(self, img: PIL.Image, scale, selected_model: str = None): dest_h = int((img.height * scale) // 8 * 8) for _ in range(3): + if img.width >= dest_w and img.height >= dest_h: + break + shape = (img.width, img.height) img = self.do_upscale(img, selected_model) @@ -64,9 +67,6 @@ def upscale(self, img: PIL.Image, scale, selected_model: str = None): if shape == (img.width, img.height): break - if img.width >= dest_w and img.height >= dest_h: - break - if img.width != dest_w or img.height != dest_h: img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py new file mode 100644 index 00000000000..a727e865529 --- /dev/null +++ b/modules/xlmr_m18.py @@ -0,0 +1,164 @@ +from transformers import BertPreTrainedModel,BertConfig +import torch.nn as nn +import torch +from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig +from transformers import XLMRobertaModel,XLMRobertaTokenizer +from typing import Optional + +class BertSeriesConfig(BertConfig): + def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): + + super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + + +class BertSeriesModelWithTransformation(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + config_class = BertSeriesConfig + + def __init__(self, config=None, **kargs): + # modify initialization for autoloading + if config is None: + config = XLMRobertaConfig() + config.attention_probs_dropout_prob= 0.1 + config.bos_token_id=0 + config.eos_token_id=2 + config.hidden_act='gelu' + config.hidden_dropout_prob=0.1 + config.hidden_size=1024 + config.initializer_range=0.02 + config.intermediate_size=4096 + config.layer_norm_eps=1e-05 + config.max_position_embeddings=514 + + config.num_attention_heads=16 + config.num_hidden_layers=24 + config.output_past=True + config.pad_token_id=1 + config.position_embedding_type= "absolute" + + config.type_vocab_size= 1 + config.use_cache=True + config.vocab_size= 250002 + config.project_dim = 1024 + config.learn_encoder = False + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size,config.project_dim) + # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') + # self.pooler = lambda x: x[:,0] + # self.post_init() + + self.has_pre_transformation = True + if self.has_pre_transformation: + self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) + self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_init() + + def encode(self,c): + device = next(self.parameters()).device + text = self.tokenizer(c, + truncation=True, + max_length=77, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt") + text["input_ids"] = torch.tensor(text["input_ids"]).to(device) + text["attention_mask"] = torch.tensor( + text['attention_mask']).to(device) + features = self(**text) + return features['projection_state'] + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) : + r""" + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + # # last module outputs + # sequence_output = outputs[0] + + + # # project every module + # sequence_output_ln = self.pre_LN(sequence_output) + + # # pooler + # pooler_output = self.pooler(sequence_output_ln) + # pooler_output = self.transformation(pooler_output) + # projection_state = self.transformation(outputs.last_hidden_state) + + if self.has_pre_transformation: + sequence_output2 = outputs["hidden_states"][-2] + sequence_output2 = self.pre_LN(sequence_output2) + projection_state2 = self.transformation_pre(sequence_output2) + + return { + "projection_state": projection_state2, + "last_hidden_state": outputs.last_hidden_state, + "hidden_states": outputs.hidden_states, + "attentions": outputs.attentions, + } + else: + projection_state = self.transformation(outputs.last_hidden_state) + return { + "projection_state": projection_state, + "last_hidden_state": outputs.last_hidden_state, + "hidden_states": outputs.hidden_states, + "attentions": outputs.attentions, + } + + + # return { + # 'pooler_output':pooler_output, + # 'last_hidden_state':outputs.last_hidden_state, + # 'hidden_states':outputs.hidden_states, + # 'attentions':outputs.attentions, + # 'projection_state':projection_state, + # 'sequence_out': sequence_output + # } + + +class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): + base_model_prefix = 'roberta' + config_class= RobertaSeriesConfig diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py new file mode 100644 index 00000000000..d8da94a0efd --- /dev/null +++ b/modules/xpu_specific.py @@ -0,0 +1,59 @@ +from modules import shared +from modules.sd_hijack_utils import CondFunc + +has_ipex = False +try: + import torch + import intel_extension_for_pytorch as ipex # noqa: F401 + has_ipex = True +except Exception: + pass + + +def check_for_xpu(): + return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available() + + +def get_xpu_device_string(): + if shared.cmd_opts.device_id is not None: + return f"xpu:{shared.cmd_opts.device_id}" + return "xpu" + + +def torch_xpu_gc(): + with torch.xpu.device(get_xpu_device_string()): + torch.xpu.empty_cache() + + +has_xpu = check_for_xpu() + +if has_xpu: + # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: device is not None and device.type == "xpu") + + # W/A for some OPs that could not handle different input dtypes + CondFunc('torch.nn.functional.layer_norm', + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + weight is not None and input.dtype != weight.data.dtype) + CondFunc('torch.nn.modules.GroupNorm.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.modules.linear.Linear.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.modules.conv.Conv2d.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.bmm', + lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), + lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) + CondFunc('torch.cat', + lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), + lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) + CondFunc('torch.nn.functional.scaled_dot_product_attention', + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) diff --git a/scripts/postprocessing_caption.py b/scripts/postprocessing_caption.py new file mode 100644 index 00000000000..469190e088a --- /dev/null +++ b/scripts/postprocessing_caption.py @@ -0,0 +1,30 @@ +from modules import scripts_postprocessing, ui_components, deepbooru, shared +import gradio as gr + + +class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing): + name = "Caption" + order = 4000 + + def ui(self): + with ui_components.InputAccordion(False, label="Caption") as enable: + option = gr.CheckboxGroup(value=["Deepbooru"], choices=["Deepbooru", "BLIP"], show_label=False) + + return { + "enable": enable, + "option": option, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option): + if not enable: + return + + captions = [pp.caption] + + if "Deepbooru" in option: + captions.append(deepbooru.model.tag(pp.image)) + + if "BLIP" in option: + captions.append(shared.interrogator.generate_caption(pp.image)) + + pp.caption = ", ".join([x for x in captions if x]) diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py index a7d80d40e2b..e1e156ddcaa 100644 --- a/scripts/postprocessing_codeformer.py +++ b/scripts/postprocessing_codeformer.py @@ -1,28 +1,28 @@ from PIL import Image import numpy as np -from modules import scripts_postprocessing, codeformer_model +from modules import scripts_postprocessing, codeformer_model, ui_components import gradio as gr -from modules.ui_components import FormRow - class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): name = "CodeFormer" order = 3000 def ui(self): - with FormRow(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight") + with ui_components.InputAccordion(False, label="CodeFormer") as enable: + with gr.Row(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight") return { + "enable": enable, "codeformer_visibility": codeformer_visibility, "codeformer_weight": codeformer_weight, } - def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): - if codeformer_visibility == 0: + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, codeformer_visibility, codeformer_weight): + if codeformer_visibility == 0 or not enable: return restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) diff --git a/scripts/postprocessing_create_flipped_copies.py b/scripts/postprocessing_create_flipped_copies.py new file mode 100644 index 00000000000..55c18fcf48f --- /dev/null +++ b/scripts/postprocessing_create_flipped_copies.py @@ -0,0 +1,32 @@ +from PIL import ImageOps, Image + +from modules import scripts_postprocessing, ui_components +import gradio as gr + + +class ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing): + name = "Create flipped copies" + order = 4000 + + def ui(self): + with ui_components.InputAccordion(False, label="Create flipped copies") as enable: + with gr.Row(): + option = gr.CheckboxGroup(value=["Horizontal"], choices=["Horizontal", "Vertical", "Both"], show_label=False) + + return { + "enable": enable, + "option": option, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option): + if not enable: + return + + if "Horizontal" in option: + pp.extra_images.append(ImageOps.mirror(pp.image)) + + if "Vertical" in option: + pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)) + + if "Both" in option: + pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).transpose(Image.Transpose.FLIP_LEFT_RIGHT)) diff --git a/scripts/postprocessing_focal_crop.py b/scripts/postprocessing_focal_crop.py new file mode 100644 index 00000000000..d67b3a1556e --- /dev/null +++ b/scripts/postprocessing_focal_crop.py @@ -0,0 +1,54 @@ + +from modules import scripts_postprocessing, ui_components, errors +import gradio as gr + +from modules.textual_inversion import autocrop + + +class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing): + name = "Auto focal point crop" + order = 4000 + + def ui(self): + with ui_components.InputAccordion(False, label="Auto focal point crop") as enable: + face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_face_weight") + entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_entropy_weight") + edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_edges_weight") + debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") + + return { + "enable": enable, + "face_weight": face_weight, + "entropy_weight": entropy_weight, + "edges_weight": edges_weight, + "debug": debug, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, face_weight, entropy_weight, edges_weight, debug): + if not enable: + return + + if not pp.shared.target_width or not pp.shared.target_height: + return + + dnn_model_path = None + try: + dnn_model_path = autocrop.download_and_cache_models() + except Exception: + errors.report("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", exc_info=True) + + autocrop_settings = autocrop.Settings( + crop_width=pp.shared.target_width, + crop_height=pp.shared.target_height, + face_points_weight=face_weight, + entropy_points_weight=entropy_weight, + corner_points_weight=edges_weight, + annotate_image=debug, + dnn_model_path=dnn_model_path, + ) + + result, *others = autocrop.crop_image(pp.image, autocrop_settings) + + pp.image = result + pp.extra_images = [pp.create_copy(x, nametags=["focal-crop-debug"], disable_processing=True) for x in others] + diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py index d854f3f7748..6e7566055d2 100644 --- a/scripts/postprocessing_gfpgan.py +++ b/scripts/postprocessing_gfpgan.py @@ -1,26 +1,25 @@ from PIL import Image import numpy as np -from modules import scripts_postprocessing, gfpgan_model +from modules import scripts_postprocessing, gfpgan_model, ui_components import gradio as gr -from modules.ui_components import FormRow - class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): name = "GFPGAN" order = 2000 def ui(self): - with FormRow(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility") + with ui_components.InputAccordion(False, label="GFPGAN") as enable: + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_gfpgan_visibility") return { + "enable": enable, "gfpgan_visibility": gfpgan_visibility, } - def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): - if gfpgan_visibility == 0: + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, gfpgan_visibility): + if gfpgan_visibility == 0 or not enable: return restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) diff --git a/scripts/postprocessing_split_oversized.py b/scripts/postprocessing_split_oversized.py new file mode 100644 index 00000000000..2522ed8f0c9 --- /dev/null +++ b/scripts/postprocessing_split_oversized.py @@ -0,0 +1,71 @@ +import math + +from modules import scripts_postprocessing, ui_components +import gradio as gr + + +def split_pic(image, inverse_xy, width, height, overlap_ratio): + if inverse_xy: + from_w, from_h = image.height, image.width + to_w, to_h = height, width + else: + from_w, from_h = image.width, image.height + to_w, to_h = width, height + h = from_h * to_w // from_w + if inverse_xy: + image = image.resize((h, to_w)) + else: + image = image.resize((to_w, h)) + + split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio))) + y_step = (h - to_h) / (split_count - 1) + for i in range(split_count): + y = int(y_step * i) + if inverse_xy: + splitted = image.crop((y, 0, y + to_h, to_w)) + else: + splitted = image.crop((0, y, to_w, y + to_h)) + yield splitted + + +class ScriptPostprocessingSplitOversized(scripts_postprocessing.ScriptPostprocessing): + name = "Split oversized images" + order = 4000 + + def ui(self): + with ui_components.InputAccordion(False, label="Split oversized images") as enable: + with gr.Row(): + split_threshold = gr.Slider(label='Threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_split_threshold") + overlap_ratio = gr.Slider(label='Overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="postprocess_overlap_ratio") + + return { + "enable": enable, + "split_threshold": split_threshold, + "overlap_ratio": overlap_ratio, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, split_threshold, overlap_ratio): + if not enable: + return + + width = pp.shared.target_width + height = pp.shared.target_height + + if not width or not height: + return + + if pp.image.height > pp.image.width: + ratio = (pp.image.width * height) / (pp.image.height * width) + inverse_xy = False + else: + ratio = (pp.image.height * width) / (pp.image.width * height) + inverse_xy = True + + if ratio >= 1.0 and ratio > split_threshold: + return + + result, *others = split_pic(pp.image, inverse_xy, width, height, overlap_ratio) + + pp.image = result + pp.extra_images = [pp.create_copy(x) for x in others] + diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index edb70ac01ca..ed709688de4 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -29,7 +29,7 @@ def ui(self): upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w") upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h") with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"): - upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn") + upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn", tooltip="Switch width/height") upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") with FormRow(): @@ -81,6 +81,14 @@ def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_w return image + def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): + if upscale_mode == 1: + pp.shared.target_width = upscale_to_width + pp.shared.target_height = upscale_to_height + else: + pp.shared.target_width = int(pp.image.width * upscale_by) + pp.shared.target_height = int(pp.image.height * upscale_by) + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): if upscaler_1_name == "None": upscaler_1_name = None @@ -126,6 +134,10 @@ def ui(self): "upscaler_name": upscaler_name, } + def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None): + pp.shared.target_width = int(pp.image.width * upscale_by) + pp.shared.target_height = int(pp.image.height * upscale_by) + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None): if upscaler_name is None or upscaler_name == "None": return diff --git a/scripts/processing_autosized_crop.py b/scripts/processing_autosized_crop.py new file mode 100644 index 00000000000..d1c16e13b6a --- /dev/null +++ b/scripts/processing_autosized_crop.py @@ -0,0 +1,64 @@ +from PIL import Image + +from modules import scripts_postprocessing, ui_components +import gradio as gr + + +def center_crop(image: Image, w: int, h: int): + iw, ih = image.size + if ih / h < iw / w: + sw = w * ih / h + box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih + else: + sh = h * iw / w + box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2 + return image.resize((w, h), Image.Resampling.LANCZOS, box) + + +def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold): + iw, ih = image.size + err = lambda w, h: 1 - (lambda x: x if x < 1 else 1 / x)(iw / ih / (w / h)) + wh = max(((w, h) for w in range(mindim, maxdim + 1, 64) for h in range(mindim, maxdim + 1, 64) + if minarea <= w * h <= maxarea and err(w, h) <= threshold), + key=lambda wh: (wh[0] * wh[1], -err(*wh))[::1 if objective == 'Maximize area' else -1], + default=None + ) + return wh and center_crop(image, *wh) + + +class ScriptPostprocessingAutosizedCrop(scripts_postprocessing.ScriptPostprocessing): + name = "Auto-sized crop" + order = 4000 + + def ui(self): + with ui_components.InputAccordion(False, label="Auto-sized crop") as enable: + gr.Markdown('Each image is center-cropped with an automatically chosen width and height.') + with gr.Row(): + mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="postprocess_multicrop_mindim") + maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="postprocess_multicrop_maxdim") + with gr.Row(): + minarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area lower bound", value=64 * 64, elem_id="postprocess_multicrop_minarea") + maxarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area upper bound", value=640 * 640, elem_id="postprocess_multicrop_maxarea") + with gr.Row(): + objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="postprocess_multicrop_objective") + threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="postprocess_multicrop_threshold") + + return { + "enable": enable, + "mindim": mindim, + "maxdim": maxdim, + "minarea": minarea, + "maxarea": maxarea, + "objective": objective, + "threshold": threshold, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, mindim, maxdim, minarea, maxarea, objective, threshold): + if not enable: + return + + cropped = multicrop_pic(pp.image, mindim, maxdim, minarea, maxarea, objective, threshold) + if cropped is not None: + pp.image = cropped + else: + print(f"skipped {pp.image.width}x{pp.image.height} image (can't find suitable size within error threshold)") diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 50320d553bd..a4a2f24dd25 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -5,11 +5,17 @@ import modules.scripts as scripts import gradio as gr -from modules import sd_samplers, errors +from modules import sd_samplers, errors, sd_models from modules.processing import Processed, process_images from modules.shared import state +def process_model_tag(tag): + info = sd_models.get_closet_checkpoint_match(tag) + assert info is not None, f'Unknown checkpoint: {tag}' + return info.name + + def process_string_tag(tag): return tag @@ -27,7 +33,7 @@ def process_boolean_tag(tag): prompt_tags = { - "sd_model": None, + "sd_model": process_model_tag, "outpath_samples": process_string_tag, "outpath_grids": process_string_tag, "prompt_for_display": process_string_tag, @@ -108,6 +114,7 @@ def title(self): def ui(self, is_img2img): checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate")) checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) + prompt_position = gr.Radio(["start", "end"], label="Insert prompts at the", elem_id=self.elem_id("prompt_position"), value="start") prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt")) file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file")) @@ -118,9 +125,9 @@ def ui(self, is_img2img): # We don't shrink back to 1, because that causes the control to ignore [enter], and it may # be unclear to the user that shift-enter is needed. prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False) - return [checkbox_iterate, checkbox_iterate_batch, prompt_txt] + return [checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt] - def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): + def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt: str): lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x] p.do_not_save_grid = True @@ -156,7 +163,22 @@ def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): copy_p = copy.copy(p) for k, v in args.items(): - setattr(copy_p, k, v) + if k == "sd_model": + copy_p.override_settings['sd_model_checkpoint'] = v + else: + setattr(copy_p, k, v) + + if args.get("prompt") and p.prompt: + if prompt_position == "start": + copy_p.prompt = args.get("prompt") + " " + p.prompt + else: + copy_p.prompt = p.prompt + " " + args.get("prompt") + + if args.get("negative_prompt") and p.negative_prompt: + if prompt_position == "start": + copy_p.negative_prompt = args.get("negative_prompt") + " " + p.negative_prompt + else: + copy_p.negative_prompt = p.negative_prompt + " " + args.get("negative_prompt") proc = process_images(copy_p) images += proc.images diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 939d86053bd..0dc255bc43d 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -205,13 +205,14 @@ def csv_string_to_list_strip(data_str): class AxisOption: - def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None): + def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None, prepare=None): self.label = label self.type = type self.apply = apply self.format_value = format_value self.confirm = confirm self.cost = cost + self.prepare = prepare self.choices = choices @@ -536,6 +537,8 @@ def process_axis(opt, vals, vals_dropdown): if opt.choices is not None and not csv_mode: valslist = vals_dropdown + elif opt.prepare is not None: + valslist = opt.prepare(vals) else: valslist = csv_string_to_list_strip(vals) @@ -773,6 +776,8 @@ def cell(x, y, z, ix, iy, iz): # TODO: See previous comment about intentional data misalignment. adj_g = g-1 if g > 0 else g images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed) + if not include_sub_grids: # if not include_sub_grids then skip saving after the first grid + break if not include_sub_grids: # Done with sub-grids, drop all related information: