Skip to content

Commit

Permalink
feat: support aws proxy_url
Browse files Browse the repository at this point in the history
  • Loading branch information
wxg0103 committed Dec 24, 2024
1 parent f1cca66 commit 0cae0c8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def encryption_dict(self, model: Dict[str, object]):
region_name = forms.TextInputField('Region Name', required=True)
access_key_id = forms.TextInputField('Access Key ID', required=True)
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
base_url = forms.TextInputField('Proxy URL', required=False)

def get_model_params_setting_form(self, model_name):
return BedrockLLMModelParams()
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List, Dict

from botocore.config import Config
from langchain_community.chat_models import BedrockChat
from setting.models_provider.base_model_provider import MaxKBBaseModel

Expand Down Expand Up @@ -33,19 +35,34 @@ def is_cache_model():
return False

def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
streaming: bool = False, **kwargs):
streaming: bool = False, config: Config = None, **kwargs):
super().__init__(model_id=model_id, region_name=region_name,
credentials_profile_name=credentials_profile_name, streaming=streaming, **kwargs)
credentials_profile_name=credentials_profile_name, streaming=streaming, config=config,
**kwargs)

@classmethod
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
**model_kwargs) -> 'BedrockModel':
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)

config = {}
# 判断model_kwargs是否包含 base_url 且不为空
if 'base_url' in model_credential and model_credential['base_url']:
proxy_url = model_credential['base_url']
config = Config(
proxies={
'http': proxy_url,
'https': proxy_url
},
connect_timeout=60,
read_timeout=60
)

return cls(
model_id=model_name,
region_name=model_credential['region_name'],
credentials_profile_name=model_credential['credentials_profile_name'],
streaming=model_kwargs.pop('streaming', True),
model_kwargs=optional_params
model_kwargs=optional_params,
config=config
)

0 comments on commit 0cae0c8

Please sign in to comment.