Skip to content

Commit

Permalink
fix: ensure Image objects can be used directly
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Dec 17, 2024
1 parent 767b1d5 commit a8bd998
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
7 changes: 6 additions & 1 deletion ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Any,
Callable,
Literal,
List,
Mapping,
Optional,
Sequence,
Expand Down Expand Up @@ -1121,10 +1122,14 @@ async def ps(self) -> ProcessResponse:
)


def _massage_images(imagelikes: Sequence[Union[Image, Any]]) -> List[Image]:
return [image if isinstance(image, Image) else Image(value=image) for image in imagelikes]


def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]]]) -> Iterator[Message]:
for message in messages or []:
yield Message.model_validate(
{k: [Image(value=image) for image in v] if k == 'images' else v for k, v in dict(message).items() if v},
{k: _massage_images(v) if k == 'images' else v for k, v in dict(message).items() if v},
)


Expand Down
23 changes: 18 additions & 5 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ def generate():
assert part['message']['content'] == next(it)


def test_client_chat_images(httpserver: HTTPServer):
@pytest.mark.parametrize('messages_style', ('dict', 'model'))
@pytest.mark.parametrize('file_style', ('path', 'bytes'))
def test_client_chat_images(httpserver: HTTPServer, messages_style: str, file_style: str, tmp_path):
from ollama._types import Message, Image

httpserver.expect_ordered_request(
'/api/chat',
method='POST',
Expand Down Expand Up @@ -116,10 +120,19 @@ def test_client_chat_images(httpserver: HTTPServer):

client = Client(httpserver.url_for('/'))

response = client.chat(
'dummy',
messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [PNG_BYTES]}],
)
if file_style == 'bytes':
image_content = PNG_BYTES
elif file_style == 'path':
image_path = tmp_path / 'transparent.png'
image_path.write_bytes(PNG_BYTES)
image_content = str(image_path)

if messages_style:
messages = [Message(role='user', content='Why is the sky blue?', images=[Image(value=image_content)])]
else:
messages = [{'role': 'user', 'content': 'Why is the sky blue?', 'images': [image_content]}]

response = client.chat('dummy', messages=messages)
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == "I don't know."
Expand Down

0 comments on commit a8bd998

Please sign in to comment.