From 9c2099913c1b4efa5dec33b0a13300c40e7ef4f8 Mon Sep 17 00:00:00 2001 From: Yifan Mai Date: Wed, 6 Mar 2024 09:25:48 -0800 Subject: [PATCH 1/4] Add multimodal support to Claude 3 models --- src/helm/clients/anthropic_client.py | 47 +++++++++++++++++++++++++--- src/helm/common/media_object.py | 1 + 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/src/helm/clients/anthropic_client.py b/src/helm/clients/anthropic_client.py index 07f38f4176..e7ac7c47f4 100644 --- a/src/helm/clients/anthropic_client.py +++ b/src/helm/clients/anthropic_client.py @@ -6,6 +6,7 @@ from helm.common.cache import CacheConfig from helm.common.hierarchical_logger import htrack_block, hlog +from helm.common.media_object import IMAGE_TYPE, TEXT_TYPE from helm.common.optional_dependencies import handle_module_not_found_error from helm.common.request import ( wrap_request_time, @@ -27,6 +28,8 @@ try: from anthropic import Anthropic from anthropic.types import MessageParam + from anthropic.types.image_block_param import ImageBlockParam + from anthropic.types.text_block_param import TextBlockParam import websocket except ModuleNotFoundError as e: handle_module_not_found_error(e, ["anthropic"]) @@ -242,16 +245,52 @@ def make_request(self, request: Request) -> RequestResult: messages: List[MessageParam] = [] system_message: Optional[MessageParam] = None - if request.messages and request.prompt: - raise AnthropicMessagesRequestError("Exactly one of Request.messages and Request.prompt should be set") - if request.messages: + + if request.messages is not None: messages = cast(List[MessageParam], request.messages) if messages[0]["role"] == "system": system_message = messages[0] messages = messages[1:] - else: + + elif request.prompt: messages = [{"role": "user", "content": request.prompt}] + elif request.multimodal_prompt is not None: + blocks = [] + for media_object in request.multimodal_prompt.media_objects: + if media_object.is_type(IMAGE_TYPE): + # TODO(#2439): Refactor out Request validation + if not media_object.location: + raise Exception("MediaObject of image type has missing location field value") + + from helm.common.images_utils import encode_base64 + + base64_image: str = encode_base64(media_object.location, format="JPEG") + image_block: ImageBlockParam = { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": base64_image, + }, + } + blocks.append(image_block) + if media_object.is_type(TEXT_TYPE): + # TODO(#2439): Refactor out Request validation + if media_object.text is None: + raise ValueError("MediaObject of text type has missing text field value") + text_block: TextBlockParam = { + "type": "text", + "text": media_object.text, + } + blocks.append(text_block) + + else: + # TODO(#2439): Refactor out Request validation + raise AnthropicMessagesRequestError( + "Exactly one of Request.messages, Request.prompt or Request.multimodel_prompt should be set" + ) + raw_request: AnthropicMessagesRequest = { "messages": messages, "model": request.model_engine, diff --git a/src/helm/common/media_object.py b/src/helm/common/media_object.py index ddee88e778..d3fdf8292a 100644 --- a/src/helm/common/media_object.py +++ b/src/helm/common/media_object.py @@ -5,6 +5,7 @@ from typing import List, Optional +IMAGE_TYPE = "image" TEXT_TYPE = "text" From 4a8c0d3679c5612d1d811899ba683f2a98921e18 Mon Sep 17 00:00:00 2001 From: Yifan Mai Date: Wed, 6 Mar 2024 09:33:46 -0800 Subject: [PATCH 2/4] Fix some bugs --- src/helm/clients/anthropic_client.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/helm/clients/anthropic_client.py b/src/helm/clients/anthropic_client.py index e7ac7c47f4..f8687e05ce 100644 --- a/src/helm/clients/anthropic_client.py +++ b/src/helm/clients/anthropic_client.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, TypedDict, cast +from typing import Any, Dict, List, Optional, TypedDict, Union, cast import json import requests import time @@ -256,7 +256,7 @@ def make_request(self, request: Request) -> RequestResult: messages = [{"role": "user", "content": request.prompt}] elif request.multimodal_prompt is not None: - blocks = [] + blocks: List[Union[TextBlockParam, ImageBlockParam]] = [] for media_object in request.multimodal_prompt.media_objects: if media_object.is_type(IMAGE_TYPE): # TODO(#2439): Refactor out Request validation @@ -284,6 +284,7 @@ def make_request(self, request: Request) -> RequestResult: "text": media_object.text, } blocks.append(text_block) + messages = [{"role": "user", "content": blocks}] else: # TODO(#2439): Refactor out Request validation From 8a0d3d350a9d7f9a427a3c8237ebeb62e9643a9c Mon Sep 17 00:00:00 2001 From: Yifan Mai Date: Wed, 6 Mar 2024 11:13:02 -0800 Subject: [PATCH 3/4] Improve validation --- src/helm/clients/anthropic_client.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/helm/clients/anthropic_client.py b/src/helm/clients/anthropic_client.py index f8687e05ce..343365ab5f 100644 --- a/src/helm/clients/anthropic_client.py +++ b/src/helm/clients/anthropic_client.py @@ -247,15 +247,22 @@ def make_request(self, request: Request) -> RequestResult: system_message: Optional[MessageParam] = None if request.messages is not None: + # TODO(#2439): Refactor out Request validation + if request.multimodal_prompt is not None or request.prompt: + raise AnthropicMessagesRequestError( + "Exactly one of Request.messages, Request.prompt or Request.multimodel_prompt should be set" + ) messages = cast(List[MessageParam], request.messages) if messages[0]["role"] == "system": system_message = messages[0] messages = messages[1:] - elif request.prompt: - messages = [{"role": "user", "content": request.prompt}] - elif request.multimodal_prompt is not None: + # TODO(#2439): Refactor out Request validation + if request.messages is not None or request.prompt: + raise AnthropicMessagesRequestError( + "Exactly one of Request.messages, Request.prompt or Request.multimodel_prompt should be set" + ) blocks: List[Union[TextBlockParam, ImageBlockParam]] = [] for media_object in request.multimodal_prompt.media_objects: if media_object.is_type(IMAGE_TYPE): @@ -287,10 +294,7 @@ def make_request(self, request: Request) -> RequestResult: messages = [{"role": "user", "content": blocks}] else: - # TODO(#2439): Refactor out Request validation - raise AnthropicMessagesRequestError( - "Exactly one of Request.messages, Request.prompt or Request.multimodel_prompt should be set" - ) + messages = [{"role": "user", "content": request.prompt}] raw_request: AnthropicMessagesRequest = { "messages": messages, From 6b4ce9c5e3d329e2e572a6cac1692e40c0d0141b Mon Sep 17 00:00:00 2001 From: Yifan Mai Date: Wed, 6 Mar 2024 11:46:19 -0800 Subject: [PATCH 4/4] Add VISION_LANGUAGE_MODEL_TAG tag --- src/helm/config/model_metadata.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/helm/config/model_metadata.yaml b/src/helm/config/model_metadata.yaml index 3af6a7d3e8..c2e0831960 100644 --- a/src/helm/config/model_metadata.yaml +++ b/src/helm/config/model_metadata.yaml @@ -235,7 +235,7 @@ models: creator_organization_name: Anthropic access: limited release_date: 2024-03-04 - tags: [TEXT_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG] + tags: [TEXT_MODEL_TAG, VISION_LANGUAGE_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG] - name: anthropic/claude-3-opus-20240229 display_name: Claude 3 Opus (20240229) @@ -243,7 +243,7 @@ models: creator_organization_name: Anthropic access: limited release_date: 2024-03-04 - tags: [TEXT_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG] + tags: [TEXT_MODEL_TAG, VISION_LANGUAGE_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG] # DEPRECATED: Please do not use. - name: anthropic/stanford-online-all-v4-s3