diff --git a/src/peft/tuners/lora/dora.py b/src/peft/tuners/lora/dora.py index 3125e3c716..74f50cf457 100644 --- a/src/peft/tuners/lora/dora.py +++ b/src/peft/tuners/lora/dora.py @@ -62,12 +62,12 @@ def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=Fals weight_norm = weight_norm.to("cpu") self.weight = nn.Parameter(weight_norm, requires_grad=True) - def forward(self, x, *, lora_A, lora_B, scaling, base_layer): + def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_layer_result, dropout): """ For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer output. """ - lora_result = lora_B(lora_A(x)) + lora_result = lora_B(lora_A(dropout(x))) # Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead, # calculate the same but using forward. @@ -86,9 +86,7 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer): # during backpropagation" weight_norm = weight_norm.detach() mag_norm_scale = (magnitude / weight_norm).view(1, -1) - result_dora = (mag_norm_scale - 1) * ( - F.linear(x, transpose(weight, self.fan_in_fan_out)) - ) + mag_norm_scale * lora_result * scaling + result_dora = (mag_norm_scale - 1) * base_layer_result + mag_norm_scale * lora_result * scaling # Note: Computation could potentially be accelerated by using the code below instead of calculating X@W again. # This is only correct if dropout=0, otherwise results will differ: @@ -142,7 +140,7 @@ def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: weight_norm = weight.norm(p=2, dim=dim, keepdim=True).transpose(1, 0) return weight_norm - def forward(self, x, *, lora_A, lora_B, scaling, base_layer): + def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_layer_result, dropout): """ For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer output. @@ -160,17 +158,9 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer): # during backpropagation" weight_norm = weight_norm.detach() mag_norm_scale = magnitude / weight_norm - result_dora = (mag_norm_scale - 1) * ( - self.conv_fn( - x, - weight, - bias=None, - stride=base_layer.stride, - padding=base_layer.padding, - dilation=base_layer.dilation, - groups=base_layer.groups, - ) - ) + mag_norm_scale * lora_B(lora_A(x)) * scaling + + # the base layer has already computed the convolution, we do not need to compute it again. + result_dora = (mag_norm_scale - 1) * base_layer_result + mag_norm_scale * lora_B(lora_A(dropout(x))) * scaling return result_dora diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index f3359ca9a8..67961d213d 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -585,13 +585,14 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: if not self.use_dora[active_adapter]: result = result + lora_B(lora_A(dropout(x))) * scaling else: - x = dropout(x) result = result + self.lora_magnitude_vector[active_adapter]( x, lora_A=lora_A, lora_B=lora_B, scaling=scaling, base_layer=self.get_base_layer(), + base_layer_result=result, + dropout=dropout ) result = result.to(torch_result_dtype) @@ -904,6 +905,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig out_kernel = out_stride = (1,) * (self._kernel_dim - 2) self.lora_A[adapter_name] = conv_layer(self.in_features, r, kernel_size, stride, padding, bias=False) self.lora_B[adapter_name] = conv_layer(r, self.out_features, out_kernel, out_stride, bias=False) + if use_rslora: self.scaling[adapter_name] = lora_alpha / math.sqrt(r) else: @@ -1088,7 +1090,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor: def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: self._check_forward_args(x, *args, **kwargs) adapter_names = kwargs.pop("adapter_names", None) - if self.disable_adapters: if self.merged: self.unmerge() @@ -1113,13 +1114,14 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: if not self.use_dora[active_adapter]: result = result + lora_B(lora_A(dropout(x))) * scaling else: - x = dropout(x) result = result + self.lora_magnitude_vector[active_adapter]( x, lora_A=lora_A, lora_B=lora_B, scaling=scaling, base_layer=self.get_base_layer(), + base_layer_result=result, + dropout=dropout ) result = result.to(torch_result_dtype)