diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index d51d90fa..4bbe4b23 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -400,15 +400,18 @@ class Joke(BaseModel): @model_validator(mode="before") @classmethod def set_disable_streaming(cls, values: Dict) -> Any: - values["provider"] = ( - values.get("provider") - or (values.get("model_id", values["model"])).split(".")[0] + model_id = values.get("model_id", values.get("model")) + model_parts = model_id.split(".") + values["provider"] = values.get("provider") or ( + model_parts[-2] if len(model_parts) > 1 else model_parts[0] ) - # As of 08/05/24 only Anthropic models support streamed tool calling + # As of 09/15/24 Anthropic and Cohere models support streamed tool calling if "disable_streaming" not in values: values["disable_streaming"] = ( - False if "anthropic" in values["provider"] else "tool_calling" + False + if values["provider"] in ["anthropic", "cohere"] + else "tool_calling" ) return values diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py index d50ec031..41f4694a 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py @@ -48,6 +48,28 @@ def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None super().test_tool_message_histories_list_content(model) +class TestBedrockCohereStandard(ChatModelIntegrationTests): + @property + def chat_model_class(self) -> Type[BaseChatModel]: + return ChatBedrockConverse + + @property + def chat_model_params(self) -> dict: + return {"model": "cohere.command-r-plus-v1:0"} + + @property + def standard_chat_model_params(self) -> dict: + return {"temperature": 0, "max_tokens": 100, "stop": []} + + @pytest.mark.xfail(reason="Cohere models don't support tool_choice.") + def test_structured_few_shot_examples(self, model: BaseChatModel) -> None: + pass + + @pytest.mark.xfail(reason="Cohere models don't support tool_choice.") + def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: + pass + + def test_structured_output_snake_case() -> None: model = ChatBedrockConverse( model="anthropic.claude-3-sonnet-20240229-v1:0", temperature=0 diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py index 0b45bdcf..37211bcd 100644 --- a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py @@ -1,7 +1,7 @@ """Test chat model integration.""" import base64 -from typing import Dict, List, Tuple, Type, cast +from typing import Dict, List, Tuple, Type, Union, cast import pytest from langchain_core.language_models import BaseChatModel @@ -399,3 +399,19 @@ def test_standard_tracing_params() -> None: "ls_temperature": 0.1, "ls_max_tokens": 10, } + + +@pytest.mark.parametrize( + "model_id, disable_streaming", + [ + ("anthropic.claude-3-5-sonnet-20240620-v1:0", False), + ("us.anthropic.claude-3-haiku-20240307-v1:0", False), + ("cohere.command-r-v1:0", False), + ("meta.llama3-1-405b-instruct-v1:0", "tool_calling"), + ], +) +def test_set_disable_streaming( + model_id: str, disable_streaming: Union[bool, str] +) -> None: + llm = ChatBedrockConverse(model=model_id, region_name="us-west-2") + assert llm.disable_streaming == disable_streaming