Skip to content

Commit

Permalink
feat(liteLLM): Implemented image support and corresponding tests (#900)
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyabsridhar authored Aug 19, 2024
1 parent 10e15cd commit f6d11eb
Show file tree
Hide file tree
Showing 4 changed files with 329 additions and 36 deletions.
2 changes: 1 addition & 1 deletion python/dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ ruff==0.6.1
pytest==8.3.2
pytest-asyncio==0.23.8
pytest-xdist==3.6.1
pytest-socket==0.7.0
pytest-socket==0.7.0
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,33 @@
"metadata": {},
"outputs": [],
"source": [
"# Perform image analysis by providing a url to the image and querying the LLM\n",
"litellm.completion(\n",
" model=\"gpt-4o\",\n",
" messages=[\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\"type\": \"text\", \"text\": \"What’s in this image?\"},\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\n",
" \"url\": \"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg\"\n",
" },\n",
" },\n",
" ],\n",
" }\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# First run pip install tenacity\n",
"litellm.completion_with_retries(\n",
" model=\"gpt-3.5-turbo\",\n",
" messages=[{\"content\": \"What's the highest grossing film ever\", \"role\": \"user\"}],\n",
Expand Down Expand Up @@ -181,6 +208,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Image generation using OpenAI\n",
"litellm.image_generation(model=\"dall-e-2\", prompt=\"cute baby otter\")"
]
},
Expand All @@ -193,6 +221,26 @@
"await litellm.aimage_generation(model=\"dall-e-2\", prompt=\"cute baby otter\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Image generation using Bedrock\n",
"# pip install boto3 first before importing\n",
"\n",
"os.getenv(\"AWS_ACCESS_KEY_ID\")\n",
"os.getenv(\"AWS_SECRET_ACCESS_KEY\")\n",
"os.getenv(\"AWS_SESSION_TOKEN\")\n",
"os.getenv(\"AWS_REGION\")\n",
"\n",
"litellm.image_generation(\n",
" model=\"bedrock/stability.stable-diffusion-xl-v1\",\n",
" prompt=\"blue sky with fluffy white clouds and green hills\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from enum import Enum
from functools import wraps
from typing import Any, Callable, Collection, Dict
from typing import Any, Callable, Collection, Dict, Iterable, Iterator, Mapping, Tuple, TypeVar

from openai.types.image import Image
from opentelemetry import context as context_api
Expand All @@ -26,6 +27,8 @@
from openinference.semconv.trace import (
EmbeddingAttributes,
ImageAttributes,
MessageAttributes,
MessageContentAttributes,
OpenInferenceSpanKindValues,
SpanAttributes,
)
Expand All @@ -37,6 +40,57 @@ def _set_span_attribute(span: trace_api.Span, name: str, value: AttributeValue)
span.set_attribute(name, value)


T = TypeVar("T", bound=type)


def is_iterable_of(lst: Iterable[object], tp: T) -> bool:
return isinstance(lst, Iterable) and all(isinstance(x, tp) for x in lst)


def _get_attributes_from_message_param(
message: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
if not hasattr(message, "get"):
return
if role := message.get("role"):
yield (
MessageAttributes.MESSAGE_ROLE,
role.value if isinstance(role, Enum) else role,
)

if content := message.get("content"):
if isinstance(content, str):
yield MessageAttributes.MESSAGE_CONTENT, content
elif is_iterable_of(content, dict):
for index, c in list(enumerate(content)):
for key, value in _get_attributes_from_message_content(c):
yield f"{MessageAttributes.MESSAGE_CONTENTS}.{index}.{key}", value


def _get_attributes_from_message_content(
content: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
content = dict(content)
type_ = content.pop("type")
if type_ == "text":
yield f"{MessageContentAttributes.MESSAGE_CONTENT_TYPE}", "text"
if text := content.pop("text"):
yield f"{MessageContentAttributes.MESSAGE_CONTENT_TEXT}", text
elif type_ == "image_url":
yield f"{MessageContentAttributes.MESSAGE_CONTENT_TYPE}", "image"
if image := content.pop("image_url"):
for key, value in _get_attributes_from_image(image):
yield f"{MessageContentAttributes.MESSAGE_CONTENT_IMAGE}.{key}", value


def _get_attributes_from_image(
image: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
image = dict(image)
if url := image.pop("url"):
yield f"{ImageAttributes.IMAGE_URL}", url


def _instrument_func_type_completion(span: trace_api.Span, kwargs: Dict[str, Any]) -> None:
"""
Currently instruments the functions:
Expand All @@ -51,10 +105,12 @@ def _instrument_func_type_completion(span: trace_api.Span, kwargs: Dict[str, Any
_set_span_attribute(span, SpanAttributes.LLM_MODEL_NAME, kwargs.get("model", "unknown_model"))

if messages := kwargs.get("messages"):
_set_span_attribute(span, SpanAttributes.INPUT_VALUE, str(messages[0].get("content")))
for i, obj in enumerate(messages):
for key, value in obj.items():
_set_span_attribute(span, f"input.messages.{i}.{key}", value)
_set_span_attribute(span, SpanAttributes.INPUT_VALUE, json.dumps(messages))
for index, input_message in list(enumerate(messages)):
for key, value in _get_attributes_from_message_param(input_message):
_set_span_attribute(
span, f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{key}", value
)

invocation_params = {k: v for k, v in kwargs.items() if k not in ["model", "messages"]}
_set_span_attribute(
Expand Down Expand Up @@ -112,10 +168,12 @@ def _finalize_span(span: trace_api.Span, result: Any) -> None:
elif isinstance(result, ImageResponse):
if len(result.data) > 0:
if img_data := result.data[0]:
if isinstance(img_data, Image) and (url := img_data.url):
if isinstance(img_data, Image) and (url := (img_data.url or img_data.b64_json)):
_set_span_attribute(span, ImageAttributes.IMAGE_URL, url)
_set_span_attribute(span, SpanAttributes.OUTPUT_VALUE, url)
elif isinstance(img_data, dict) and (url := img_data.get("url")):
elif isinstance(img_data, dict) and (
url := (img_data.get("url") or img_data.get("b64_json"))
):
_set_span_attribute(span, ImageAttributes.IMAGE_URL, url)
_set_span_attribute(span, SpanAttributes.OUTPUT_VALUE, url)
if hasattr(result, "usage"):
Expand Down
Loading

0 comments on commit f6d11eb

Please sign in to comment.