Skip to content

Commit

Permalink
openai[patch]: get output_type when using with_structured_output (#26307
Browse files Browse the repository at this point in the history
)

- This allows pydantic to correctly resolve annotations necessary when
using openai new param `json_schema`

Resolves issue: #26250

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
  • Loading branch information
3 people committed Sep 13, 2024
1 parent 0f2b32f commit 7fc9e99
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
14 changes: 7 additions & 7 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
Expand Down Expand Up @@ -1421,7 +1420,7 @@ class AnswerWithJustification(BaseModel):
strict=strict,
)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
output_parser: Runnable = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
first_tool_only=True, # type: ignore[list-item]
)
Expand All @@ -1445,11 +1444,12 @@ class AnswerWithJustification(BaseModel):
strict = strict if strict is not None else True
response_format = _convert_to_openai_response_format(schema, strict=strict)
llm = self.bind(response_format=response_format)
output_parser = (
cast(Runnable, _oai_structured_outputs_parser)
if is_pydantic_schema
else JsonOutputParser()
)
if is_pydantic_schema:
output_parser = _oai_structured_outputs_parser.with_types(
output_type=cast(type, schema)
)
else:
output_parser = JsonOutputParser()
else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or "
Expand Down
29 changes: 29 additions & 0 deletions libs/partners/openai/tests/unit_tests/chat_models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel as BaseModelV2

from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import (
Expand Down Expand Up @@ -694,3 +695,31 @@ def test_get_num_tokens_from_messages() -> None:
expected = 176
actual = llm.get_num_tokens_from_messages(messages)
assert expected == actual


class Foo(BaseModel):
bar: int


class FooV2(BaseModelV2):
bar: int


@pytest.mark.parametrize("schema", [Foo, FooV2])
def test_schema_from_with_structured_output(schema: Type) -> None:
"""Test schema from with_structured_output."""

llm = ChatOpenAI()

structured_llm = llm.with_structured_output(
schema, method="json_schema", strict=True
)

expected = {
"properties": {"bar": {"title": "Bar", "type": "integer"}},
"required": ["bar"],
"title": schema.__name__,
"type": "object",
}
actual = structured_llm.get_output_schema().schema()
assert actual == expected

0 comments on commit 7fc9e99

Please sign in to comment.