Skip to content

Commit

Permalink
update to sdxl
Browse files Browse the repository at this point in the history
  • Loading branch information
anapnoe committed Oct 31, 2023
1 parent 73e222e commit f2844bd
Show file tree
Hide file tree
Showing 233 changed files with 8,015 additions and 8,768 deletions.
10 changes: 9 additions & 1 deletion extensions-builtin/Lora/extra_networks_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self):
super().__init__('lora')

self.errors = {}
"""mapping of network names to the number of errors the network had during operation"""

def activate(self, p, params_list):
additional = shared.opts.sd_lora

self.errors.clear()

if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
Expand Down Expand Up @@ -56,4 +61,7 @@ def activate(self, p, params_list):
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)

def deactivate(self, p):
pass
if self.errors:
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))

self.errors.clear()
31 changes: 31 additions & 0 deletions extensions-builtin/Lora/lora_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch

import networks
from modules import patches


class LoraPatches:
def __init__(self):
self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)

def undo(self):
self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')
self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')
self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')
self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')
self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')
self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')
self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')
self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')
self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')
self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')

8 changes: 6 additions & 2 deletions extensions-builtin/Lora/network.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import os
from collections import namedtuple
import enum
Expand Down Expand Up @@ -132,7 +133,7 @@ def calc_scale(self):

return 1.0

def finalize_updown(self, updown, orig_weight, output_shape):
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
Expand All @@ -144,7 +145,10 @@ def finalize_updown(self, updown, orig_weight, output_shape):
if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)

return updown * self.calc_scale() * self.multiplier()
if ex_bias is not None:
ex_bias = ex_bias * self.multiplier()

return updown * self.calc_scale() * self.multiplier(), ex_bias

def calc_updown(self, target):
raise NotImplementedError()
Expand Down
7 changes: 6 additions & 1 deletion extensions-builtin/Lora/network_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)

self.weight = weights.w.get("diff")
self.ex_bias = weights.w.get("diff_b")

def calc_updown(self, orig_weight):
output_shape = self.weight.shape
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
if self.ex_bias is not None:
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
else:
ex_bias = None

return self.finalize_updown(updown, orig_weight, output_shape)
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
28 changes: 28 additions & 0 deletions extensions-builtin/Lora/network_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import network


class ModuleTypeNorm(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["w_norm", "b_norm"]):
return NetworkModuleNorm(net, weights)

return None


class NetworkModuleNorm(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)

self.w_norm = weights.w.get("w_norm")
self.b_norm = weights.w.get("b_norm")

def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)

if self.b_norm is not None:
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
else:
ex_bias = None

return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
Loading

0 comments on commit f2844bd

Please sign in to comment.