Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support inference with OFT networks #13692

Merged
merged 26 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ec718f7
wip incorrect OFT implementation
v0xie Oct 18, 2023
1c6efdb
inference working but SLOW
v0xie Oct 18, 2023
853e21d
faster by using cached R in forward
v0xie Oct 18, 2023
eb01d7f
faster by calculating R in updown and using cached R in forward
v0xie Oct 18, 2023
321680c
refactor: fix constraint, re-use get_weight
v0xie Oct 19, 2023
d10c4db
style: formatting
v0xie Oct 19, 2023
0550659
style: fix ambiguous variable name
v0xie Oct 19, 2023
2d8c894
refactor: use forward hook instead of custom forward
v0xie Oct 21, 2023
7683547
fix: return orig weights during updown, merge weights before forward
v0xie Oct 21, 2023
fce86ab
fix: support multiplier, no forward pass hook
v0xie Oct 21, 2023
76f5abd
style: cleanup oft
v0xie Oct 21, 2023
de8ee92
fix: use merge_weight to cache value
v0xie Oct 22, 2023
4a50c96
refactor: remove used OFT functions
v0xie Oct 22, 2023
3b8515d
fix: multiplier applied twice in finalize_updown
v0xie Oct 22, 2023
6523edb
style: conform style
v0xie Oct 22, 2023
a2fad6e
test implementation based on kohaku diag-oft implementation
v0xie Nov 2, 2023
65ccd63
detect diag_oft type
v0xie Nov 2, 2023
d727ddf
no idea what i'm doing, trying to support both type of OFT, kblueleaf…
v0xie Nov 2, 2023
fe1967a
skip multihead attn for now
v0xie Nov 4, 2023
f6c8201
refactor: move factorization to lyco_helpers, separate calc_updown fo…
v0xie Nov 4, 2023
1dd25be
Merge pull request #1 from v0xie/oft-faster
v0xie Nov 4, 2023
329c8ba
refactor: use same updown for both kohya OFT and LyCORIS diag-oft
v0xie Nov 4, 2023
bbf00a9
refactor: remove unused function
v0xie Nov 4, 2023
7edd50f
Merge pull request #2 from v0xie/network-oft-change-impl
v0xie Nov 4, 2023
d6d0b22
fix: ignore calc_scale() for COFT which has very small alpha
v0xie Nov 15, 2023
eb667e7
feat: LyCORIS/kohya OFT network support
v0xie Nov 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions extensions-builtin/Lora/lyco_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)


# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
'''
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
In LoRA with Kroneckor Product, first value is a value for weight scale.
secon value is a value for weight.
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
examples)
factor
-1 2 4 8 16 ...
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
'''

if factor > 0 and (dimension % factor) == 0:
m = factor
n = dimension // factor
if m > n:
n, m = m, n
return m, n
if factor < 0:
factor = dimension
m, n = 1, dimension
length = m + n
while m<n:
new_m = m + 1
while dimension%new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m>factor:
break
else:
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n

97 changes: 97 additions & 0 deletions extensions-builtin/Lora/network_oft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch
import network
from lyco_helpers import factorization
from einops import rearrange


class ModuleTypeOFT(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
return NetworkModuleOFT(net, weights)

return None

# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
class NetworkModuleOFT(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):

super().__init__(net, weights)

self.lin_module = None
self.org_module: list[torch.Module] = [self.sd_module]

# kohya-ss
if "oft_blocks" in weights.w.keys():
self.is_kohya = True
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
self.alpha = weights.w["alpha"] # alpha is constraint
self.dim = self.oft_blocks.shape[0] # lora dim
# LyCORIS
elif "oft_diag" in weights.w.keys():
self.is_kohya = False
self.oft_blocks = weights.w["oft_diag"]
# self.alpha is unused
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)

is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported

if is_linear:
self.out_dim = self.sd_module.out_features
elif is_conv:
self.out_dim = self.sd_module.out_channels
elif is_other_linear:
self.out_dim = self.sd_module.embed_dim

if self.is_kohya:
self.constraint = self.alpha * self.out_dim
self.num_blocks = self.dim
self.block_size = self.out_dim // self.dim
else:
self.constraint = None
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)

def calc_updown_kb(self, orig_weight, multiplier):
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix

R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device)

# This errors out for MultiheadAttention, might need to be handled up-stream
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
merged_weight = torch.einsum(
'k n m, k n ... -> k m ...',
R,
merged_weight
)
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')

updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
output_shape = orig_weight.shape
return self.finalize_updown(updown, orig_weight, output_shape)

def calc_updown(self, orig_weight):
# if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it
multiplier = self.multiplier()
return self.calc_updown_kb(orig_weight, multiplier)

# override to remove the multiplier/scale factor; it's already multiplied in get_weight
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
updown = updown.reshape(output_shape)

if len(output_shape) == 4:
updown = updown.reshape(output_shape)

if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)

if ex_bias is not None:
ex_bias = ex_bias * self.multiplier()

return updown, ex_bias
13 changes: 13 additions & 0 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import network_lokr
import network_full
import network_norm
import network_oft

import torch
from typing import Union
Expand All @@ -28,6 +29,7 @@
network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),
network_glora.ModuleTypeGLora(),
network_oft.ModuleTypeOFT(),
]


Expand Down Expand Up @@ -189,6 +191,17 @@ def load_network(name, network_on_disk):
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)

# kohya_ss OFT module
elif sd_module is None and "oft_unet" in key_network_without_network_parts:
key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)

# KohakuBlueLeaf OFT module
if sd_module is None and "oft_diag" in key:
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)

if sd_module is None:
keys_failed_to_match[key_network] = key
continue
Expand Down