Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

anthropic: less pydantic for client #28823

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import re
import warnings
from functools import cached_property
from operator import itemgetter
from typing import (
Any,
Expand Down Expand Up @@ -68,11 +69,10 @@
BaseModel,
ConfigDict,
Field,
PrivateAttr,
SecretStr,
model_validator,
)
from typing_extensions import NotRequired, Self
from typing_extensions import NotRequired

from langchain_anthropic.output_parsers import extract_tool_calls

Expand Down Expand Up @@ -541,9 +541,6 @@ class Joke(BaseModel):
populate_by_name=True,
)

_client: anthropic.Client = PrivateAttr(default=None) # type: ignore[assignment]
_async_client: anthropic.AsyncClient = PrivateAttr(default=None) # type: ignore[assignment]

Comment on lines -544 to -546
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these create mypy issues because we want the type to be not-None, but by default it's None

model: str = Field(alias="model_name")
"""Model name to use."""

Expand Down Expand Up @@ -661,13 +658,11 @@ def build_extra(cls, values: Dict) -> Any:
values = _build_model_kwargs(values, all_required_field_names)
return values

@model_validator(mode="after")
def post_init(self) -> Self:
api_key = self.anthropic_api_key.get_secret_value()
api_url = self.anthropic_api_url
@cached_property
def _client_params(self) -> Dict[str, Any]:
client_params: Dict[str, Any] = {
"api_key": api_key,
"base_url": api_url,
"api_key": self.anthropic_api_key.get_secret_value(),
"base_url": self.anthropic_api_url,
"max_retries": self.max_retries,
"default_headers": (self.default_headers or None),
}
Expand All @@ -677,9 +672,15 @@ def post_init(self) -> Self:
if self.default_request_timeout is None or self.default_request_timeout > 0:
client_params["timeout"] = self.default_request_timeout

self._client = anthropic.Client(**client_params)
self._async_client = anthropic.AsyncClient(**client_params)
return self
return client_params

@cached_property
def _client(self) -> anthropic.Client:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we don't have it at top-level, it doesn't matter, and also removes a validator

return anthropic.Client(**self._client_params)

@cached_property
def _async_client(self) -> anthropic.AsyncClient:
return anthropic.AsyncClient(**self._client_params)

def _get_request_payload(
self,
Expand Down
Loading