-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
LlamaMultiModal class bug fix (#16413)
- Loading branch information
Showing
9 changed files
with
116 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file removed
BIN
-237 Bytes
..._llms/llama-index-multi-modal-llms-huggingface/tests/__pycache__/__init__.cpython-312.pyc
Binary file not shown.
Binary file removed
BIN
-10.6 KB
...gingface/tests/__pycache__/test_multi_modal_llms_huggingface.cpython-312-pytest-8.3.3.pyc
Binary file not shown.
Binary file removed
BIN
-986 KB
...modal_llms/llama-index-multi-modal-llms-huggingface/tests/test_images/2dogs.jpg
Binary file not shown.
Binary file removed
BIN
-598 KB
...modal_llms/llama-index-multi-modal-llms-huggingface/tests/test_images/5cats.jpg
Binary file not shown.
Binary file removed
BIN
-830 KB
...llms/llama-index-multi-modal-llms-huggingface/tests/test_images/girl_rabbit.jpg
Binary file not shown.
Binary file removed
BIN
-222 KB
...al_llms/llama-index-multi-modal-llms-huggingface/tests/test_images/man_read.jpg
Binary file not shown.
112 changes: 112 additions & 0 deletions
112
..._llms/llama-index-multi-modal-llms-huggingface/tests/test_multi_modal_llms_huggingface.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,119 @@ | ||
import numpy as np | ||
import pytest | ||
import tempfile | ||
import os | ||
|
||
from PIL import Image | ||
from unittest.mock import patch, MagicMock | ||
|
||
from llama_index.core.schema import ImageDocument | ||
from llama_index.core.base.llms.types import ChatMessage | ||
from llama_index.core.multi_modal_llms.base import MultiModalLLM | ||
from llama_index.multi_modal_llms.huggingface import HuggingFaceMultiModal | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def mock_model(): | ||
with patch( | ||
"llama_index.multi_modal_llms.huggingface.base.AutoConfig" | ||
) as mock_config, patch( | ||
"llama_index.multi_modal_llms.huggingface.base.Qwen2VLForConditionalGeneration" | ||
) as mock_model_class, patch( | ||
"llama_index.multi_modal_llms.huggingface.base.AutoProcessor" | ||
) as mock_processor: | ||
mock_config.from_pretrained.return_value = MagicMock( | ||
architectures=["Qwen2VLForConditionalGeneration"] | ||
) | ||
mock_model = mock_model_class.from_pretrained.return_value | ||
mock_processor = mock_processor.from_pretrained.return_value | ||
|
||
yield HuggingFaceMultiModal.from_model_name("Qwen/Qwen2-VL-2B-Instruct") | ||
|
||
|
||
# Replace the existing 'model' fixture with this mock_model | ||
@pytest.fixture(scope="module") | ||
def model(mock_model): | ||
return mock_model | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def temp_image_path(): | ||
# Create a white square image | ||
white_square = np.ones((100, 100, 3), dtype=np.uint8) * 255 | ||
image = Image.fromarray(white_square) | ||
|
||
# Create a temporary file | ||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | ||
image.save(temp_file, format="PNG") | ||
temp_path = temp_file.name | ||
|
||
yield temp_path | ||
|
||
# Clean up the temporary file after the test | ||
os.unlink(temp_path) | ||
|
||
|
||
def test_class(): | ||
names_of_base_classes = [b.__name__ for b in HuggingFaceMultiModal.__mro__] | ||
assert MultiModalLLM.__name__ in names_of_base_classes | ||
|
||
|
||
def test_initialization(model): | ||
assert isinstance(model, HuggingFaceMultiModal) | ||
assert model.model_name == "Qwen/Qwen2-VL-2B-Instruct" | ||
|
||
|
||
def test_metadata(model): | ||
metadata = model.metadata | ||
assert metadata.model_name == "Qwen/Qwen2-VL-2B-Instruct" | ||
assert metadata.context_window == 3900 # Default value | ||
assert metadata.num_output == 256 # Default value | ||
|
||
|
||
def test_complete(model, temp_image_path): | ||
prompt = "Describe this image:" | ||
image_doc = ImageDocument(image_path=temp_image_path) | ||
|
||
# Mock the _prepare_messages and _generate methods | ||
model._prepare_messages = MagicMock(return_value={"mocked": "inputs"}) | ||
model._generate = MagicMock(return_value="This is a mocked response.") | ||
|
||
response = model.complete(prompt, image_documents=[image_doc]) | ||
|
||
assert response.text == "This is a mocked response." | ||
model._prepare_messages.assert_called_once() | ||
model._generate.assert_called_once_with({"mocked": "inputs"}) | ||
|
||
|
||
def test_chat(model, temp_image_path): | ||
messages = [ChatMessage(role="user", content="What's in this image?")] | ||
image_doc = ImageDocument(image_path=temp_image_path) | ||
|
||
# Mock the _prepare_messages and _generate methods | ||
model._prepare_messages = MagicMock(return_value={"mocked": "inputs"}) | ||
model._generate = MagicMock(return_value="This is a mocked chat response.") | ||
|
||
response = model.chat(messages, image_documents=[image_doc]) | ||
|
||
assert response.message.content == "This is a mocked chat response." | ||
model._prepare_messages.assert_called_once() | ||
model._generate.assert_called_once_with({"mocked": "inputs"}) | ||
|
||
|
||
@pytest.mark.asyncio() | ||
@pytest.mark.parametrize( | ||
"method_name", | ||
[ | ||
"astream_chat", | ||
"astream_complete", | ||
"acomplete", | ||
"achat", | ||
], | ||
) | ||
async def test_unsupported_methods(model, method_name): | ||
with pytest.raises(NotImplementedError): | ||
method = getattr(model, method_name) | ||
if method_name in ["astream_chat", "achat"]: | ||
await method([]) | ||
else: | ||
await method("prompt", []) |