Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Error with OLoRA init when using bnb #2011

Merged
merged 8 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from transformers.pytorch_utils import Conv1D

from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.integrations import dequantize_module_weight, gather_params_ctx
from peft.utils.integrations import dequantize_module_weight, gather_params_ctx, get_bnb_param_type
from peft.utils.other import transpose

from .config import LoraConfig
Expand Down Expand Up @@ -167,11 +167,16 @@ def reset_lora_parameters(self, adapter_name, init_lora_weights):
nn.init.normal_(self.lora_embedding_B[adapter_name])

def olora_init(self, adapter_name):
dtype = self.get_base_layer().weight.dtype
if dtype in [torch.int8, torch.uint8]:
weight_tensor = dequantize_module_weight(self.get_base_layer())
base_layer = self.get_base_layer()
orig_weight = base_layer.weight
bnb_param_type = get_bnb_param_type(orig_weight)
dtype = orig_weight.dtype

if bnb_param_type:
# check without importing bitsandbytes and robust to bnb_4bit_quant_storage=float*
weight_tensor = dequantize_module_weight(base_layer)
elif dtype in [torch.float32, torch.float16, torch.bfloat16]:
weight_tensor = self.get_base_layer().weight
weight_tensor = orig_weight
else:
raise TypeError(f"Unsupported data type for the base layer. Got {dtype}.")

Expand All @@ -186,8 +191,17 @@ def olora_init(self, adapter_name):
self.lora_B[adapter_name].weight.data = Qr.contiguous()

weight_tensor.data -= scale_factor * self.lora_B[adapter_name].weight @ self.lora_A[adapter_name].weight
weight_tensor = weight_tensor.to(dtype)
self.get_base_layer().weight.data = weight_tensor
if bnb_param_type == "4bit":
weight_tensor = orig_weight.__class__(weight_tensor, quant_type=orig_weight.quant_type).to(
orig_weight.device
)
base_layer.weight = weight_tensor
elif bnb_param_type == "8bit":
weight_tensor = orig_weight.__class__(weight_tensor, requires_grad=False).to(orig_weight.device)
base_layer.weight = weight_tensor
else:
weight_tensor = weight_tensor.to(dtype)
base_layer.weight.data = weight_tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we quantizing the weights this time ?

Copy link
Member Author

@BenjaminBossan BenjaminBossan Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally for bnb weights, the tensors are flattened, e.g. shape [64, 1]. But after dequantizing, the weight_tensor that we assign here is not flat anymore, e.g. shape [16, 16]. My reasoning was that we should get back a "correct" tensor, so better to re-initialize it.

I tried what happens when I remove this and just do base_layer.weight.data = weight_tensor and curiously, this seems to work too and the test passes, even if the shape is now wrong. This makes me wonder if bnb somehow handles this automatically and we should not re-initialize (which could cause its own problems)? Not sure, any suggestion?

Copy link
Member

@SunMarc SunMarc Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried what happens when I remove this and just do base_layer.weight.data = weight_tensor

Wow that's really strange indeed. I tried to check the code in bnb and it doesn't look like they handle this. cc @matthewdouglas

This makes me wonder if bnb somehow handles this automatically and we should not re-initialize (which could cause its own problems)?

I think that's fine as long as you pass the relevant kwargs that you can get from orig_weight. However, make sure to not pass bnb_quantized arg for the 4-bit case. Then, with to(orig_weight.device), it should quantize the weights properly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for the additional info.

However, make sure to not pass bnb_quantized arg for the 4-bit case. Then, with to(orig_weight.device), it should quantize the weights properly.

To clarify, is the present code in alignment with what you suggest or do I need to call to(orig_weight.device) too?

Wow that's really strange indeed. I tried to check the code in bnb and it doesn't look like they handle this.

Okay, then it's probably better to get Matthew's opinion before merging this.

Copy link
Member

@matthewdouglas matthewdouglas Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as the shapes go, both 4bit and 8bit have some mechanisms in place to track the original shapes, but it's different for each. The Linear8bitLt has a state.SB and for 4bit that information is part of quant_state. The main expectation is that it is all stored in a contiguous row-major format.

That said, it's not really clear to me that dequantize_module_weight() is doing all that it would need to do. Maybe it would pass the test here but I would think the updated weights would not be quantized properly afterwards, so re-initializing it is probably the way to go.

To clarify, is the present code in alignment with what you suggest or do I need to call to(orig_weight.device) too?

You'd want to have .to(orig_weight.device) in addition to the other kwargs as @SunMarc mentioned.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the inits to take into account all arguments. Unfortunately, this may get out of date if bnb is updated, but I think there is no method such as bnb.create_param_like(tensor) or such to offload this work to bnb.

It would be great if you could do a final pass over the change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just pass orig_weight.__dict_ as the kwargs ? This is what how we did it in transformers.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I wonder if that's really more robust. If a new attribute is added that is not an __init__ argument, this would fail, right?

class Foo:
    def __init__(self, x):
        self.x = x
        self.y = 123

foo = Foo("hi")
Foo(**foo.__dict__)  # TypeError: Foo.__init__() got an unexpected keyword argument 'y'

So no matter what, this code may break if there is some change to the __init__ code in bnb.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, that right :/

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I merged as is then. Code is going to eventually break one way or the other :D


def pissa_init(self, adapter_name, init_lora_weights):
weight = self.get_base_layer().weight
Expand Down
12 changes: 12 additions & 0 deletions src/peft/utils/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from contextlib import contextmanager
from typing import Literal

import packaging.version
import torch
Expand Down Expand Up @@ -104,3 +107,12 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
if is_cpu:
dequantized = dequantized.to(device)
return dequantized


def get_bnb_param_type(param: torch.nn.Parameter) -> Literal[False, "4bit", "8bit"]:
"""Returns '4bit' or '8bit' if bitsandbytes parameter, else False"""
if param.__class__.__name__ == "Params4bit":
return "4bit"
if param.__class__.__name__ == "Int8Params":
return "8bit"
return False
35 changes: 35 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1786,6 +1786,41 @@ def test_bloomz_olora_8bit(self, device, tmp_path):
# Same test as test_bloomz_olora_4bit but with 8 bits.
self.get_errors(bits=8, device=device, tmp_path=tmp_path)

@pytest.mark.parametrize("bits", [4, 8])
def test_olora_with_quantized_model(self, bits):
import bitsandbytes as bnb

# issue 1999
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
if bits == 4:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_storage=torch.float16,
bnb_4bit_use_double_quant=True,
)
elif bits == 8:
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
else:
raise ValueError("bits must be 4 or 8")

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
model = prepare_model_for_kbit_training(model)
config = LoraConfig(init_lora_weights="olora")
model = get_peft_model(model, config)

# check that the correct type is used for the weights
base_layer = model.base_model.model.model.decoder.layers[0].self_attn.v_proj.base_layer.weight
if bits == 4:
assert isinstance(base_layer, bnb.nn.modules.Params4bit)
else:
assert isinstance(base_layer, bnb.nn.modules.Int8Params)

inputs = torch.arange(10).unsqueeze(0).to(model.device)
logits = model(inputs).logits # does not raise
assert torch.isfinite(logits).all()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
class TestLoftQ:
Expand Down
Loading