From f10ff6c9ac7e6b0d9cf4c70e5b93161bbe6f45c6 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Tue, 15 Oct 2024 19:19:23 -0700 Subject: [PATCH 1/3] Fixes support for cross region inference. Added cohere to disable_streaming. --- .../chat_models/bedrock_converse.py | 13 ++++++++----- .../chat_models/test_bedrock_converse.py | 18 +++++++++++++++++- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index d51d90fa..4bbe4b23 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -400,15 +400,18 @@ class Joke(BaseModel): @model_validator(mode="before") @classmethod def set_disable_streaming(cls, values: Dict) -> Any: - values["provider"] = ( - values.get("provider") - or (values.get("model_id", values["model"])).split(".")[0] + model_id = values.get("model_id", values.get("model")) + model_parts = model_id.split(".") + values["provider"] = values.get("provider") or ( + model_parts[-2] if len(model_parts) > 1 else model_parts[0] ) - # As of 08/05/24 only Anthropic models support streamed tool calling + # As of 09/15/24 Anthropic and Cohere models support streamed tool calling if "disable_streaming" not in values: values["disable_streaming"] = ( - False if "anthropic" in values["provider"] else "tool_calling" + False + if values["provider"] in ["anthropic", "cohere"] + else "tool_calling" ) return values diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py index 0b45bdcf..f19e4bca 100644 --- a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py @@ -1,7 +1,7 @@ """Test chat model integration.""" import base64 -from typing import Dict, List, Tuple, Type, cast +from typing import Dict, List, Tuple, Type, Union, cast import pytest from langchain_core.language_models import BaseChatModel @@ -399,3 +399,19 @@ def test_standard_tracing_params() -> None: "ls_temperature": 0.1, "ls_max_tokens": 10, } + + +@pytest.mark.parametrize( + "model_id, disable_streaming", + [ + ("anthropic.claude-3-5-sonnet-20240620-v1:0", False), + ("us.anthropic.claude-3-haiku-20240307-v1:0", False), + ("cohere.command-r-v1:0", False), + ("meta.llama3-1-405b-instruct-v1:0", "tool_calling"), + ], +) +def test_set_disable_streaming( + model_id: str, disable_streaming: Union[bool, str] +) -> None: + llm = ChatBedrockConverse(model=model_id) + assert llm.disable_streaming == disable_streaming From ce4c80e4857c074ba9f5b8514a41446ba0ad9b5b Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Tue, 15 Oct 2024 19:22:28 -0700 Subject: [PATCH 2/3] Fixed unit tests for CI. --- libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py index f19e4bca..37211bcd 100644 --- a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py @@ -413,5 +413,5 @@ def test_standard_tracing_params() -> None: def test_set_disable_streaming( model_id: str, disable_streaming: Union[bool, str] ) -> None: - llm = ChatBedrockConverse(model=model_id) + llm = ChatBedrockConverse(model=model_id, region_name="us-west-2") assert llm.disable_streaming == disable_streaming From d84466c1fa67c022476bfd22e42a4fa9ba973c09 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Tue, 15 Oct 2024 21:15:49 -0700 Subject: [PATCH 3/3] Added standard integration tests for Cohere. --- .../chat_models/test_bedrock_converse.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py index d50ec031..41f4694a 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py @@ -48,6 +48,28 @@ def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None super().test_tool_message_histories_list_content(model) +class TestBedrockCohereStandard(ChatModelIntegrationTests): + @property + def chat_model_class(self) -> Type[BaseChatModel]: + return ChatBedrockConverse + + @property + def chat_model_params(self) -> dict: + return {"model": "cohere.command-r-plus-v1:0"} + + @property + def standard_chat_model_params(self) -> dict: + return {"temperature": 0, "max_tokens": 100, "stop": []} + + @pytest.mark.xfail(reason="Cohere models don't support tool_choice.") + def test_structured_few_shot_examples(self, model: BaseChatModel) -> None: + pass + + @pytest.mark.xfail(reason="Cohere models don't support tool_choice.") + def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: + pass + + def test_structured_output_snake_case() -> None: model = ChatBedrockConverse( model="anthropic.claude-3-sonnet-20240229-v1:0", temperature=0