Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Error with OLoRA init when using bnb #2011

Merged
merged 8 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from transformers.pytorch_utils import Conv1D

from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.integrations import dequantize_module_weight, gather_params_ctx
from peft.utils.integrations import dequantize_module_weight, gather_params_ctx, get_bnb_param_type
from peft.utils.other import transpose

from .config import LoraConfig
Expand Down Expand Up @@ -167,11 +167,16 @@ def reset_lora_parameters(self, adapter_name, init_lora_weights):
nn.init.normal_(self.lora_embedding_B[adapter_name])

def olora_init(self, adapter_name):
dtype = self.get_base_layer().weight.dtype
if dtype in [torch.int8, torch.uint8]:
weight_tensor = dequantize_module_weight(self.get_base_layer())
base_layer = self.get_base_layer()
orig_weight = base_layer.weight
bnb_param_type = get_bnb_param_type(orig_weight)
dtype = orig_weight.dtype

if bnb_param_type:
# check without importing bitsandbytes and robust to bnb_4bit_quant_storage=float*
weight_tensor = dequantize_module_weight(base_layer)
elif dtype in [torch.float32, torch.float16, torch.bfloat16]:
weight_tensor = self.get_base_layer().weight
weight_tensor = orig_weight
else:
raise TypeError(f"Unsupported data type for the base layer. Got {dtype}.")

Expand All @@ -186,8 +191,25 @@ def olora_init(self, adapter_name):
self.lora_B[adapter_name].weight.data = Qr.contiguous()

weight_tensor.data -= scale_factor * self.lora_B[adapter_name].weight @ self.lora_A[adapter_name].weight
weight_tensor = weight_tensor.to(dtype)
self.get_base_layer().weight.data = weight_tensor
if bnb_param_type == "4bit":
weight_tensor = orig_weight.__class__(
weight_tensor,
quant_type=orig_weight.quant_type,
quant_storage=orig_weight.quant_storage,
compress_statistics=orig_weight.compress_statistics,
module=orig_weight.module,
).to(orig_weight.device)
base_layer.weight = weight_tensor
elif bnb_param_type == "8bit":
weight_tensor = orig_weight.__class__(
weight_tensor,
requires_grad=orig_weight.requires_grad,
has_fp16_weights=orig_weight.has_fp16_weights,
).to(orig_weight.device)
base_layer.weight = weight_tensor
else:
weight_tensor = weight_tensor.to(dtype)
base_layer.weight.data = weight_tensor

def pissa_init(self, adapter_name, init_lora_weights):
weight = self.get_base_layer().weight
Expand Down
12 changes: 12 additions & 0 deletions src/peft/utils/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from contextlib import contextmanager
from typing import Literal

import packaging.version
import torch
Expand Down Expand Up @@ -104,3 +107,12 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
if is_cpu:
dequantized = dequantized.to(device)
return dequantized


def get_bnb_param_type(param: torch.nn.Parameter) -> Literal[False, "4bit", "8bit"]:
"""Returns '4bit' or '8bit' if bitsandbytes parameter, else False"""
if param.__class__.__name__ == "Params4bit":
return "4bit"
if param.__class__.__name__ == "Int8Params":
return "8bit"
return False
35 changes: 35 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1786,6 +1786,41 @@ def test_bloomz_olora_8bit(self, device, tmp_path):
# Same test as test_bloomz_olora_4bit but with 8 bits.
self.get_errors(bits=8, device=device, tmp_path=tmp_path)

@pytest.mark.parametrize("bits", [4, 8])
def test_olora_with_quantized_model(self, bits):
import bitsandbytes as bnb

# issue 1999
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
if bits == 4:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_storage=torch.float16,
bnb_4bit_use_double_quant=True,
)
elif bits == 8:
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
else:
raise ValueError("bits must be 4 or 8")

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
model = prepare_model_for_kbit_training(model)
config = LoraConfig(init_lora_weights="olora")
model = get_peft_model(model, config)

# check that the correct type is used for the weights
base_layer = model.base_model.model.model.decoder.layers[0].self_attn.v_proj.base_layer.weight
if bits == 4:
assert isinstance(base_layer, bnb.nn.modules.Params4bit)
else:
assert isinstance(base_layer, bnb.nn.modules.Int8Params)

inputs = torch.arange(10).unsqueeze(0).to(model.device)
logits = model(inputs).logits # does not raise
assert torch.isfinite(logits).all()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
class TestLoftQ:
Expand Down
Loading