Skip to content

Commit

Permalink
ENH: Updates for upcoming BNB Int8 release (#2245)
Browse files Browse the repository at this point in the history
* Updates to prepare for bitsandbytes release
  • Loading branch information
matthewdouglas authored Dec 5, 2024
1 parent 15712db commit 860f783
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 20 deletions.
1 change: 0 additions & 1 deletion src/peft/tuners/adalora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
kwargs.update(
{
"has_fp16_weights": target_base_layer.state.has_fp16_weights,
"memory_efficient_backward": target_base_layer.state.memory_efficient_backward,
"threshold": target_base_layer.state.threshold,
"index": target_base_layer.index,
}
Expand Down
1 change: 0 additions & 1 deletion src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def _create_new_module(ia3_config, adapter_name, target, **kwargs):
eightbit_kwargs.update(
{
"has_fp16_weights": target_base_layer.state.has_fp16_weights,
"memory_efficient_backward": target_base_layer.state.memory_efficient_backward,
"threshold": target_base_layer.state.threshold,
"index": target_base_layer.index,
}
Expand Down
1 change: 0 additions & 1 deletion src/peft/tuners/lora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ def dispatch_bnb_8bit(target: torch.nn.Module, adapter_name: str, **kwargs):
eightbit_kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
"memory_efficient_backward": target.state.memory_efficient_backward,
"threshold": target.state.threshold,
"index": target.index,
}
Expand Down
1 change: 0 additions & 1 deletion src/peft/tuners/vera/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ def _create_new_module(vera_config, vera_A, vera_B, adapter_name, target, **kwar
eightbit_kwargs.update(
{
"has_fp16_weights": target_base_layer.state.has_fp16_weights,
"memory_efficient_backward": target_base_layer.state.memory_efficient_backward,
"threshold": target_base_layer.state.threshold,
"index": target_base_layer.index,
}
Expand Down
14 changes: 7 additions & 7 deletions src/peft/utils/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
if state.SCB is None:
state.SCB = weight.SCB

im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
im, Sim = bnb.functional.transform(im, "col32")
if state.CxB is None:
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
dequantized = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
if hasattr(bnb.functional, "int8_vectorwise_dequant"):
# Use bitsandbytes API if available (requires v0.45.0+)
dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
else:
# Multiply by (scale/127) to dequantize.
dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3

if is_cpu:
dequantized = dequantized.to(device)
return dequantized
Expand Down
17 changes: 8 additions & 9 deletions tests/test_common_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,8 +767,8 @@ def test_8bit_merge_lora(self):
with torch.inference_mode():
out_after_merge = F.softmax(model(random_input).logits, dim=-1)

atol = 0.01
rtol = 10
atol = 1e-3
rtol = 1
assert not torch.allclose(out_base, out_before_merge, atol=atol, rtol=rtol)
assert torch.allclose(out_before_merge, out_after_merge, atol=atol, rtol=rtol)
assert isinstance(model, PeftModel)
Expand Down Expand Up @@ -803,8 +803,8 @@ def test_8bit_merge_and_disable_lora(self):
with torch.inference_mode():
out_after = F.softmax(model(random_input).logits, dim=-1)

atol = 0.01
rtol = 10
atol = 1e-3
rtol = 1
assert not torch.allclose(out_base, out_before, atol=atol, rtol=rtol)
assert torch.allclose(out_base, out_after, atol=atol, rtol=rtol)
assert isinstance(model, PeftModel)
Expand Down Expand Up @@ -838,8 +838,8 @@ def test_8bit_merge_lora_with_bias(self):
with torch.inference_mode():
out_after_merge = F.softmax(model(random_input).logits, dim=-1)

atol = 0.01
rtol = 10
atol = 1e-3
rtol = 1
assert not torch.allclose(out_base, out_before_merge, atol=atol, rtol=rtol)
assert torch.allclose(out_before_merge, out_after_merge, atol=atol, rtol=rtol)

Expand Down Expand Up @@ -1294,9 +1294,8 @@ def test_8bit_dora_merging(self):
model = model.merge_and_unload()
out_unloaded = F.softmax(model(random_input).logits, dim=-1)

# 8bit merging less precise than 4bit
atol = 0.01
rtol = 10
atol = 1e-3
rtol = 1
# sanity check that using DoRA changes the results
assert not torch.allclose(out_base, out_dora, atol=atol, rtol=rtol)
assert torch.allclose(out_dora, out_merged, atol=atol, rtol=rtol)
Expand Down

0 comments on commit 860f783

Please sign in to comment.