Skip to content

Commit

Permalink
add lora bundle system
Browse files Browse the repository at this point in the history
  • Loading branch information
KohakuBlueleaf committed Oct 9, 2023
1 parent 7d60076 commit 2aa485b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
1 change: 1 addition & 0 deletions extensions-builtin/Lora/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(self, name, network_on_disk: NetworkOnDisk):
self.unet_multiplier = 1.0
self.dyn_dim = None
self.modules = {}
self.bundle_embeddings = {}
self.mtime = None

self.mentioned_name = None
Expand Down
48 changes: 48 additions & 0 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Union

from modules import shared, devices, sd_models, errors, scripts, sd_hijack
from modules.textual_inversion.textual_inversion import Embedding

module_types = [
network_lora.ModuleTypeLora(),
Expand Down Expand Up @@ -149,9 +150,15 @@ def load_network(name, network_on_disk):
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping

matched_networks = {}
bundle_embeddings = {}

for key_network, weight in sd.items():
key_network_without_network_parts, network_part = key_network.split(".", 1)
if key_network_without_network_parts == "bundle_emb":
emb_name, vec_name = network_part.split(".", 1)
emb_dict = bundle_embeddings.get(emb_name, {})
emb_dict[vec_name] = weight
bundle_embeddings[emb_name] = emb_dict

key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
Expand Down Expand Up @@ -195,6 +202,8 @@ def load_network(name, network_on_disk):

net.modules[key] = net_module

net.bundle_embeddings = bundle_embeddings

if keys_failed_to_match:
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")

Expand All @@ -210,11 +219,14 @@ def purge_networks_from_memory():


def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
emb_db = sd_hijack.model_hijack.embedding_db
already_loaded = {}

for net in loaded_networks:
if net.name in names:
already_loaded[net.name] = net
for emb_name in net.bundle_embeddings:
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)

loaded_networks.clear()

Expand Down Expand Up @@ -257,6 +269,41 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
loaded_networks.append(net)

for emb_name, data in net.bundle_embeddings.items():
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
vectors = data['clip_g'].shape[0]
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
else:
raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")

embedding = Embedding(vec, emb_name)
embedding.vectors = vectors
embedding.shape = shape

if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
emb_db.register_embedding(embedding, shared.sd_model)
else:
emb_db.skipped_embeddings[name] = embedding

if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))

Expand Down Expand Up @@ -565,6 +612,7 @@ def infotext_pasted(infotext, params):
available_networks = {}
available_network_aliases = {}
loaded_networks = []
loaded_bundle_embeddings = {}
networks_in_memory = {}
available_network_hash_lookup = {}
forbidden_network_aliases = {}
Expand Down

0 comments on commit 2aa485b

Please sign in to comment.