Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(community): Add support for Bedrock cross-region inference models #187

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,19 @@
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]

AWS_REGIONS = [
"us",
"sa",
"me",
"il",
"eu",
"cn",
"ca",
"ap",
"af",
"us-gov",
]


class ChatBedrockConverse(BaseChatModel):
"""Bedrock chat model integration built on the Bedrock converse API.
Expand Down Expand Up @@ -360,10 +373,19 @@ class Config:

@root_validator(pre=True)
def set_disable_streaming(cls, values: Dict) -> Dict:
values["provider"] = (
values.get("provider")
or (values.get("model_id", values["model"])).split(".")[0]
)
if "provider" not in values:
model_id = values.get("model_id", values["model"])
if not model_id:
raise ValueError("model_id must be provided")

if model_id.startswith("arn"):
raise ValueError(
"Model provider should be supplied when passing a model ARN as "
"model_id"
)

parts = model_id.split(".", maxsplit=2)
values["provider"] = parts[1] if _model_is_inference(model_id) else parts[0]

# As of 08/05/24 only Anthropic models support streamed tool calling
if "disable_streaming" not in values:
Expand Down Expand Up @@ -997,3 +1019,9 @@ def _format_openai_image_url(image_url: str) -> Dict:
"format": match.group("media_type"),
"source": {"bytes": _b64str_to_bytes(match.group("data"))},
}


def _model_is_inference(model_id: str) -> bool:
parts = model_id.split(".")

return True if parts[0] in AWS_REGIONS else False
29 changes: 21 additions & 8 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@
ALTERNATION_ERROR = (
"Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'."
)
AWS_REGIONS = [
"us",
"sa",
"me",
"il",
"eu",
"cn",
"ca",
"ap",
"af",
"us-gov",
]


def _add_newlines_before_ha(input_text: str) -> str:
Expand Down Expand Up @@ -625,15 +637,10 @@ def _get_provider(self) -> str:
"model_id"
)

# If model_id has region prefixed to them,
# for example eu.anthropic.claude-3-haiku-20240307-v1:0,
# provider is the second part, otherwise, the first part
parts = self.model_id.split(".", maxsplit=2)
return (
parts[1]
if (len(parts) > 1 and parts[0].lower() in {"eu", "us", "ap", "sa"})
else parts[0]
)
# inference models are in the format of <region>.<provider>.<model_id>
# the others are in the format of <provider>.<model_id>
return parts[1] if self._model_is_inference else parts[0]

def _get_model(self) -> str:
return self.model_id.split(".", maxsplit=1)[-1]
Expand All @@ -642,6 +649,12 @@ def _get_model(self) -> str:
def _model_is_anthropic(self) -> bool:
return self._get_provider() == "anthropic"

@property
def _model_is_inference(self) -> bool:
parts = self.model_id.split(".")

return True if parts[0] in AWS_REGIONS else False

@property
def _guardrails_enabled(self) -> bool:
"""
Expand Down
Loading