diff --git a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py new file mode 100644 index 00000000000000..a19ffbb20a6a9e --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py @@ -0,0 +1,21 @@ +import boto3 +from botocore.config import Config + + +def get_bedrock_client(service_name, credentials=None): + client_config = Config(region_name=credentials["aws_region"]) + aws_access_key_id = credentials["aws_access_key_id"] + aws_secret_access_key = credentials["aws_secret_access_key"] + if aws_access_key_id and aws_secret_access_key: + # use aksk to call bedrock + client = boto3.client( + service_name=service_name, + config=client_config, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + ) + else: + # use iam without aksk to call + client = boto3.client(service_name=service_name, config=client_config) + + return client diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index e6e8a765ee9e05..75ed7ad62404cb 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -40,6 +40,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client logger = logging.getLogger(__name__) ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. @@ -173,13 +174,7 @@ def _generate_with_converse( :param stream: is stream response :return: full response or stream response chunk generator result """ - bedrock_client = boto3.client( - service_name="bedrock-runtime", - aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key"), - region_name=credentials["aws_region"], - ) - + bedrock_client = get_bedrock_client("bedrock-runtime", credentials) system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages) inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop) diff --git a/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py b/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py index 397f65e8c960c8..e134db646f3d39 100644 --- a/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py @@ -1,8 +1,5 @@ from typing import Optional -import boto3 -from botocore.config import Config - from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, @@ -14,6 +11,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.rerank_model import RerankModel +from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client class BedrockRerankModel(RerankModel): @@ -48,13 +46,7 @@ def _invoke( return RerankResult(model=model, docs=docs) # initialize client - client_config = Config(region_name=credentials["aws_region"]) - bedrock_runtime = boto3.client( - service_name="bedrock-agent-runtime", - config=client_config, - aws_access_key_id=credentials.get("aws_access_key_id", ""), - aws_secret_access_key=credentials.get("aws_secret_access_key"), - ) + bedrock_runtime = get_bedrock_client("bedrock-agent-runtime", credentials) queries = [{"type": "TEXT", "textQuery": {"text": query}}] text_sources = [] for text in docs: diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index 2f998d8bdaee90..5505797f7658b3 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -3,8 +3,6 @@ import time from typing import Optional -import boto3 -from botocore.config import Config from botocore.exceptions import ( ClientError, EndpointConnectionError, @@ -25,6 +23,7 @@ InvokeServerUnavailableError, ) from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client logger = logging.getLogger(__name__) @@ -48,14 +47,7 @@ def _invoke( :param input_type: input type :return: embeddings result """ - client_config = Config(region_name=credentials["aws_region"]) - - bedrock_runtime = boto3.client( - service_name="bedrock-runtime", - config=client_config, - aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key"), - ) + bedrock_runtime = get_bedrock_client("bedrock-runtime", credentials) embeddings = [] token_usage = 0