Skip to content

Commit

Permalink
ENH Argument to enable bias for LoRA B (#2237)
Browse files Browse the repository at this point in the history
This PR adds the argument lora_bias which, if set to True (default:
False), adds a bias term to the LoRA B module.

Typically, this should be disabled. The main use case is when the LoRA
weights were extracted from fully fine-tuned parameters, so the bias of
those parameters can be taken into account.

Merging is supported for this argument when using vanilla LoRA layers or
bitsandbytes LoRA layers. Other types of LoRA layers don't support
merging.

This option is also disabled for non-standard LoRA weight initialization
like LoftQ, as well as for embedding layers (since they use
nn.Parameter).
  • Loading branch information
BenjaminBossan authored Nov 27, 2024
1 parent 60978d7 commit 943daf1
Show file tree
Hide file tree
Showing 15 changed files with 315 additions and 8 deletions.
16 changes: 15 additions & 1 deletion src/peft/tuners/lora/aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,27 @@ def __init__(
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
**kwargs,
):
if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

super().__init__()
LoraLayer.__init__(self, base_layer)

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
)

def forward(self, x: torch.Tensor):
# note: logic differs from default Linear because merging is not supported
Expand Down
16 changes: 15 additions & 1 deletion src/peft/tuners/lora/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,13 @@ def __init__(
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
**kwargs,
):
if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

super().__init__()
LoraLayer.__init__(self, base_layer)

Expand All @@ -46,7 +51,16 @@ def __init__(
self.quant_linear_module = base_layer

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
)

def forward(self, x: torch.Tensor):
result = self.quant_linear_module(x)
Expand Down
25 changes: 25 additions & 0 deletions src/peft/tuners/lora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
**kwargs,
) -> None:
super().__init__()
Expand All @@ -56,6 +57,7 @@ def __init__(
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
)

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
Expand Down Expand Up @@ -118,6 +120,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
self.get_base_layer().weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
).to(weight.device)
if self.lora_bias[active_adapter]:
bias_data = self.get_base_layer().bias.data + self.lora_B[active_adapter].bias
if safe_merge and not torch.isfinite(bias_data):
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
self.get_base_layer().bias.data = bias_data

state.reset_grads()
self.merged_adapters.append(active_adapter)

Expand Down Expand Up @@ -154,6 +164,9 @@ def unmerge(self) -> None:
self.get_base_layer().weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
).to(weight.device)

if self.lora_bias[active_adapter]:
self.get_base_layer().bias.data -= self.lora_B[active_adapter].bias
state.reset_grads()

def get_delta_weight(self, adapter):
Expand Down Expand Up @@ -298,6 +311,7 @@ def __init__(
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
**kwargs,
) -> None:
super().__init__()
Expand All @@ -313,6 +327,7 @@ def __init__(
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
)

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
Expand Down Expand Up @@ -372,6 +387,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
kwargs["requires_grad"] = False
kwargs.pop("data", None)
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
if self.lora_bias[active_adapter]:
bias_data = self.get_base_layer().bias.data + self.lora_B[active_adapter].bias
if safe_merge and not torch.isfinite(bias_data):
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
self.get_base_layer().bias.data = bias_data

self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
Expand Down Expand Up @@ -407,6 +430,8 @@ def unmerge(self) -> None:
kwargs["requires_grad"] = False
kwargs.pop("data", None)
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
if self.lora_bias[active_adapter]:
self.get_base_layer().bias.data -= self.lora_B[active_adapter].bias

def get_delta_weight(self, adapter):
return (
Expand Down
23 changes: 23 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,10 @@ class LoraConfig(PeftConfig):
all have separate LoRA adapters attached to them.
runtime_config (`LoraRuntimeConfig`):
Runtime configurations (which are not saved or restored).
lora_bias (`bool`):
Defaults to `False`. Whether to enable the bias term for the LoRA B parameter. Typically, this should be
disabled. The main use case for this is when the LoRA weights were extracted from fully fine-tuned
parameters so the bias of those parameters can be taken into account.
"""

r: int = field(default=8, metadata={"help": "Lora attention dimension"})
Expand Down Expand Up @@ -391,6 +395,16 @@ class LoraConfig(PeftConfig):
runtime_config: LoraRuntimeConfig = field(
default_factory=LoraRuntimeConfig, metadata={"help": "Runtime configurations"}
)
lora_bias: bool = field(
default=False,
metadata={
"help": (
"Whether to enable the bias term for the LoRA B parameter. Typically, this should be disabled. The "
"main use case for this is when the LoRA weights were extracted from fully fine-tuned parameters so "
"the bias of those parameters can be taken into account."
)
},
)

def to_dict(self):
"""
Expand Down Expand Up @@ -446,6 +460,15 @@ def __post_init__(self):
elif self.init_lora_weights != "eva" and self.eva_config is not None:
warnings.warn("`eva_config` specified but will be ignored when `init_lora_weights` is not 'eva'.")

if self.lora_bias:
if self.init_lora_weights not in (True, False):
raise ValueError(
f"The argument lora_bias=True is only supported with init_lora_weights=True or False, got "
f"init_lora_weights={self.init_lora_weights} instead."
)
if self.use_dora:
raise ValueError("The argument lora_bias=True is not supported for DoRA, please pass use_dora=False")

# Using post training conversion of modified base weights to restore their initial values (PiSSA, OLoRA) cannot
# be correctly done when using rslora + rank_pattern/alpha_pattern. We can't really know if the user intends
# this when they'll eventually call save_pretrained (i.e. if they'll pass
Expand Down
16 changes: 15 additions & 1 deletion src/peft/tuners/lora/eetq.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,13 @@ def __init__(
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
**kwargs,
):
if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

super().__init__()
LoraLayer.__init__(self, base_layer)

Expand All @@ -43,7 +48,16 @@ def __init__(
self.quant_linear_module = base_layer

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
)

def forward(self, x: torch.Tensor):
result = self.quant_linear_module(x)
Expand Down
1 change: 1 addition & 0 deletions src/peft/tuners/lora/eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def _load_eva_state_dict(
"lora_dropout": peft_config.lora_dropout,
"use_rslora": peft_config.use_rslora,
"use_dora": peft_config.use_dora,
"lora_bias": peft_config.lora_bias,
}
missing_eva_inits = []
new_target_modules = []
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/lora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
**kwargs,
):
super().__init__()
Expand All @@ -52,6 +53,7 @@ def __init__(
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
)

def forward(self, x: torch.Tensor):
Expand Down
5 changes: 5 additions & 0 deletions src/peft/tuners/lora/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ def __init__(
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
**kwargs,
) -> None:
if lora_bias:
raise ValueError(f"{self.__class__.__name__} does not support lora_bias yet, set it to False")

super().__init__()
LoraLayer.__init__(self, base_layer)
self.fan_in_fan_out = False
Expand All @@ -56,6 +60,7 @@ def __init__(
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
)

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
Expand Down
Loading

0 comments on commit 943daf1

Please sign in to comment.