Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(internal): codegen related update #103

Merged
merged 1 commit into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ FROM mcr.microsoft.com/vscode/devcontainers/python:0-${VARIANT}

USER vscode

RUN curl -sSf https://rye.astral.sh/get | RYE_VERSION="0.24.0" RYE_INSTALL_OPTION="--yes" bash
RUN curl -sSf https://rye.astral.sh/get | RYE_VERSION="0.35.0" RYE_INSTALL_OPTION="--yes" bash
ENV PATH=/home/vscode/.rye/shims:$PATH

RUN echo "[[ -d .venv ]] && source .venv/bin/activate" >> /home/vscode/.bashrc
5 changes: 3 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
pull_request:
branches:
- main
- next

jobs:
lint:
Expand All @@ -21,7 +22,7 @@ jobs:
curl -sSf https://rye.astral.sh/get | bash
echo "$HOME/.rye/shims" >> $GITHUB_PATH
env:
RYE_VERSION: 0.24.0
RYE_VERSION: '0.35.0'
RYE_INSTALL_OPTION: '--yes'

- name: Install dependencies
Expand All @@ -41,7 +42,7 @@ jobs:
curl -sSf https://rye.astral.sh/get | bash
echo "$HOME/.rye/shims" >> $GITHUB_PATH
env:
RYE_VERSION: 0.24.0
RYE_VERSION: '0.35.0'
RYE_INSTALL_OPTION: '--yes'

- name: Bootstrap
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/publish-pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ jobs:
curl -sSf https://rye.astral.sh/get | bash
echo "$HOME/.rye/shims" >> $GITHUB_PATH
env:
RYE_VERSION: 0.24.0
RYE_INSTALL_OPTION: "--yes"
RYE_VERSION: '0.35.0'
RYE_INSTALL_OPTION: '--yes'

- name: Publish to PyPI
run: |
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.prism.log
.vscode
_dev

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ It is generated with [Stainless](https://www.stainlessapi.com/).

## Documentation

The REST API documentation can be found [on console.groq.com](https://console.groq.com/docs). The full API of this library can be found in [api.md](api.md).
The REST API documentation can be found on [console.groq.com](https://console.groq.com/docs). The full API of this library can be found in [api.md](api.md).

## Installation

Expand Down
16 changes: 16 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ dev-dependencies = [
"nox",
"dirty-equals>=0.6.0",
"importlib-metadata>=6.7.0",
"rich>=13.7.1",

]

Expand Down Expand Up @@ -99,6 +100,21 @@ include = [
[tool.hatch.build.targets.wheel]
packages = ["src/groq"]

[tool.hatch.build.targets.sdist]
# Basically everything except hidden files/directories (such as .github, .devcontainers, .python-version, etc)
include = [
"/*.toml",
"/*.json",
"/*.lock",
"/*.md",
"/mypy.ini",
"/noxfile.py",
"bin/*",
"examples/*",
"src/*",
"tests/*",
]

[tool.hatch.metadata.hooks.fancy-pypi-readme]
content-type = "text/markdown"

Expand Down
12 changes: 10 additions & 2 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
-e file:.
annotated-types==0.6.0
# via pydantic
anyio==4.1.0
anyio==4.4.0
# via groq
# via httpx
argcomplete==3.1.2
Expand Down Expand Up @@ -45,7 +45,11 @@ idna==3.4
importlib-metadata==7.0.0
iniconfig==2.0.0
# via pytest
mypy==1.7.1
markdown-it-py==3.0.0
# via rich
mdurl==0.1.2
# via markdown-it-py
mypy==1.10.1
mypy-extensions==1.0.0
# via mypy
nodeenv==1.8.0
Expand All @@ -64,6 +68,8 @@ pydantic==2.7.1
# via groq
pydantic-core==2.18.2
# via pydantic
pygments==2.18.0
# via rich
pyright==1.1.364
pytest==7.1.1
# via pytest-asyncio
Expand All @@ -73,6 +79,7 @@ python-dateutil==2.8.2
pytz==2023.3.post1
# via dirty-equals
respx==0.20.2
rich==13.7.1
ruff==0.1.9
setuptools==68.2.2
# via nodeenv
Expand All @@ -87,6 +94,7 @@ tomli==2.0.1
# via mypy
# via pytest
typing-extensions==4.8.0
# via anyio
# via groq
# via mypy
# via pydantic
Expand Down
3 changes: 2 additions & 1 deletion requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
-e file:.
annotated-types==0.6.0
# via pydantic
anyio==4.1.0
anyio==4.4.0
# via groq
# via httpx
certifi==2023.7.22
Expand Down Expand Up @@ -39,6 +39,7 @@ sniffio==1.3.0
# via groq
# via httpx
typing-extensions==4.8.0
# via anyio
# via groq
# via pydantic
# via pydantic-core
54 changes: 40 additions & 14 deletions src/groq/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
HttpxSendArgs,
AsyncTransport,
RequestOptions,
HttpxRequestFiles,
ModelBuilderProtocol,
)
from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping
Expand Down Expand Up @@ -459,6 +460,7 @@ def _build_request(
headers = self._build_headers(options)
params = _merge_mappings(self.default_query, options.params)
content_type = headers.get("Content-Type")
files = options.files

# If the given Content-Type header is multipart/form-data then it
# has to be removed so that httpx can generate the header with
Expand All @@ -472,14 +474,23 @@ def _build_request(
headers.pop("Content-Type")

# As we are now sending multipart/form-data instead of application/json
# we need to tell httpx to use it, https://www.python-httpx.org/advanced/#multipart-file-encoding
# we need to tell httpx to use it, https://www.python-httpx.org/advanced/clients/#multipart-file-encoding
if json_data:
if not is_dict(json_data):
raise TypeError(
f"Expected query input to be a dictionary for multipart requests but got {type(json_data)} instead."
)
kwargs["data"] = self._serialize_multipartform(json_data)

# httpx determines whether or not to send a "multipart/form-data"
# request based on the truthiness of the "files" argument.
# This gets around that issue by generating a dict value that
# evaluates to true.
#
# https://github.com/encode/httpx/discussions/2399#discussioncomment-3814186
if not files:
files = cast(HttpxRequestFiles, ForceMultipartDict())

# TODO: report this error to httpx
return self._client.build_request( # pyright: ignore[reportUnknownMemberType]
headers=headers,
Expand All @@ -492,7 +503,7 @@ def _build_request(
# https://github.com/microsoft/pyright/issues/3526#event-6715453066
params=self.qs.stringify(cast(Mapping[str, Any], params)) if params else None,
json=json_data,
files=options.files,
files=files,
**kwargs,
)

Expand Down Expand Up @@ -868,9 +879,9 @@ def __exit__(
def _prepare_options(
self,
options: FinalRequestOptions, # noqa: ARG002
) -> None:
) -> FinalRequestOptions:
"""Hook for mutating the given options"""
return None
return options

def _prepare_request(
self,
Expand Down Expand Up @@ -944,8 +955,13 @@ def _request(
stream: bool,
stream_cls: type[_StreamT] | None,
) -> ResponseT | _StreamT:
# create a copy of the options we were given so that if the
# options are mutated later & we then retry, the retries are
# given the original options
input_options = model_copy(options)

cast_to = self._maybe_override_cast_to(cast_to, options)
self._prepare_options(options)
options = self._prepare_options(options)

retries = self._remaining_retries(remaining_retries, options)
request = self._build_request(options)
Expand All @@ -968,7 +984,7 @@ def _request(

if retries > 0:
return self._retry_request(
options,
input_options,
cast_to,
retries,
stream=stream,
Expand All @@ -983,7 +999,7 @@ def _request(

if retries > 0:
return self._retry_request(
options,
input_options,
cast_to,
retries,
stream=stream,
Expand Down Expand Up @@ -1011,7 +1027,7 @@ def _request(
if retries > 0 and self._should_retry(err.response):
err.response.close()
return self._retry_request(
options,
input_options,
cast_to,
retries,
err.response.headers,
Expand Down Expand Up @@ -1426,9 +1442,9 @@ async def __aexit__(
async def _prepare_options(
self,
options: FinalRequestOptions, # noqa: ARG002
) -> None:
) -> FinalRequestOptions:
"""Hook for mutating the given options"""
return None
return options

async def _prepare_request(
self,
Expand Down Expand Up @@ -1507,8 +1523,13 @@ async def _request(
# execute it earlier while we are in an async context
self._platform = await asyncify(get_platform)()

# create a copy of the options we were given so that if the
# options are mutated later & we then retry, the retries are
# given the original options
input_options = model_copy(options)

cast_to = self._maybe_override_cast_to(cast_to, options)
await self._prepare_options(options)
options = await self._prepare_options(options)

retries = self._remaining_retries(remaining_retries, options)
request = self._build_request(options)
Expand All @@ -1529,7 +1550,7 @@ async def _request(

if retries > 0:
return await self._retry_request(
options,
input_options,
cast_to,
retries,
stream=stream,
Expand All @@ -1544,7 +1565,7 @@ async def _request(

if retries > 0:
return await self._retry_request(
options,
input_options,
cast_to,
retries,
stream=stream,
Expand All @@ -1567,7 +1588,7 @@ async def _request(
if retries > 0 and self._should_retry(err.response):
await err.response.aclose()
return await self._retry_request(
options,
input_options,
cast_to,
retries,
err.response.headers,
Expand Down Expand Up @@ -1863,6 +1884,11 @@ def make_request_options(
return options


class ForceMultipartDict(Dict[str, None]):
def __bool__(self) -> bool:
return True


class OtherPlatform:
def __init__(self, name: str) -> None:
self.name = name
Expand Down
6 changes: 3 additions & 3 deletions src/groq/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
return model.__fields__ # type: ignore


def model_copy(model: _ModelT) -> _ModelT:
def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT:
if PYDANTIC_V2:
return model.model_copy()
return model.copy() # type: ignore
return model.model_copy(deep=deep)
return model.copy(deep=deep) # type: ignore


def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
Expand Down
Loading