Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add anthropic claude-2.1 support #1591

Merged
merged 1 commit into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions api/core/model_providers/providers/anthropic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
'id': 'claude-2.1',
'name': 'claude-2.1',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'claude-2',
Expand All @@ -44,6 +47,11 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
'mode': ModelMode.CHAT.value,
},
]
else:
return []
Expand Down Expand Up @@ -73,12 +81,18 @@ def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> M
:param model_type:
:return:
"""
model_max_tokens = {
'claude-instant-1': 100000,
'claude-2': 100000,
'claude-2.1': 200000,
}

return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0),
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=model_max_tokens.get(model_name, 100000), default=256, precision=0),
)

@classmethod
Expand Down
10 changes: 8 additions & 2 deletions api/core/model_providers/rules/anthropic.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@
"currency": "USD"
},
"claude-2": {
"prompt": "11.02",
"completion": "32.68",
"prompt": "8.00",
"completion": "24.00",
"unit": "0.000001",
"currency": "USD"
},
"claude-2.1": {
"prompt": "8.00",
"completion": "24.00",
"unit": "0.000001",
"currency": "USD"
}
Expand Down
18 changes: 15 additions & 3 deletions api/core/third_party/langchain/llms/anthropic_llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict

from httpx import Limits
from langchain.chat_models import ChatAnthropic
from langchain.schema import ChatMessage, BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain.utils import get_from_dict_or_env, check_package_version
from pydantic import root_validator

Expand Down Expand Up @@ -29,8 +29,7 @@ def validate_environment(cls, values: Dict) -> Dict:
base_url=values["anthropic_api_url"],
api_key=values["anthropic_api_key"],
timeout=values["default_request_timeout"],
max_retries=0,
connection_pool_limits=Limits(max_connections=200, max_keepalive_connections=100),
max_retries=0
)
values["async_client"] = anthropic.AsyncAnthropic(
base_url=values["anthropic_api_url"],
Expand All @@ -46,3 +45,16 @@ def validate_environment(cls, values: Dict) -> Dict:
"Please it install it with `pip install anthropic`."
)
return values

def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"{message.content}"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
2 changes: 1 addition & 1 deletion api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ docx2txt==0.8
pypdfium2==4.16.0
resend~=0.5.1
pyjwt~=2.6.0
anthropic~=0.3.4
anthropic~=0.7.2
newspaper3k==0.2.8
google-api-python-client==2.90.0
wikipedia==1.4.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def mock_chat_generate_invalid(messages: List[BaseMessage],
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any):
raise anthropic.APIStatusError('Invalid credentials',
request=httpx._models.Request(
method='POST',
url='https://api.anthropic.com/v1/completions',
),
response=httpx._models.Response(
status_code=401,
request=httpx._models.Request(
method='POST',
url='https://api.anthropic.com/v1/completions',
)
),
body=None
)
Expand Down
Loading