Skip to content

Commit

Permalink
Fix rescale
Browse files Browse the repository at this point in the history
  • Loading branch information
KohakuBlueleaf authored Feb 18, 2024
1 parent 9044129 commit 5a8dd0c
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion extensions-builtin/Lora/network_oft.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
self.is_boft = False
if weights.w["oft_diag"].dim() == 4:
self.is_boft = True
self.rescale = weight.w.get('rescale', None)
self.rescale = weights.w.get('rescale', None)
if self.rescale is not None:
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))

is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
Expand Down

0 comments on commit 5a8dd0c

Please sign in to comment.