diff --git a/pyproject.toml b/pyproject.toml index 3c82469..7d5561f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,9 +76,12 @@ format = { chain = [ "lint" = { chain = [ "check:ruff", "typecheck", + "check:importable", ]} -"check:ruff" = "ruff ." -"fix:ruff" = "ruff --fix ." +"check:ruff" = "ruff check ." +"fix:ruff" = "ruff check --fix ." + +"check:importable" = "python -c 'import groq'" typecheck = { chain = [ "typecheck:pyright", @@ -162,6 +165,11 @@ reportPrivateUsage = false line-length = 120 output-format = "grouped" target-version = "py37" + +[tool.ruff.format] +docstring-code-format = true + +[tool.ruff.lint] select = [ # isort "I", @@ -190,10 +198,6 @@ unfixable = [ "T201", "T203", ] -ignore-init-module-imports = true - -[tool.ruff.format] -docstring-code-format = true [tool.ruff.lint.flake8-tidy-imports.banned-api] "functools.lru_cache".msg = "This function does not retain type information for the wrapped function's arguments; The `lru_cache` function from `_utils` should be used instead" @@ -205,7 +209,7 @@ combine-as-imports = true extra-standard-library = ["typing_extensions"] known-first-party = ["groq", "tests"] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "bin/**.py" = ["T201", "T203"] "scripts/**.py" = ["T201", "T203"] "tests/**.py" = ["T201", "T203"] diff --git a/requirements-dev.lock b/requirements-dev.lock index 3ee1ee4..6b7d864 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -70,7 +70,7 @@ pydantic-core==2.18.2 # via pydantic pygments==2.18.0 # via rich -pyright==1.1.364 +pyright==1.1.374 pytest==7.1.1 # via pytest-asyncio pytest-asyncio==0.21.1 @@ -80,7 +80,7 @@ pytz==2023.3.post1 # via dirty-equals respx==0.20.2 rich==13.7.1 -ruff==0.1.9 +ruff==0.5.6 setuptools==68.2.2 # via nodeenv six==1.16.0 diff --git a/scripts/mock b/scripts/mock index f586157..d2814ae 100755 --- a/scripts/mock +++ b/scripts/mock @@ -21,7 +21,7 @@ echo "==> Starting mock server with URL ${URL}" # Run prism mock on the given spec if [ "$1" == "--daemon" ]; then - npm exec --package=@stainless-api/prism-cli@5.8.4 -- prism mock "$URL" &> .prism.log & + npm exec --package=@stainless-api/prism-cli@5.8.5 -- prism mock "$URL" &> .prism.log & # Wait for server to come online echo -n "Waiting for server" @@ -37,5 +37,5 @@ if [ "$1" == "--daemon" ]; then echo else - npm exec --package=@stainless-api/prism-cli@5.8.4 -- prism mock "$URL" + npm exec --package=@stainless-api/prism-cli@5.8.5 -- prism mock "$URL" fi diff --git a/src/groq/_base_client.py b/src/groq/_base_client.py index 4b3ddd1..9207153 100644 --- a/src/groq/_base_client.py +++ b/src/groq/_base_client.py @@ -124,16 +124,14 @@ def __init__( self, *, url: URL, - ) -> None: - ... + ) -> None: ... @overload def __init__( self, *, params: Query, - ) -> None: - ... + ) -> None: ... def __init__( self, @@ -166,8 +164,7 @@ def has_next_page(self) -> bool: return False return self.next_page_info() is not None - def next_page_info(self) -> Optional[PageInfo]: - ... + def next_page_info(self) -> Optional[PageInfo]: ... def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body] ... @@ -903,8 +900,7 @@ def request( *, stream: Literal[True], stream_cls: Type[_StreamT], - ) -> _StreamT: - ... + ) -> _StreamT: ... @overload def request( @@ -914,8 +910,7 @@ def request( remaining_retries: Optional[int] = None, *, stream: Literal[False] = False, - ) -> ResponseT: - ... + ) -> ResponseT: ... @overload def request( @@ -926,8 +921,7 @@ def request( *, stream: bool = False, stream_cls: Type[_StreamT] | None = None, - ) -> ResponseT | _StreamT: - ... + ) -> ResponseT | _StreamT: ... def request( self, @@ -1049,6 +1043,7 @@ def _request( response=response, stream=stream, stream_cls=stream_cls, + retries_taken=options.get_max_retries(self.max_retries) - retries, ) def _retry_request( @@ -1090,6 +1085,7 @@ def _process_response( response: httpx.Response, stream: bool, stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, + retries_taken: int = 0, ) -> ResponseT: origin = get_origin(cast_to) or cast_to @@ -1107,6 +1103,7 @@ def _process_response( stream=stream, stream_cls=stream_cls, options=options, + retries_taken=retries_taken, ), ) @@ -1120,6 +1117,7 @@ def _process_response( stream=stream, stream_cls=stream_cls, options=options, + retries_taken=retries_taken, ) if bool(response.request.headers.get(RAW_RESPONSE_HEADER)): return cast(ResponseT, api_response) @@ -1152,8 +1150,7 @@ def get( cast_to: Type[ResponseT], options: RequestOptions = {}, stream: Literal[False] = False, - ) -> ResponseT: - ... + ) -> ResponseT: ... @overload def get( @@ -1164,8 +1161,7 @@ def get( options: RequestOptions = {}, stream: Literal[True], stream_cls: type[_StreamT], - ) -> _StreamT: - ... + ) -> _StreamT: ... @overload def get( @@ -1176,8 +1172,7 @@ def get( options: RequestOptions = {}, stream: bool, stream_cls: type[_StreamT] | None = None, - ) -> ResponseT | _StreamT: - ... + ) -> ResponseT | _StreamT: ... def get( self, @@ -1203,8 +1198,7 @@ def post( options: RequestOptions = {}, files: RequestFiles | None = None, stream: Literal[False] = False, - ) -> ResponseT: - ... + ) -> ResponseT: ... @overload def post( @@ -1217,8 +1211,7 @@ def post( files: RequestFiles | None = None, stream: Literal[True], stream_cls: type[_StreamT], - ) -> _StreamT: - ... + ) -> _StreamT: ... @overload def post( @@ -1231,8 +1224,7 @@ def post( files: RequestFiles | None = None, stream: bool, stream_cls: type[_StreamT] | None = None, - ) -> ResponseT | _StreamT: - ... + ) -> ResponseT | _StreamT: ... def post( self, @@ -1465,8 +1457,7 @@ async def request( *, stream: Literal[False] = False, remaining_retries: Optional[int] = None, - ) -> ResponseT: - ... + ) -> ResponseT: ... @overload async def request( @@ -1477,8 +1468,7 @@ async def request( stream: Literal[True], stream_cls: type[_AsyncStreamT], remaining_retries: Optional[int] = None, - ) -> _AsyncStreamT: - ... + ) -> _AsyncStreamT: ... @overload async def request( @@ -1489,8 +1479,7 @@ async def request( stream: bool, stream_cls: type[_AsyncStreamT] | None = None, remaining_retries: Optional[int] = None, - ) -> ResponseT | _AsyncStreamT: - ... + ) -> ResponseT | _AsyncStreamT: ... async def request( self, @@ -1610,6 +1599,7 @@ async def _request( response=response, stream=stream, stream_cls=stream_cls, + retries_taken=options.get_max_retries(self.max_retries) - retries, ) async def _retry_request( @@ -1649,6 +1639,7 @@ async def _process_response( response: httpx.Response, stream: bool, stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, + retries_taken: int = 0, ) -> ResponseT: origin = get_origin(cast_to) or cast_to @@ -1666,6 +1657,7 @@ async def _process_response( stream=stream, stream_cls=stream_cls, options=options, + retries_taken=retries_taken, ), ) @@ -1679,6 +1671,7 @@ async def _process_response( stream=stream, stream_cls=stream_cls, options=options, + retries_taken=retries_taken, ) if bool(response.request.headers.get(RAW_RESPONSE_HEADER)): return cast(ResponseT, api_response) @@ -1701,8 +1694,7 @@ async def get( cast_to: Type[ResponseT], options: RequestOptions = {}, stream: Literal[False] = False, - ) -> ResponseT: - ... + ) -> ResponseT: ... @overload async def get( @@ -1713,8 +1705,7 @@ async def get( options: RequestOptions = {}, stream: Literal[True], stream_cls: type[_AsyncStreamT], - ) -> _AsyncStreamT: - ... + ) -> _AsyncStreamT: ... @overload async def get( @@ -1725,8 +1716,7 @@ async def get( options: RequestOptions = {}, stream: bool, stream_cls: type[_AsyncStreamT] | None = None, - ) -> ResponseT | _AsyncStreamT: - ... + ) -> ResponseT | _AsyncStreamT: ... async def get( self, @@ -1750,8 +1740,7 @@ async def post( files: RequestFiles | None = None, options: RequestOptions = {}, stream: Literal[False] = False, - ) -> ResponseT: - ... + ) -> ResponseT: ... @overload async def post( @@ -1764,8 +1753,7 @@ async def post( options: RequestOptions = {}, stream: Literal[True], stream_cls: type[_AsyncStreamT], - ) -> _AsyncStreamT: - ... + ) -> _AsyncStreamT: ... @overload async def post( @@ -1778,8 +1766,7 @@ async def post( options: RequestOptions = {}, stream: bool, stream_cls: type[_AsyncStreamT] | None = None, - ) -> ResponseT | _AsyncStreamT: - ... + ) -> ResponseT | _AsyncStreamT: ... async def post( self, diff --git a/src/groq/_compat.py b/src/groq/_compat.py index c919b5a..21fe694 100644 --- a/src/groq/_compat.py +++ b/src/groq/_compat.py @@ -7,7 +7,7 @@ import pydantic from pydantic.fields import FieldInfo -from ._types import StrBytesIntFloat +from ._types import IncEx, StrBytesIntFloat _T = TypeVar("_T") _ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel) @@ -133,17 +133,20 @@ def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: def model_dump( model: pydantic.BaseModel, *, + exclude: IncEx = None, exclude_unset: bool = False, exclude_defaults: bool = False, ) -> dict[str, Any]: if PYDANTIC_V2: return model.model_dump( + exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, ) return cast( "dict[str, Any]", model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, ), @@ -159,22 +162,19 @@ def model_parse(model: type[_ModelT], data: Any) -> _ModelT: # generic models if TYPE_CHECKING: - class GenericModel(pydantic.BaseModel): - ... + class GenericModel(pydantic.BaseModel): ... else: if PYDANTIC_V2: # there no longer needs to be a distinction in v2 but # we still have to create our own subclass to avoid # inconsistent MRO ordering errors - class GenericModel(pydantic.BaseModel): - ... + class GenericModel(pydantic.BaseModel): ... else: import pydantic.generics - class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): - ... + class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ... # cached properties @@ -193,26 +193,21 @@ class typed_cached_property(Generic[_T]): func: Callable[[Any], _T] attrname: str | None - def __init__(self, func: Callable[[Any], _T]) -> None: - ... + def __init__(self, func: Callable[[Any], _T]) -> None: ... @overload - def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: - ... + def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ... @overload - def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: - ... + def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ... def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self: raise NotImplementedError() - def __set_name__(self, owner: type[Any], name: str) -> None: - ... + def __set_name__(self, owner: type[Any], name: str) -> None: ... # __set__ is not defined at runtime, but @cached_property is designed to be settable - def __set__(self, instance: object, value: _T) -> None: - ... + def __set__(self, instance: object, value: _T) -> None: ... else: try: from functools import cached_property as cached_property diff --git a/src/groq/_files.py b/src/groq/_files.py index 0d2022a..715cc20 100644 --- a/src/groq/_files.py +++ b/src/groq/_files.py @@ -39,13 +39,11 @@ def assert_is_file_content(obj: object, *, key: str | None = None) -> None: @overload -def to_httpx_files(files: None) -> None: - ... +def to_httpx_files(files: None) -> None: ... @overload -def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: - ... +def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ... def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: @@ -83,13 +81,11 @@ def _read_file_content(file: FileContent) -> HttpxFileContent: @overload -async def async_to_httpx_files(files: None) -> None: - ... +async def async_to_httpx_files(files: None) -> None: ... @overload -async def async_to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: - ... +async def async_to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ... async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: diff --git a/src/groq/_models.py b/src/groq/_models.py index eb7ce3b..5148d5a 100644 --- a/src/groq/_models.py +++ b/src/groq/_models.py @@ -406,6 +406,15 @@ def build( return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs)) +def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T: + """Loose coercion to the expected type with construction of nested values. + + Note: the returned value from this function is not guaranteed to match the + given type. + """ + return cast(_T, construct_type(value=value, type_=type_)) + + def construct_type(*, value: object, type_: object) -> object: """Loose coercion to the expected type with construction of nested values. diff --git a/src/groq/_response.py b/src/groq/_response.py index c4bc069..d7ae9cd 100644 --- a/src/groq/_response.py +++ b/src/groq/_response.py @@ -55,6 +55,9 @@ class BaseAPIResponse(Generic[R]): http_response: httpx.Response + retries_taken: int + """The number of retries made. If no retries happened this will be `0`""" + def __init__( self, *, @@ -64,6 +67,7 @@ def __init__( stream: bool, stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, options: FinalRequestOptions, + retries_taken: int = 0, ) -> None: self._cast_to = cast_to self._client = client @@ -72,6 +76,7 @@ def __init__( self._stream_cls = stream_cls self._options = options self.http_response = raw + self.retries_taken = retries_taken @property def headers(self) -> httpx.Headers: @@ -255,12 +260,10 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: class APIResponse(BaseAPIResponse[R]): @overload - def parse(self, *, to: type[_T]) -> _T: - ... + def parse(self, *, to: type[_T]) -> _T: ... @overload - def parse(self) -> R: - ... + def parse(self) -> R: ... def parse(self, *, to: type[_T] | None = None) -> R | _T: """Returns the rich python representation of this response's data. @@ -359,12 +362,10 @@ def iter_lines(self) -> Iterator[str]: class AsyncAPIResponse(BaseAPIResponse[R]): @overload - async def parse(self, *, to: type[_T]) -> _T: - ... + async def parse(self, *, to: type[_T]) -> _T: ... @overload - async def parse(self) -> R: - ... + async def parse(self) -> R: ... async def parse(self, *, to: type[_T] | None = None) -> R | _T: """Returns the rich python representation of this response's data. diff --git a/src/groq/_types.py b/src/groq/_types.py index b970455..f85d73b 100644 --- a/src/groq/_types.py +++ b/src/groq/_types.py @@ -111,8 +111,7 @@ class NotGiven: For example: ```py - def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: - ... + def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ... get(timeout=1) # 1s timeout @@ -162,16 +161,14 @@ def build( *, response: Response, data: object, - ) -> _T: - ... + ) -> _T: ... Headers = Mapping[str, Union[str, Omit]] class HeadersLikeProtocol(Protocol): - def get(self, __key: str) -> str | None: - ... + def get(self, __key: str) -> str | None: ... HeadersLike = Union[Headers, HeadersLikeProtocol] diff --git a/src/groq/_utils/_proxy.py b/src/groq/_utils/_proxy.py index c46a62a..ffd883e 100644 --- a/src/groq/_utils/_proxy.py +++ b/src/groq/_utils/_proxy.py @@ -59,5 +59,4 @@ def __as_proxied__(self) -> T: return cast(T, self) @abstractmethod - def __load__(self) -> T: - ... + def __load__(self) -> T: ... diff --git a/src/groq/_utils/_reflection.py b/src/groq/_utils/_reflection.py index 9a53c7b..89aa712 100644 --- a/src/groq/_utils/_reflection.py +++ b/src/groq/_utils/_reflection.py @@ -34,7 +34,7 @@ def assert_signatures_in_sync( if custom_param.annotation != source_param.annotation: errors.append( - f"types for the `{name}` param are do not match; source={repr(source_param.annotation)} checking={repr(source_param.annotation)}" + f"types for the `{name}` param are do not match; source={repr(source_param.annotation)} checking={repr(custom_param.annotation)}" ) continue diff --git a/src/groq/_utils/_utils.py b/src/groq/_utils/_utils.py index 34797c2..2fc5a1c 100644 --- a/src/groq/_utils/_utils.py +++ b/src/groq/_utils/_utils.py @@ -211,20 +211,17 @@ def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]: Example usage: ```py @overload - def foo(*, a: str) -> str: - ... + def foo(*, a: str) -> str: ... @overload - def foo(*, b: bool) -> str: - ... + def foo(*, b: bool) -> str: ... # This enforces the same constraints that a static type checker would # i.e. that either a or b must be passed to the function @required_args(["a"], ["b"]) - def foo(*, a: str | None = None, b: bool | None = None) -> str: - ... + def foo(*, a: str | None = None, b: bool | None = None) -> str: ... ``` """ @@ -286,18 +283,15 @@ def wrapper(*args: object, **kwargs: object) -> object: @overload -def strip_not_given(obj: None) -> None: - ... +def strip_not_given(obj: None) -> None: ... @overload -def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: - ... +def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ... @overload -def strip_not_given(obj: object) -> object: - ... +def strip_not_given(obj: object) -> object: ... def strip_not_given(obj: object | None) -> object: diff --git a/src/groq/types/chat/chat_completion_content_part_param.py b/src/groq/types/chat/chat_completion_content_part_param.py index f9b5f71..e0c6e48 100644 --- a/src/groq/types/chat/chat_completion_content_part_param.py +++ b/src/groq/types/chat/chat_completion_content_part_param.py @@ -3,10 +3,13 @@ from __future__ import annotations from typing import Union +from typing_extensions import TypeAlias from .chat_completion_content_part_text_param import ChatCompletionContentPartTextParam from .chat_completion_content_part_image_param import ChatCompletionContentPartImageParam __all__ = ["ChatCompletionContentPartParam"] -ChatCompletionContentPartParam = Union[ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam] +ChatCompletionContentPartParam: TypeAlias = Union[ + ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam +] diff --git a/src/groq/types/chat/chat_completion_message_param.py b/src/groq/types/chat/chat_completion_message_param.py index a3644a5..ec65d94 100644 --- a/src/groq/types/chat/chat_completion_message_param.py +++ b/src/groq/types/chat/chat_completion_message_param.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import Union +from typing_extensions import TypeAlias from .chat_completion_tool_message_param import ChatCompletionToolMessageParam from .chat_completion_user_message_param import ChatCompletionUserMessageParam @@ -12,7 +13,7 @@ __all__ = ["ChatCompletionMessageParam"] -ChatCompletionMessageParam = Union[ +ChatCompletionMessageParam: TypeAlias = Union[ ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam, ChatCompletionAssistantMessageParam, diff --git a/src/groq/types/chat/chat_completion_role.py b/src/groq/types/chat/chat_completion_role.py index 1fd8388..c2ebef7 100644 --- a/src/groq/types/chat/chat_completion_role.py +++ b/src/groq/types/chat/chat_completion_role.py @@ -1,7 +1,7 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing_extensions import Literal +from typing_extensions import Literal, TypeAlias __all__ = ["ChatCompletionRole"] -ChatCompletionRole = Literal["system", "user", "assistant", "tool", "function"] +ChatCompletionRole: TypeAlias = Literal["system", "user", "assistant", "tool", "function"] diff --git a/src/groq/types/chat/chat_completion_tool_choice_option_param.py b/src/groq/types/chat/chat_completion_tool_choice_option_param.py index 1d3c250..7dedf04 100644 --- a/src/groq/types/chat/chat_completion_tool_choice_option_param.py +++ b/src/groq/types/chat/chat_completion_tool_choice_option_param.py @@ -3,10 +3,12 @@ from __future__ import annotations from typing import Union -from typing_extensions import Literal +from typing_extensions import Literal, TypeAlias from .chat_completion_named_tool_choice_param import ChatCompletionNamedToolChoiceParam __all__ = ["ChatCompletionToolChoiceOptionParam"] -ChatCompletionToolChoiceOptionParam = Union[Literal["none", "auto", "required"], ChatCompletionNamedToolChoiceParam] +ChatCompletionToolChoiceOptionParam: TypeAlias = Union[ + Literal["none", "auto", "required"], ChatCompletionNamedToolChoiceParam +] diff --git a/src/groq/types/chat/chat_completion_tool_param.py b/src/groq/types/chat/chat_completion_tool_param.py index 0cf6ea7..6c2b1a3 100644 --- a/src/groq/types/chat/chat_completion_tool_param.py +++ b/src/groq/types/chat/chat_completion_tool_param.py @@ -4,13 +4,13 @@ from typing_extensions import Literal, Required, TypedDict -from ...types import shared_params +from ..shared_params.function_definition import FunctionDefinition __all__ = ["ChatCompletionToolParam"] class ChatCompletionToolParam(TypedDict, total=False): - function: Required[shared_params.FunctionDefinition] + function: Required[FunctionDefinition] type: Required[Literal["function"]] """The type of the tool. Currently, only `function` is supported.""" diff --git a/src/groq/types/chat/completion_create_params.py b/src/groq/types/chat/completion_create_params.py index ef01720..0157906 100644 --- a/src/groq/types/chat/completion_create_params.py +++ b/src/groq/types/chat/completion_create_params.py @@ -3,11 +3,11 @@ from __future__ import annotations from typing import Dict, List, Union, Iterable, Optional -from typing_extensions import Literal, Required, TypedDict +from typing_extensions import Literal, Required, TypeAlias, TypedDict -from ...types import shared_params from .chat_completion_tool_param import ChatCompletionToolParam from .chat_completion_message_param import ChatCompletionMessageParam +from ..shared_params.function_parameters import FunctionParameters from .chat_completion_tool_choice_option_param import ChatCompletionToolChoiceOptionParam from .chat_completion_function_call_option_param import ChatCompletionFunctionCallOptionParam @@ -176,7 +176,7 @@ class CompletionCreateParams(TypedDict, total=False): """ -FunctionCall = Union[Literal["none", "auto", "required"], ChatCompletionFunctionCallOptionParam] +FunctionCall: TypeAlias = Union[Literal["none", "auto", "required"], ChatCompletionFunctionCallOptionParam] class Function(TypedDict, total=False): @@ -193,7 +193,7 @@ class Function(TypedDict, total=False): how to call the function. """ - parameters: shared_params.FunctionParameters + parameters: FunctionParameters """The parameters the functions accepts, described as a JSON Schema object. See the docs on [tool use](/docs/tool-use) for examples, and the diff --git a/src/groq/types/shared/function_parameters.py b/src/groq/types/shared/function_parameters.py index c9524e4..a3d83e3 100644 --- a/src/groq/types/shared/function_parameters.py +++ b/src/groq/types/shared/function_parameters.py @@ -1,7 +1,8 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import Dict +from typing_extensions import TypeAlias __all__ = ["FunctionParameters"] -FunctionParameters = Dict[str, object] +FunctionParameters: TypeAlias = Dict[str, object] diff --git a/src/groq/types/shared_params/function_definition.py b/src/groq/types/shared_params/function_definition.py index cbac7b7..0799867 100644 --- a/src/groq/types/shared_params/function_definition.py +++ b/src/groq/types/shared_params/function_definition.py @@ -4,7 +4,7 @@ from typing_extensions import Required, TypedDict -from ...types import shared_params +from .function_parameters import FunctionParameters __all__ = ["FunctionDefinition"] @@ -23,7 +23,7 @@ class FunctionDefinition(TypedDict, total=False): how to call the function. """ - parameters: shared_params.FunctionParameters + parameters: FunctionParameters """The parameters the functions accepts, described as a JSON Schema object. See the docs on [tool use](/docs/tool-use) for examples, and the diff --git a/src/groq/types/shared_params/function_parameters.py b/src/groq/types/shared_params/function_parameters.py index 5b40efb..45fc742 100644 --- a/src/groq/types/shared_params/function_parameters.py +++ b/src/groq/types/shared_params/function_parameters.py @@ -3,7 +3,8 @@ from __future__ import annotations from typing import Dict +from typing_extensions import TypeAlias __all__ = ["FunctionParameters"] -FunctionParameters = Dict[str, object] +FunctionParameters: TypeAlias = Dict[str, object] diff --git a/tests/api_resources/chat/test_completions.py b/tests/api_resources/chat/test_completions.py index 4e51043..01c673f 100644 --- a/tests/api_resources/chat/test_completions.py +++ b/tests/api_resources/chat/test_completions.py @@ -46,18 +46,18 @@ def test_method_create_with_all_params(self, client: Groq) -> None: function_call="none", functions=[ { - "description": "description", "name": "name", + "description": "description", "parameters": {"foo": "bar"}, }, { - "description": "description", "name": "name", + "description": "description", "parameters": {"foo": "bar"}, }, { - "description": "description", "name": "name", + "description": "description", "parameters": {"foo": "bar"}, }, ], @@ -67,7 +67,7 @@ def test_method_create_with_all_params(self, client: Groq) -> None: n=1, parallel_tool_calls=True, presence_penalty=-2, - response_format={"type": "json_object"}, + response_format={"type": "text"}, seed=0, stop="\n", stream=False, @@ -76,28 +76,28 @@ def test_method_create_with_all_params(self, client: Groq) -> None: tool_choice="none", tools=[ { - "type": "function", "function": { - "description": "description", "name": "name", + "description": "description", "parameters": {"foo": "bar"}, }, + "type": "function", }, { - "type": "function", "function": { - "description": "description", "name": "name", + "description": "description", "parameters": {"foo": "bar"}, }, + "type": "function", }, { - "type": "function", "function": { - "description": "description", "name": "name", + "description": "description", "parameters": {"foo": "bar"}, }, + "type": "function", }, ], top_logprobs=0, @@ -175,18 +175,18 @@ async def test_method_create_with_all_params(self, async_client: AsyncGroq) -> N function_call="none", functions=[ { - "description": "description", "name": "name", + "description": "description", "parameters": {"foo": "bar"}, }, { - "description": "description", "name": "name", + "description": "description", "parameters": {"foo": "bar"}, }, { - "description": "description", "name": "name", + "description": "description", "parameters": {"foo": "bar"}, }, ], @@ -196,7 +196,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncGroq) -> N n=1, parallel_tool_calls=True, presence_penalty=-2, - response_format={"type": "json_object"}, + response_format={"type": "text"}, seed=0, stop="\n", stream=False, @@ -205,28 +205,28 @@ async def test_method_create_with_all_params(self, async_client: AsyncGroq) -> N tool_choice="none", tools=[ { - "type": "function", "function": { - "description": "description", "name": "name", + "description": "description", "parameters": {"foo": "bar"}, }, + "type": "function", }, { - "type": "function", "function": { - "description": "description", "name": "name", + "description": "description", "parameters": {"foo": "bar"}, }, + "type": "function", }, { - "type": "function", "function": { - "description": "description", "name": "name", + "description": "description", "parameters": {"foo": "bar"}, }, + "type": "function", }, ], top_logprobs=0, diff --git a/tests/test_client.py b/tests/test_client.py index 0983b1d..fe06746 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -17,6 +17,7 @@ from pydantic import ValidationError from groq import Groq, AsyncGroq, APIResponseValidationError +from groq._types import Omit from groq._models import BaseModel, FinalRequestOptions from groq._constants import RAW_RESPONSE_HEADER from groq._exceptions import GroqError, APIStatusError, APITimeoutError, APIResponseValidationError @@ -321,7 +322,8 @@ def test_validate_headers(self) -> None: assert request.headers.get("Authorization") == f"Bearer {api_key}" with pytest.raises(GroqError): - client2 = Groq(base_url=base_url, api_key=None, _strict_response_validation=True) + with update_env(**{"GROQ_API_KEY": Omit()}): + client2 = Groq(base_url=base_url, api_key=None, _strict_response_validation=True) _ = client2 def test_default_query_option(self) -> None: @@ -748,6 +750,35 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> Non assert _get_open_connections(self.client) == 0 + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("groq._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_retries_taken(self, client: Groq, failures_before_success: int, respx_mock: MockRouter) -> None: + client = client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/openai/v1/chat/completions").mock(side_effect=retry_handler) + + response = client.chat.completions.with_raw_response.create( + messages=[ + { + "content": "content", + "role": "system", + } + ], + model="string", + ) + + assert response.retries_taken == failures_before_success + class TestAsyncGroq: client = AsyncGroq(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -1034,7 +1065,8 @@ def test_validate_headers(self) -> None: assert request.headers.get("Authorization") == f"Bearer {api_key}" with pytest.raises(GroqError): - client2 = AsyncGroq(base_url=base_url, api_key=None, _strict_response_validation=True) + with update_env(**{"GROQ_API_KEY": Omit()}): + client2 = AsyncGroq(base_url=base_url, api_key=None, _strict_response_validation=True) _ = client2 def test_default_query_option(self) -> None: @@ -1464,3 +1496,35 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) ) assert _get_open_connections(self.client) == 0 + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("groq._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + @pytest.mark.asyncio + async def test_retries_taken( + self, async_client: AsyncGroq, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = async_client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/openai/v1/chat/completions").mock(side_effect=retry_handler) + + response = await client.chat.completions.with_raw_response.create( + messages=[ + { + "content": "content", + "role": "system", + } + ], + model="string", + ) + + assert response.retries_taken == failures_before_success diff --git a/tests/test_deepcopy.py b/tests/test_deepcopy.py index bdc8954..e98f5a4 100644 --- a/tests/test_deepcopy.py +++ b/tests/test_deepcopy.py @@ -41,8 +41,7 @@ def test_nested_list() -> None: assert_different_identities(obj1[1], obj2[1]) -class MyObject: - ... +class MyObject: ... def test_ignores_other_types() -> None: diff --git a/tests/test_response.py b/tests/test_response.py index a89f629..551ebeb 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -19,16 +19,13 @@ from groq._base_client import FinalRequestOptions -class ConcreteBaseAPIResponse(APIResponse[bytes]): - ... +class ConcreteBaseAPIResponse(APIResponse[bytes]): ... -class ConcreteAPIResponse(APIResponse[List[str]]): - ... +class ConcreteAPIResponse(APIResponse[List[str]]): ... -class ConcreteAsyncAPIResponse(APIResponse[httpx.Response]): - ... +class ConcreteAsyncAPIResponse(APIResponse[httpx.Response]): ... def test_extract_response_type_direct_classes() -> None: @@ -56,8 +53,7 @@ def test_extract_response_type_binary_response() -> None: assert extract_response_type(AsyncBinaryAPIResponse) == bytes -class PydanticModel(pydantic.BaseModel): - ... +class PydanticModel(pydantic.BaseModel): ... def test_response_parse_mismatched_basemodel(client: Groq) -> None: diff --git a/tests/test_utils/test_typing.py b/tests/test_utils/test_typing.py index 84e628c..908203f 100644 --- a/tests/test_utils/test_typing.py +++ b/tests/test_utils/test_typing.py @@ -9,24 +9,19 @@ _T3 = TypeVar("_T3") -class BaseGeneric(Generic[_T]): - ... +class BaseGeneric(Generic[_T]): ... -class SubclassGeneric(BaseGeneric[_T]): - ... +class SubclassGeneric(BaseGeneric[_T]): ... -class BaseGenericMultipleTypeArgs(Generic[_T, _T2, _T3]): - ... +class BaseGenericMultipleTypeArgs(Generic[_T, _T2, _T3]): ... -class SubclassGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T, _T2, _T3]): - ... +class SubclassGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T, _T2, _T3]): ... -class SubclassDifferentOrderGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T2, _T, _T3]): - ... +class SubclassDifferentOrderGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T2, _T, _T3]): ... def test_extract_type_var() -> None: diff --git a/tests/utils.py b/tests/utils.py index 96a9cba..e519958 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,7 +8,7 @@ from datetime import date, datetime from typing_extensions import Literal, get_args, get_origin, assert_type -from groq._types import NoneType +from groq._types import Omit, NoneType from groq._utils import ( is_dict, is_list, @@ -139,11 +139,15 @@ def _assert_list_type(type_: type[object], value: object) -> None: @contextlib.contextmanager -def update_env(**new_env: str) -> Iterator[None]: +def update_env(**new_env: str | Omit) -> Iterator[None]: old = os.environ.copy() try: - os.environ.update(new_env) + for name, value in new_env.items(): + if isinstance(value, Omit): + os.environ.pop(name, None) + else: + os.environ[name] = value yield None finally: