Skip to content

Commit

Permalink
chore(pydantic v1): exclude specific properties when rich printing (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie authored Sep 26, 2024
1 parent 70edb21 commit af535ce
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 9 deletions.
10 changes: 9 additions & 1 deletion src/openai/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import os
import inspect
from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast
from typing import TYPE_CHECKING, Any, Type, Tuple, Union, Generic, TypeVar, Callable, Optional, cast
from datetime import date, datetime
from typing_extensions import (
Unpack,
Literal,
ClassVar,
Protocol,
Required,
Sequence,
ParamSpec,
TypedDict,
TypeGuard,
Expand Down Expand Up @@ -72,6 +73,8 @@

P = ParamSpec("P")

ReprArgs = Sequence[Tuple[Optional[str], Any]]


@runtime_checkable
class _ConfigProtocol(Protocol):
Expand All @@ -94,6 +97,11 @@ def model_fields_set(self) -> set[str]:
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
extra: Any = pydantic.Extra.allow # type: ignore

@override
def __repr_args__(self) -> ReprArgs:
# we don't want these attributes to be included when something like `rich.print` is used
return [arg for arg in super().__repr_args__() if arg[0] not in {"_request_id", "__exclude_fields__"}]

if TYPE_CHECKING:
_request_id: Optional[str] = None
"""The ID of the request, returned via the X-Request-ID header. Useful for debugging requests and reporting issues to OpenAI.
Expand Down
11 changes: 3 additions & 8 deletions tests/lib/chat/_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

import io
import inspect
from typing import Any, Iterable
from typing_extensions import TypeAlias

import rich
import pytest
import pydantic

from ...utils import rich_print_str

ReprArgs: TypeAlias = "Iterable[tuple[str | None, Any]]"


Expand All @@ -26,12 +26,7 @@ def __repr_args__(self: pydantic.BaseModel) -> ReprArgs:
with monkeypatch.context() as m:
m.setattr(pydantic.BaseModel, "__repr_args__", __repr_args__)

buf = io.StringIO()

console = rich.console.Console(file=buf, width=120)
console.print(obj)

string = buf.getvalue()
string = rich_print_str(obj)

# we remove all `fn_name.<locals>.` occurences
# so that we can share the same snapshots between
Expand Down
4 changes: 4 additions & 0 deletions tests/test_legacy_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from openai._base_client import FinalRequestOptions
from openai._legacy_response import LegacyAPIResponse

from .utils import rich_print_str


class PydanticModel(pydantic.BaseModel): ...

Expand Down Expand Up @@ -85,6 +87,8 @@ def test_response_basemodel_request_id(client: OpenAI) -> None:
assert obj.foo == "hello!"
assert obj.bar == 2
assert obj.to_dict() == {"foo": "hello!", "bar": 2}
assert "_request_id" not in rich_print_str(obj)
assert "__exclude_fields__" not in rich_print_str(obj)


def test_response_parse_annotated_type(client: OpenAI) -> None:
Expand Down
4 changes: 4 additions & 0 deletions tests/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from openai._streaming import Stream
from openai._base_client import FinalRequestOptions

from .utils import rich_print_str


class ConcreteBaseAPIResponse(APIResponse[bytes]): ...

Expand Down Expand Up @@ -175,6 +177,8 @@ def test_response_basemodel_request_id(client: OpenAI) -> None:
assert obj.foo == "hello!"
assert obj.bar == 2
assert obj.to_dict() == {"foo": "hello!", "bar": 2}
assert "_request_id" not in rich_print_str(obj)
assert "__exclude_fields__" not in rich_print_str(obj)


@pytest.mark.asyncio
Expand Down
13 changes: 13 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import io
import os
import inspect
import traceback
Expand All @@ -8,6 +9,8 @@
from datetime import date, datetime
from typing_extensions import Literal, get_args, get_origin, assert_type

import rich

from openai._types import Omit, NoneType
from openai._utils import (
is_dict,
Expand Down Expand Up @@ -138,6 +141,16 @@ def _assert_list_type(type_: type[object], value: object) -> None:
assert_type(inner_type, entry) # type: ignore


def rich_print_str(obj: object) -> str:
"""Like `rich.print()` but returns the string instead"""
buf = io.StringIO()

console = rich.console.Console(file=buf, width=120)
console.print(obj)

return buf.getvalue()


@contextlib.contextmanager
def update_env(**new_env: str | Omit) -> Iterator[None]:
old = os.environ.copy()
Expand Down

0 comments on commit af535ce

Please sign in to comment.