Skip to content

Commit

Permalink
[Idefics2] - Fix FA2 call for Perceiver layer (#32275)
Browse files Browse the repository at this point in the history
* Fix FA2 call for Perciever layer

* [run_slow] idefics2

* [run_slow] idefics2

* [run_slow] idefics2

* Fix up

* [run_slow] idefics2

* [run_slow] idefics2

* [run_slow] idefics2
  • Loading branch information
amyeroberts authored Jul 31, 2024
1 parent b75ad56 commit 5f1fcc2
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ def forward(
attention_mask,
q_len,
dropout=dropout_rate,
sliding_window=False,
sliding_window=None,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
Expand Down
48 changes: 45 additions & 3 deletions tests/models/idefics2/test_modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
torch_device,
)

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
Expand Down Expand Up @@ -491,13 +498,13 @@ def tearDown(self):
torch.cuda.empty_cache()

@slow
@unittest.skip("Test hits OOM on CI - https://github.com/huggingface/transformers/issues/32288")
def test_integration_test(self):
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base",
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.to(torch_device)

# Create inputs
text = "<image>In this image, we see"
Expand All @@ -517,7 +524,8 @@ def test_integration_test(self):
def test_integration_test_4bit(self):
# Let' s make sure we test the preprocessing to replace what is used
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base", load_in_4bit=True, device_map="auto"
"HuggingFaceM4/idefics2-8b-base",
load_in_4bit=True,
)

# Create pixel inputs
Expand All @@ -530,3 +538,37 @@ def test_integration_test_4bit(self):

expected_generated_text = "In this image, we see the Statue of Liberty, the Hudson River,"
self.assertEqual(generated_texts[0], expected_generated_text)

@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
def test_flash_attn_2_eager_equivalence(self):
# Create inputs
text = "<image>In this image, we see"
images = self.image1
inputs = self.processor(text=text, images=images, return_tensors="pt", padding=True)
inputs.to(torch_device)

# Eager model
model_eager = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base",
attn_implementation="eager",
load_in_4bit=True,
)
generated_ids_eager = model_eager.generate(**inputs, max_new_tokens=10)
generated_texts_eager = self.processor.batch_decode(generated_ids_eager, skip_special_tokens=True)

del model_eager

# Flash Attention 2 model
model_flash_attention_2 = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base",
attn_implementation="flash_attention_2",
load_in_4bit=True,
)
generated_ids_flash_attention_2 = model_flash_attention_2.generate(**inputs, max_new_tokens=10)
generated_texts_flash_attention_2 = self.processor.batch_decode(
generated_ids_flash_attention_2, skip_special_tokens=True
)

self.assertEqual(generated_texts_eager[0], generated_texts_flash_attention_2[0])

0 comments on commit 5f1fcc2

Please sign in to comment.