Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize DoRA in eval and no dropout #2122

Merged
merged 17 commits into from
Oct 16, 2024
14 changes: 13 additions & 1 deletion docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,21 @@ from peft import PeftModel
model = PeftModel.from_pretrained(base_model, peft_model_id, ephemeral_gpu_offload=True)
```

DoRA is optimized (computes faster and takes less memory) for models in the evaluation mode, or when dropout is set to 0. We reuse the
base result at those times to get the speedup.
Running [dora finetuning](https://github.com/huggingface/peft/blob/main/examples/dora_finetuning/dora_finetuning.py)
with `CUDA_VISIBLE_DEVICES=0 time python examples/dora_finetuning/dora_finetuning.py --quantize --lora_dropout 0 --batch_size 16 --eval_step 2 --use_dora`
on a 4090 with gradient accumulation set to 2 and max step to 20 resulted with the following observations:

| | Without Optimization | With Optimization |
| :--: | :--: | :--: |
| train_runtime | 359.7298 | **279.2676** |
| train_samples_per_second | 1.779 | **2.292** |
| train_steps_per_second | 0.056 | **0.072** |

#### Caveats

- DoRA only supports linear and Conv2d layers at the moment.
- DoRA only supports embedding, linear, and Conv2d layers at the moment.
- DoRA introduces a bigger overhead than pure LoRA, so it is recommended to merge weights for inference, see [`LoraModel.merge_and_unload`].
- DoRA should work with weights quantized with bitsandbytes ("QDoRA"). However, issues have been reported when using QDoRA with DeepSpeed Zero2.

Expand Down
32 changes: 20 additions & 12 deletions src/peft/tuners/lora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,20 +235,24 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
x = x.to(compute_dtype)

if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
x = dropout(x)
output = self.lora_magnitude_vector[active_adapter](
if isinstance(dropout, torch.nn.Identity) or not self.training:
base_result = result
else:
x = dropout(x)
base_result = None

result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
base_result=base_result,
)
if requires_conversion:
output = output.to(expected_dtype)

result = result + output
result = result.to(expected_dtype)

return result

Expand Down Expand Up @@ -486,20 +490,24 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
x = x.to(lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
x = dropout(x)
output = self.lora_magnitude_vector[active_adapter](
if isinstance(dropout, torch.nn.Identity) or not self.training:
base_result = result
else:
x = dropout(x)
base_result = None

result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
base_result=base_result,
)
if requires_conversion:
output = output.to(expected_dtype)

result = result + output
result = result.to(expected_dtype)

return result

Expand Down
29 changes: 13 additions & 16 deletions src/peft/tuners/lora/dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,11 @@ def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=Fals
weight_norm = weight_norm.to("cpu")
self.weight = nn.Parameter(weight_norm, requires_grad=True)

def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None):
"""
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
output.
"""
lora_result = lora_B(lora_A(x))

# Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
# calculate the same but using forward.
x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device, dtype=x.dtype)
Expand All @@ -86,19 +84,18 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
# during backpropagation"
weight_norm = weight_norm.detach()
mag_norm_scale = (magnitude / weight_norm).view(1, -1)
result_dora = (mag_norm_scale - 1) * (
F.linear(x, transpose(weight, self.fan_in_fan_out))
) + mag_norm_scale * lora_result * scaling

# Note: Computation could potentially be accelerated by using the code below instead of calculating X@W again.
# This is only correct if dropout=0, otherwise results will differ:
# https://github.com/huggingface/peft/pull/1474#issuecomment-1964682771
# bias = self.get_base_layer().bias
# if bias is not None:
# result = result - bias
# result = mag_norm_scale * result + mag_norm_scale * lora_B(lora_A(x)) * scaling
# if bias is not None:
# result = result + bias

lora_result = lora_B(lora_A(x))

bias = None
if base_result is not None:
bias = base_layer.bias
if bias is not None:
base_result = base_result - bias
else:
base_result = F.linear(x, transpose(weight, self.fan_in_fan_out))

result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * scaling

return result_dora

Expand Down
21 changes: 16 additions & 5 deletions src/peft/tuners/lora/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,24 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
x = x.to(compute_dtype)

if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
output = self._apply_dora(x, lora_A, lora_B, scaling, active_adapter)
if isinstance(dropout, torch.nn.Identity) or not self.training:
base_result = result
else:
x = dropout(x)
base_result = None

result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
base_result=base_result,
)
if requires_conversion:
output = output.to(expected_dtype)

result = result + output
result = result.to(expected_dtype)

return result

Expand Down
8 changes: 7 additions & 1 deletion src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,13 +585,19 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
x = dropout(x)
if isinstance(dropout, nn.Identity) or not self.training:
base_result = result
else:
x = dropout(x)
base_result = None

result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
base_result=base_result,
)

result = result.to(torch_result_dtype)
Expand Down
18 changes: 8 additions & 10 deletions src/peft/tuners/lora/tp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,23 +201,21 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
x = x.to(lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
lora_result = lora_A(dropout(x))
if isinstance(lora_result, tuple):
lora_result = lora_result[0]
lora_result = lora_B(lora_result)
if isinstance(lora_result, tuple):
lora_result = lora_result[0]
lora_result = lora_result * scaling

result = result + lora_result
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
x = dropout(x)
if isinstance(dropout, torch.nn.Identity) or not self.training:
base_result = result
else:
x = dropout(x)
base_result = None

result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
base_result=base_result,
)

result = result.to(torch_result_dtype)
Expand Down
Loading