Skip to content

Commit

Permalink
FIX: Transpose weight matrix based on fan_in_fan_out condition in PiS…
Browse files Browse the repository at this point in the history
…SA initialization (huggingface#2103)

Previously, the weight matrix was converted to float32 without considering the need for transposition. This update ensures that the weight matrix is transposed when the fan_in_fan_out condition is met, resolving dimension mismatch issues during GPT-2 training.
  • Loading branch information
Yang Su committed Sep 26, 2024
1 parent ccc3501 commit 2a513f6
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def pissa_init(self, adapter_name, init_lora_weights):
"Please initialize PiSSA under float32, float16, or bfloat16. "
"Subsequently, re-quantize the residual model to help minimize quantization errors."
)
weight = weight.to(torch.float32)
weight = transpose(weight.to(torch.float32), self.fan_in_fan_out)
if init_lora_weights == "pissa":
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)
Expand All @@ -245,7 +245,7 @@ def pissa_init(self, adapter_name, init_lora_weights):
self.lora_A[adapter_name].weight.data = lora_A
self.lora_B[adapter_name].weight.data = lora_B
weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A
weight = weight.to(dtype)
weight = transpose(weight.to(dtype), self.fan_in_fan_out)
self.get_base_layer().weight.data = weight

def loftq_init(self, adapter_name):
Expand Down

0 comments on commit 2a513f6

Please sign in to comment.