Skip to content

Commit

Permalink
fix(json schema): remove None defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed Aug 19, 2024
1 parent 798c6cb commit 161f1e5
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 4 deletions.
7 changes: 7 additions & 0 deletions src/openai/lib/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pydantic

from .._types import NOT_GIVEN
from .._utils import is_dict as _is_dict, is_list
from .._compat import model_json_schema

Expand Down Expand Up @@ -76,6 +77,12 @@ def _ensure_strict_json_schema(
for i, entry in enumerate(all_of)
]

# strip `None` defaults as there's no meaningful distinction here
# the schema will still be `nullable` and the model will default
# to using `None` anyway
if json_schema.get("default", NOT_GIVEN) is None:
json_schema.pop("default")

# we can't use `$ref`s if there are also other properties defined, e.g.
# `{"$ref": "...", "description": "my description"}`
#
Expand Down
60 changes: 59 additions & 1 deletion tests/lib/chat/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import json
from enum import Enum
from typing import Any, Callable
from typing import Any, Callable, Optional
from typing_extensions import Literal, TypeVar

import httpx
Expand Down Expand Up @@ -135,6 +135,63 @@ class Location(BaseModel):
)


@pytest.mark.respx(base_url=base_url)
def test_parse_pydantic_model_optional_default(
client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch
) -> None:
class Location(BaseModel):
city: str
temperature: float
units: Optional[Literal["c", "f"]] = None

completion = _make_snapshot_request(
lambda c: c.beta.chat.completions.parse(
model="gpt-4o-2024-08-06",
messages=[
{
"role": "user",
"content": "What's the weather like in SF?",
},
],
response_format=Location,
),
content_snapshot=snapshot(
'{"id": "chatcmpl-9y39Q2jGzWmeEZlm5CoNVOuQzcxP4", "object": "chat.completion", "created": 1724098820, "model": "gpt-4o-2024-08-06", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\\"city\\":\\"San Francisco\\",\\"temperature\\":62,\\"units\\":\\"f\\"}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 17, "completion_tokens": 14, "total_tokens": 31}, "system_fingerprint": "fp_2a322c9ffc"}'
),
mock_client=client,
respx_mock=respx_mock,
)

assert print_obj(completion, monkeypatch) == snapshot(
"""\
ParsedChatCompletion[Location](
choices=[
ParsedChoice[Location](
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
content='{"city":"San Francisco","temperature":62,"units":"f"}',
function_call=None,
parsed=Location(city='San Francisco', temperature=62.0, units='f'),
refusal=None,
role='assistant',
tool_calls=[]
)
)
],
created=1724098820,
id='chatcmpl-9y39Q2jGzWmeEZlm5CoNVOuQzcxP4',
model='gpt-4o-2024-08-06',
object='chat.completion',
service_tier=None,
system_fingerprint='fp_2a322c9ffc',
usage=CompletionUsage(completion_tokens=14, prompt_tokens=17, total_tokens=31)
)
"""
)


@pytest.mark.respx(base_url=base_url)
def test_parse_pydantic_model_enum(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
class Color(Enum):
Expand Down Expand Up @@ -320,6 +377,7 @@ def test_pydantic_tool_model_all_types(client: OpenAI, respx_mock: MockRouter, m
value=DynamicValue(column_name='expected_delivery_date')
)
],
name=None,
order_by=<OrderBy.asc: 'asc'>,
table_name=<Table.orders: 'orders'>
)
Expand Down
3 changes: 2 additions & 1 deletion tests/lib/schema_types/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import List, Union
from typing import List, Union, Optional

from pydantic import BaseModel

Expand Down Expand Up @@ -45,6 +45,7 @@ class Condition(BaseModel):


class Query(BaseModel):
name: Optional[str] = None
table_name: Table
columns: List[Column]
conditions: List[Condition]
Expand Down
6 changes: 4 additions & 2 deletions tests/lib/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def test_most_types() -> None:
"Table": {"enum": ["orders", "customers", "products"], "title": "Table", "type": "string"},
},
"properties": {
"name": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Name"},
"table_name": {"$ref": "#/$defs/Table"},
"columns": {
"items": {"$ref": "#/$defs/Column"},
Expand All @@ -75,7 +76,7 @@ def test_most_types() -> None:
},
"order_by": {"$ref": "#/$defs/OrderBy"},
},
"required": ["table_name", "columns", "conditions", "order_by"],
"required": ["name", "table_name", "columns", "conditions", "order_by"],
"title": "Query",
"type": "object",
"additionalProperties": False,
Expand All @@ -91,6 +92,7 @@ def test_most_types() -> None:
"title": "Query",
"type": "object",
"properties": {
"name": {"title": "Name", "type": "string"},
"table_name": {"$ref": "#/definitions/Table"},
"columns": {"type": "array", "items": {"$ref": "#/definitions/Column"}},
"conditions": {
Expand All @@ -100,7 +102,7 @@ def test_most_types() -> None:
},
"order_by": {"$ref": "#/definitions/OrderBy"},
},
"required": ["table_name", "columns", "conditions", "order_by"],
"required": ["name", "table_name", "columns", "conditions", "order_by"],
"definitions": {
"Table": {
"title": "Table",
Expand Down

0 comments on commit 161f1e5

Please sign in to comment.