Skip to content

Commit

Permalink
[ref] use one method to get boto client for aws bedrock (#11506)
Browse files Browse the repository at this point in the history
  • Loading branch information
warren830 authored and iamjoel committed Dec 16, 2024
1 parent 7633817 commit a926cd4
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import boto3
from botocore.config import Config


def get_bedrock_client(service_name, credentials=None):
client_config = Config(region_name=credentials["aws_region"])
aws_access_key_id = credentials["aws_access_key_id"]
aws_secret_access_key = credentials["aws_secret_access_key"]
if aws_access_key_id and aws_secret_access_key:
# use aksk to call bedrock
client = boto3.client(
service_name=service_name,
config=client_config,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
else:
# use iam without aksk to call
client = boto3.client(service_name=service_name, config=client_config)

return client
9 changes: 2 additions & 7 deletions api/core/model_runtime/model_providers/bedrock/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client

logger = logging.getLogger(__name__)
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
Expand Down Expand Up @@ -173,13 +174,7 @@ def _generate_with_converse(
:param stream: is stream response
:return: full response or stream response chunk generator result
"""
bedrock_client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
region_name=credentials["aws_region"],
)

bedrock_client = get_bedrock_client("bedrock-runtime", credentials)
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)

Expand Down
12 changes: 2 additions & 10 deletions api/core/model_runtime/model_providers/bedrock/rerank/rerank.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from typing import Optional

import boto3
from botocore.config import Config

from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
Expand All @@ -14,6 +11,7 @@
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client


class BedrockRerankModel(RerankModel):
Expand Down Expand Up @@ -48,13 +46,7 @@ def _invoke(
return RerankResult(model=model, docs=docs)

# initialize client
client_config = Config(region_name=credentials["aws_region"])
bedrock_runtime = boto3.client(
service_name="bedrock-agent-runtime",
config=client_config,
aws_access_key_id=credentials.get("aws_access_key_id", ""),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
)
bedrock_runtime = get_bedrock_client("bedrock-agent-runtime", credentials)
queries = [{"type": "TEXT", "textQuery": {"text": query}}]
text_sources = []
for text in docs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import time
from typing import Optional

import boto3
from botocore.config import Config
from botocore.exceptions import (
ClientError,
EndpointConnectionError,
Expand All @@ -25,6 +23,7 @@
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client

logger = logging.getLogger(__name__)

Expand All @@ -48,14 +47,7 @@ def _invoke(
:param input_type: input type
:return: embeddings result
"""
client_config = Config(region_name=credentials["aws_region"])

bedrock_runtime = boto3.client(
service_name="bedrock-runtime",
config=client_config,
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
)
bedrock_runtime = get_bedrock_client("bedrock-runtime", credentials)

embeddings = []
token_usage = 0
Expand Down

0 comments on commit a926cd4

Please sign in to comment.