Skip to content

Commit

Permalink
MNT Remove deprecated use of load_in_8bit (#1811)
Browse files Browse the repository at this point in the history
Don't pass load_in_8bit to AutoModel.from_pretrained, instead use
BitsAndBytesConfig.

There was already a PR to clean this up (#1552) but a slightly later
PR (#1518) re-added this usage.

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
  • Loading branch information
BenjaminBossan and younesbelkada authored May 30, 2024
1 parent 8cd2cb6 commit cb0bf07
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions tests/test_common_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,11 +826,7 @@ def test_8bit_lora_mixed_adapter_batches_lora(self):
# check that we can pass mixed adapter names to the model
# note that with 8bit, we have quite a bit of imprecision, therefore we use softmax and higher tolerances
torch.manual_seed(3000)
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_compute_dtype=torch.float32,
)
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
"facebook/opt-125m",
quantization_config=bnb_config,
Expand Down Expand Up @@ -951,7 +947,7 @@ def test_8bit_dora_inference(self):
# check for same result with and without DoRA when initializing with init_lora_weights=False
model = AutoModelForCausalLM.from_pretrained(
"facebook/opt-125m",
load_in_8bit=True,
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
torch_dtype=torch.float32,
).eval()

Expand All @@ -964,7 +960,7 @@ def test_8bit_dora_inference(self):

model = AutoModelForCausalLM.from_pretrained(
"facebook/opt-125m",
load_in_8bit=True,
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
torch_dtype=torch.float32,
)
torch.manual_seed(0)
Expand Down Expand Up @@ -1042,7 +1038,7 @@ def test_8bit_dora_merging(self):
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
"facebook/opt-125m",
load_in_8bit=True,
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
torch_dtype=torch.float32,
).eval()

Expand Down

0 comments on commit cb0bf07

Please sign in to comment.