Skip to content

Commit

Permalink
Merge pull request #12543 from AUTOMATIC1111/extra-norm-module
Browse files Browse the repository at this point in the history
Fix MHA error with ex_bias and support ex_bias for layers which don't have bias
  • Loading branch information
AUTOMATIC1111 authored Aug 14, 2023
2 parents c1a31ec + f70ded8 commit 3a4bee1
Showing 1 changed file with 30 additions and 9 deletions.
39 changes: 30 additions & 9 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,15 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
self.weight.copy_(weights_backup)

if bias_backup is not None:
self.bias.copy_(bias_backup)
if isinstance(self, torch.nn.MultiheadAttention):
self.out_proj.bias.copy_(bias_backup)
else:
self.bias.copy_(bias_backup)
else:
if isinstance(self, torch.nn.MultiheadAttention):
self.out_proj.bias = None
else:
self.bias = None


def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
Expand All @@ -304,8 +312,13 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_weights_backup = weights_backup

bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None and getattr(self, 'bias', None) is not None:
bias_backup = self.bias.to(devices.cpu, copy=True)
if bias_backup is None:
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
elif getattr(self, 'bias', None) is not None:
bias_backup = self.bias.to(devices.cpu, copy=True)
else:
bias_backup = None
self.network_bias_backup = bias_backup

if current_names != wanted_names:
Expand All @@ -323,8 +336,11 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))

self.weight += updown
if ex_bias is not None and getattr(self, 'bias', None) is not None:
self.bias += ex_bias
if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None:
self.bias = torch.nn.Parameter(ex_bias)
else:
self.bias += ex_bias
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
Expand All @@ -339,14 +355,19 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
try:
with torch.no_grad():
updown_q = module_q.calc_updown(self.in_proj_weight)
updown_k = module_k.calc_updown(self.in_proj_weight)
updown_v = module_v.calc_updown(self.in_proj_weight)
updown_q, _ = module_q.calc_updown(self.in_proj_weight)
updown_k, _ = module_k.calc_updown(self.in_proj_weight)
updown_v, _ = module_v.calc_updown(self.in_proj_weight)
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
updown_out = module_out.calc_updown(self.out_proj.weight)
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)

self.in_proj_weight += updown_qkv
self.out_proj.weight += updown_out
if ex_bias is not None:
if self.out_proj.bias is None:
self.out_proj.bias = torch.nn.Parameter(ex_bias)
else:
self.out_proj.bias += ex_bias

except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
Expand Down

0 comments on commit 3a4bee1

Please sign in to comment.