Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize DoRA in eval and no dropout #2122

Merged
merged 17 commits into from
Oct 16, 2024
15 changes: 12 additions & 3 deletions src/peft/tuners/lora/dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,11 @@ def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=Fals
weight_norm = weight_norm.to("cpu")
self.weight = nn.Parameter(weight_norm, requires_grad=True)

def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, do_optimize=False, result=None):
"""
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
output.
"""
lora_result = lora_B(lora_A(x))

# Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
# calculate the same but using forward.
x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device, dtype=x.dtype)
Expand All @@ -86,6 +84,17 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
# during backpropagation"
weight_norm = weight_norm.detach()
mag_norm_scale = (magnitude / weight_norm).view(1, -1)

if do_optimize:
bias = base_layer.bias
if bias is not None:
result = result - bias
result = mag_norm_scale * result + mag_norm_scale * lora_B(lora_A(x)) * scaling
if bias is not None:
result = result + bias
return result

lora_result = lora_B(lora_A(x))
result_dora = (mag_norm_scale - 1) * (
F.linear(x, transpose(weight, self.fan_in_fan_out))
) + mag_norm_scale * lora_result * scaling
Expand Down
29 changes: 21 additions & 8 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,14 +585,27 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
x = dropout(x)
result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
)
if isinstance(dropout, nn.Identity) or not self.training:
result = self.lora_magnitude_vector[active_adapter](
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
x,
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
do_optimize=True,
result=result,
)
else:
x = dropout(x)
result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
do_optimize=False,
result=None,
)

result = result.to(torch_result_dtype)

Expand Down
Loading