diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 8a5dc9970f..4a030dab94 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -25,7 +25,7 @@ from transformers.pytorch_utils import Conv1D from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge -from peft.utils.integrations import dequantize_module_weight, gather_params_ctx +from peft.utils.integrations import dequantize_module_weight, gather_params_ctx, get_bnb_param_type from peft.utils.other import transpose from .config import LoraConfig @@ -167,11 +167,16 @@ def reset_lora_parameters(self, adapter_name, init_lora_weights): nn.init.normal_(self.lora_embedding_B[adapter_name]) def olora_init(self, adapter_name): - dtype = self.get_base_layer().weight.dtype - if dtype in [torch.int8, torch.uint8]: - weight_tensor = dequantize_module_weight(self.get_base_layer()) + base_layer = self.get_base_layer() + orig_weight = base_layer.weight + bnb_param_type = get_bnb_param_type(orig_weight) + dtype = orig_weight.dtype + + if bnb_param_type: + # check without importing bitsandbytes and robust to bnb_4bit_quant_storage=float* + weight_tensor = dequantize_module_weight(base_layer) elif dtype in [torch.float32, torch.float16, torch.bfloat16]: - weight_tensor = self.get_base_layer().weight + weight_tensor = orig_weight else: raise TypeError(f"Unsupported data type for the base layer. Got {dtype}.") @@ -186,8 +191,25 @@ def olora_init(self, adapter_name): self.lora_B[adapter_name].weight.data = Qr.contiguous() weight_tensor.data -= scale_factor * self.lora_B[adapter_name].weight @ self.lora_A[adapter_name].weight - weight_tensor = weight_tensor.to(dtype) - self.get_base_layer().weight.data = weight_tensor + if bnb_param_type == "4bit": + weight_tensor = orig_weight.__class__( + weight_tensor, + quant_type=orig_weight.quant_type, + quant_storage=orig_weight.quant_storage, + compress_statistics=orig_weight.compress_statistics, + module=orig_weight.module, + ).to(orig_weight.device) + base_layer.weight = weight_tensor + elif bnb_param_type == "8bit": + weight_tensor = orig_weight.__class__( + weight_tensor, + requires_grad=orig_weight.requires_grad, + has_fp16_weights=orig_weight.has_fp16_weights, + ).to(orig_weight.device) + base_layer.weight = weight_tensor + else: + weight_tensor = weight_tensor.to(dtype) + base_layer.weight.data = weight_tensor def pissa_init(self, adapter_name, init_lora_weights): weight = self.get_base_layer().weight diff --git a/src/peft/utils/integrations.py b/src/peft/utils/integrations.py index 02c56f6830..4a23809317 100644 --- a/src/peft/utils/integrations.py +++ b/src/peft/utils/integrations.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from contextlib import contextmanager +from typing import Literal import packaging.version import torch @@ -104,3 +107,12 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None): if is_cpu: dequantized = dequantized.to(device) return dequantized + + +def get_bnb_param_type(param: torch.nn.Parameter) -> Literal[False, "4bit", "8bit"]: + """Returns '4bit' or '8bit' if bitsandbytes parameter, else False""" + if param.__class__.__name__ == "Params4bit": + return "4bit" + if param.__class__.__name__ == "Int8Params": + return "8bit" + return False diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 5a45bbb91b..c4b8948b35 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -1786,6 +1786,41 @@ def test_bloomz_olora_8bit(self, device, tmp_path): # Same test as test_bloomz_olora_4bit but with 8 bits. self.get_errors(bits=8, device=device, tmp_path=tmp_path) + @pytest.mark.parametrize("bits", [4, 8]) + def test_olora_with_quantized_model(self, bits): + import bitsandbytes as bnb + + # issue 1999 + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + if bits == 4: + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_quant_storage=torch.float16, + bnb_4bit_use_double_quant=True, + ) + elif bits == 8: + bnb_config = BitsAndBytesConfig(load_in_8bit=True) + else: + raise ValueError("bits must be 4 or 8") + + model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config) + model = prepare_model_for_kbit_training(model) + config = LoraConfig(init_lora_weights="olora") + model = get_peft_model(model, config) + + # check that the correct type is used for the weights + base_layer = model.base_model.model.model.decoder.layers[0].self_attn.v_proj.base_layer.weight + if bits == 4: + assert isinstance(base_layer, bnb.nn.modules.Params4bit) + else: + assert isinstance(base_layer, bnb.nn.modules.Int8Params) + + inputs = torch.arange(10).unsqueeze(0).to(model.device) + logits = model(inputs).logits # does not raise + assert torch.isfinite(logits).all() + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") class TestLoftQ: