From 860f7838c885ada7d48bb91fbc65b5f1843b9bc6 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Thu, 5 Dec 2024 11:09:56 -0500 Subject: [PATCH] ENH: Updates for upcoming BNB Int8 release (#2245) * Updates to prepare for bitsandbytes release --- src/peft/tuners/adalora/model.py | 1 - src/peft/tuners/ia3/model.py | 1 - src/peft/tuners/lora/bnb.py | 1 - src/peft/tuners/vera/model.py | 1 - src/peft/utils/integrations.py | 14 +++++++------- tests/test_common_gpu.py | 17 ++++++++--------- 6 files changed, 15 insertions(+), 20 deletions(-) diff --git a/src/peft/tuners/adalora/model.py b/src/peft/tuners/adalora/model.py index d85f4b8cdb..db5759a5ac 100644 --- a/src/peft/tuners/adalora/model.py +++ b/src/peft/tuners/adalora/model.py @@ -174,7 +174,6 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): kwargs.update( { "has_fp16_weights": target_base_layer.state.has_fp16_weights, - "memory_efficient_backward": target_base_layer.state.memory_efficient_backward, "threshold": target_base_layer.state.threshold, "index": target_base_layer.index, } diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index 09fd905bac..b7ce3f93a6 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -103,7 +103,6 @@ def _create_new_module(ia3_config, adapter_name, target, **kwargs): eightbit_kwargs.update( { "has_fp16_weights": target_base_layer.state.has_fp16_weights, - "memory_efficient_backward": target_base_layer.state.memory_efficient_backward, "threshold": target_base_layer.state.threshold, "index": target_base_layer.index, } diff --git a/src/peft/tuners/lora/bnb.py b/src/peft/tuners/lora/bnb.py index 4921a4ae57..fbb1c712dd 100644 --- a/src/peft/tuners/lora/bnb.py +++ b/src/peft/tuners/lora/bnb.py @@ -288,7 +288,6 @@ def dispatch_bnb_8bit(target: torch.nn.Module, adapter_name: str, **kwargs): eightbit_kwargs.update( { "has_fp16_weights": target.state.has_fp16_weights, - "memory_efficient_backward": target.state.memory_efficient_backward, "threshold": target.state.threshold, "index": target.index, } diff --git a/src/peft/tuners/vera/model.py b/src/peft/tuners/vera/model.py index f863a074e6..e129620ce8 100644 --- a/src/peft/tuners/vera/model.py +++ b/src/peft/tuners/vera/model.py @@ -314,7 +314,6 @@ def _create_new_module(vera_config, vera_A, vera_B, adapter_name, target, **kwar eightbit_kwargs.update( { "has_fp16_weights": target_base_layer.state.has_fp16_weights, - "memory_efficient_backward": target_base_layer.state.memory_efficient_backward, "threshold": target_base_layer.state.threshold, "index": target_base_layer.index, } diff --git a/src/peft/utils/integrations.py b/src/peft/utils/integrations.py index df65084608..5be6300c0f 100644 --- a/src/peft/utils/integrations.py +++ b/src/peft/utils/integrations.py @@ -101,13 +101,13 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None): if state.SCB is None: state.SCB = weight.SCB - im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device) - im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im) - im, Sim = bnb.functional.transform(im, "col32") - if state.CxB is None: - state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB) - out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB) - dequantized = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t() + if hasattr(bnb.functional, "int8_vectorwise_dequant"): + # Use bitsandbytes API if available (requires v0.45.0+) + dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB) + else: + # Multiply by (scale/127) to dequantize. + dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3 + if is_cpu: dequantized = dequantized.to(device) return dequantized diff --git a/tests/test_common_gpu.py b/tests/test_common_gpu.py index 5f14d9baab..09b8b4e901 100644 --- a/tests/test_common_gpu.py +++ b/tests/test_common_gpu.py @@ -767,8 +767,8 @@ def test_8bit_merge_lora(self): with torch.inference_mode(): out_after_merge = F.softmax(model(random_input).logits, dim=-1) - atol = 0.01 - rtol = 10 + atol = 1e-3 + rtol = 1 assert not torch.allclose(out_base, out_before_merge, atol=atol, rtol=rtol) assert torch.allclose(out_before_merge, out_after_merge, atol=atol, rtol=rtol) assert isinstance(model, PeftModel) @@ -803,8 +803,8 @@ def test_8bit_merge_and_disable_lora(self): with torch.inference_mode(): out_after = F.softmax(model(random_input).logits, dim=-1) - atol = 0.01 - rtol = 10 + atol = 1e-3 + rtol = 1 assert not torch.allclose(out_base, out_before, atol=atol, rtol=rtol) assert torch.allclose(out_base, out_after, atol=atol, rtol=rtol) assert isinstance(model, PeftModel) @@ -838,8 +838,8 @@ def test_8bit_merge_lora_with_bias(self): with torch.inference_mode(): out_after_merge = F.softmax(model(random_input).logits, dim=-1) - atol = 0.01 - rtol = 10 + atol = 1e-3 + rtol = 1 assert not torch.allclose(out_base, out_before_merge, atol=atol, rtol=rtol) assert torch.allclose(out_before_merge, out_after_merge, atol=atol, rtol=rtol) @@ -1294,9 +1294,8 @@ def test_8bit_dora_merging(self): model = model.merge_and_unload() out_unloaded = F.softmax(model(random_input).logits, dim=-1) - # 8bit merging less precise than 4bit - atol = 0.01 - rtol = 10 + atol = 1e-3 + rtol = 1 # sanity check that using DoRA changes the results assert not torch.allclose(out_base, out_dora, atol=atol, rtol=rtol) assert torch.allclose(out_dora, out_merged, atol=atol, rtol=rtol)