diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index ff8a3e26..4a9eb835 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -54,6 +54,19 @@ _BM = TypeVar("_BM", bound=BaseModel) _DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type] +AWS_REGIONS = [ + "us", + "sa", + "me", + "il", + "eu", + "cn", + "ca", + "ap", + "af", + "us-gov", +] + class ChatBedrockConverse(BaseChatModel): """Bedrock chat model integration built on the Bedrock converse API. @@ -360,10 +373,19 @@ class Config: @root_validator(pre=True) def set_disable_streaming(cls, values: Dict) -> Dict: - values["provider"] = ( - values.get("provider") - or (values.get("model_id", values["model"])).split(".")[0] - ) + if "provider" not in values: + model_id = values.get("model_id", values["model"]) + if not model_id: + raise ValueError("model_id must be provided") + + if model_id.startswith("arn"): + raise ValueError( + "Model provider should be supplied when passing a model ARN as " + "model_id" + ) + + parts = model_id.split(".", maxsplit=2) + values["provider"] = parts[1] if _model_is_inference(model_id) else parts[0] # As of 08/05/24 only Anthropic models support streamed tool calling if "disable_streaming" not in values: @@ -997,3 +1019,9 @@ def _format_openai_image_url(image_url: str) -> Dict: "format": match.group("media_type"), "source": {"bytes": _b64str_to_bytes(match.group("data"))}, } + + +def _model_is_inference(model_id: str) -> bool: + parts = model_id.split(".") + + return True if parts[0] in AWS_REGIONS else False diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 172688ec..6bb17e44 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -42,6 +42,18 @@ ALTERNATION_ERROR = ( "Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'." ) +AWS_REGIONS = [ + "us", + "sa", + "me", + "il", + "eu", + "cn", + "ca", + "ap", + "af", + "us-gov", +] def _add_newlines_before_ha(input_text: str) -> str: @@ -625,15 +637,10 @@ def _get_provider(self) -> str: "model_id" ) - # If model_id has region prefixed to them, - # for example eu.anthropic.claude-3-haiku-20240307-v1:0, - # provider is the second part, otherwise, the first part parts = self.model_id.split(".", maxsplit=2) - return ( - parts[1] - if (len(parts) > 1 and parts[0].lower() in {"eu", "us", "ap", "sa"}) - else parts[0] - ) + # inference models are in the format of .. + # the others are in the format of . + return parts[1] if self._model_is_inference else parts[0] def _get_model(self) -> str: return self.model_id.split(".", maxsplit=1)[-1] @@ -642,6 +649,12 @@ def _get_model(self) -> str: def _model_is_anthropic(self) -> bool: return self._get_provider() == "anthropic" + @property + def _model_is_inference(self) -> bool: + parts = self.model_id.split(".") + + return True if parts[0] in AWS_REGIONS else False + @property def _guardrails_enabled(self) -> bool: """