Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Idefics2] - Fix FA2 call for Perceiver layer #32275

Merged
merged 8 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ def forward(
attention_mask,
q_len,
dropout=dropout_rate,
sliding_window=False,
sliding_window=None,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
Expand Down
48 changes: 45 additions & 3 deletions tests/models/idefics2/test_modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
torch_device,
)

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
Expand Down Expand Up @@ -491,13 +498,13 @@ def tearDown(self):
torch.cuda.empty_cache()

@slow
@unittest.skip("Test hits OOM on CI - https://github.com/huggingface/transformers/issues/32288")
def test_integration_test(self):
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base",
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.to(torch_device)

# Create inputs
text = "<image>In this image, we see"
Expand All @@ -517,7 +524,8 @@ def test_integration_test(self):
def test_integration_test_4bit(self):
# Let' s make sure we test the preprocessing to replace what is used
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base", load_in_4bit=True, device_map="auto"
"HuggingFaceM4/idefics2-8b-base",
load_in_4bit=True,
)

# Create pixel inputs
Expand All @@ -530,3 +538,37 @@ def test_integration_test_4bit(self):

expected_generated_text = "In this image, we see the Statue of Liberty, the Hudson River,"
self.assertEqual(generated_texts[0], expected_generated_text)

@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
def test_flash_attn_2_eager_equivalence(self):
# Create inputs
text = "<image>In this image, we see"
images = self.image1
inputs = self.processor(text=text, images=images, return_tensors="pt", padding=True)
inputs.to(torch_device)

# Eager model
model_eager = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base",
attn_implementation="eager",
load_in_4bit=True,
)
generated_ids_eager = model_eager.generate(**inputs, max_new_tokens=10)
generated_texts_eager = self.processor.batch_decode(generated_ids_eager, skip_special_tokens=True)

del model_eager

# Flash Attention 2 model
model_flash_attention_2 = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base",
attn_implementation="flash_attention_2",
load_in_4bit=True,
)
generated_ids_flash_attention_2 = model_flash_attention_2.generate(**inputs, max_new_tokens=10)
generated_texts_flash_attention_2 = self.processor.batch_decode(
generated_ids_flash_attention_2, skip_special_tokens=True
)

self.assertEqual(generated_texts_eager[0], generated_texts_flash_attention_2[0])
Loading