-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support structured outputs response format based on signature in JSON adapter #1881
Changes from all commits
7a1a84f
b7dfbb8
ed9d504
6e01b2e
7c0e03b
5af146d
6007f61
68d8877
b87cf96
90dc353
40dee38
80f9f34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,13 +2,15 @@ | |
import enum | ||
import inspect | ||
import json | ||
import logging | ||
import textwrap | ||
from copy import deepcopy | ||
from typing import Any, Dict, KeysView, Literal, NamedTuple, get_args, get_origin | ||
|
||
import json_repair | ||
import litellm | ||
import pydantic | ||
from pydantic import TypeAdapter | ||
from pydantic import TypeAdapter, create_model | ||
from pydantic.fields import FieldInfo | ||
|
||
from dspy.adapters.base import Adapter | ||
|
@@ -18,6 +20,8 @@ | |
from ..signatures.signature import SignatureMeta | ||
from ..signatures.utils import get_dspy_field_type | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
class FieldInfoWithName(NamedTuple): | ||
name: str | ||
|
@@ -35,7 +39,16 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): | |
try: | ||
provider = lm.model.split("/", 1)[0] or "openai" | ||
if "response_format" in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider): | ||
outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"}) | ||
try: | ||
response_format = _get_structured_outputs_response_format(signature) | ||
outputs = lm(**inputs, **lm_kwargs, response_format=response_format) | ||
except Exception: | ||
_logger.debug( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Debated a warning, but it seems too spammy |
||
"Failed to obtain response using signature-based structured outputs" | ||
" response format: Falling back to default 'json_object' response format." | ||
" Exception: {e}" | ||
) | ||
Comment on lines
+47
to
+50
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We expect to hit this case for tuples, until there's support for |
||
outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"}) | ||
else: | ||
outputs = lm(**inputs, **lm_kwargs) | ||
|
||
|
@@ -303,3 +316,50 @@ def format_signature_fields_for_instructions(role, fields: Dict[str, FieldInfo]) | |
# ", and then ending with the marker for `completed`.") | ||
|
||
return "\n\n".join(parts).strip() | ||
|
||
|
||
def _get_structured_outputs_response_format(signature: SignatureMeta) -> pydantic.BaseModel: | ||
""" | ||
Obtains the LiteLLM / OpenAI `response_format` parameter for generating structured outputs from | ||
an LM request, based on the output fields of the specified DSPy signature. | ||
|
||
Args: | ||
signature: The DSPy signature for which to obtain the `response_format` request parameter. | ||
Returns: | ||
A Pydantic model representing the `response_format` parameter for the LM request. | ||
""" | ||
|
||
def filter_json_schema_extra(field_name: str, field_info: FieldInfo) -> FieldInfo: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs test coverage |
||
""" | ||
Recursively filter the `json_schema_extra` of a FieldInfo to exclude DSPy internal attributes | ||
(e.g. `__dspy_field_type`) and remove descriptions that are placeholders for the field name. | ||
""" | ||
field_copy = deepcopy(field_info) # Make a copy to avoid mutating the original | ||
|
||
# Update `json_schema_extra` for the copied field | ||
if field_copy.json_schema_extra: | ||
field_copy.json_schema_extra = { | ||
key: value | ||
for key, value in field_info.json_schema_extra.items() | ||
if key not in ("desc", "__dspy_field_type") | ||
} | ||
field_desc = field_info.json_schema_extra.get("desc") | ||
if field_desc is not None and field_desc != f"${{{field_name}}}": | ||
field_copy.json_schema_extra["desc"] = field_desc | ||
|
||
# Handle nested fields | ||
if hasattr(field_copy.annotation, "__pydantic_model__"): | ||
# Recursively update fields of the nested model | ||
nested_model = field_copy.annotation.__pydantic_model__ | ||
updated_fields = { | ||
key: filter_json_schema_extra(key, value) for key, value in nested_model.__fields__.items() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious - why do we need recursive handling? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the vast majority of cases, hopefully not.... Though the following user error will silently produce this state, which is probably best to exclude from
|
||
} | ||
# Create a new model with the same name and updated fields | ||
field_copy.annotation = create_model(nested_model.__name__, **updated_fields) | ||
|
||
return field_copy | ||
|
||
output_pydantic_fields = { | ||
key: (value.annotation, filter_json_schema_extra(key, value)) for key, value in signature.output_fields.items() | ||
} | ||
return create_model("DSPyProgramOutputs", **output_pydantic_fields) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
from typing import List | ||
|
||
import pydantic | ||
import pytest | ||
|
||
import dspy | ||
from tests.reliability.utils import assert_program_output_correct, known_failing_models | ||
|
@@ -33,22 +34,29 @@ class QA(dspy.Signature): | |
assert_program_output_correct( | ||
program_input=question, | ||
program_output=answer.comments, | ||
grading_guidelines="The comments should be relevant to the answer", | ||
grading_guidelines=( | ||
"The comments should be relevant to the answer. They don't need to restate the answer explicitly." | ||
), | ||
) | ||
assert answer.certainty >= 0 | ||
assert answer.certainty <= 1 | ||
assert len(answer.comments) >= 2 | ||
|
||
|
||
def test_color_classification_using_enum(): | ||
@pytest.mark.parametrize("module", [dspy.Predict, dspy.ChainOfThought]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CoT fails this test case with chat and json adapters on master:
However, it passes on the PR branch :D |
||
def test_color_classification_using_enum(module): | ||
Color = Enum("Color", ["RED", "GREEN", "BLUE"]) | ||
|
||
class Colorful(dspy.Signature): | ||
text: str = dspy.InputField() | ||
color: Color = dspy.OutputField() | ||
|
||
program = dspy.Predict(Colorful) | ||
color = program(text="The sky is blue").color | ||
program = module(Colorful) | ||
# Note: The precise text, including the trailing period, is important here for ensuring that | ||
# the program is correctly extracting the color from the text; previous implementations have | ||
# produced invalid enum responses for "The sky is blue.", but they have produced valid enum | ||
# responses for "The sky is blue" (without the period). | ||
color = program(text="The sky is blue.").color | ||
|
||
assert color == Color.BLUE | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,6 @@ def assert_program_output_correct( | |
grading_guidelines = [grading_guidelines] | ||
|
||
with judge_dspy_configuration(): | ||
print("GUIDELINES", grading_guidelines) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removing a leftover & unintentional debugging statement from test generation code |
||
for guideline_entry in grading_guidelines: | ||
judge_response = _get_judge_program()( | ||
program_input=str(program_input), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from unittest import mock | ||
|
||
import pydantic | ||
import pytest | ||
from pydantic import create_model | ||
|
||
import dspy | ||
|
||
|
||
def test_json_adapter_passes_structured_output_when_supported_by_model(): | ||
class OutputField3(pydantic.BaseModel): | ||
subfield1: int = pydantic.Field(description="Int subfield 1", ge=0, le=10) | ||
subfield2: float = pydantic.Field(description="Float subfield 2") | ||
|
||
class TestSignature(dspy.Signature): | ||
input1: str = dspy.InputField() | ||
output1: str = dspy.OutputField() # Description intentionally left blank | ||
output2: bool = dspy.OutputField(desc="Boolean output field") | ||
output3: OutputField3 = dspy.OutputField(desc="Nested output field") | ||
output4_unannotated = dspy.OutputField(desc="Unannotated output field") | ||
|
||
program = dspy.Predict(TestSignature) | ||
|
||
# Configure DSPy to use an OpenAI LM that supports structured outputs | ||
dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter()) | ||
with mock.patch("litellm.completion") as mock_completion: | ||
program(input1="Test input") | ||
|
||
def clean_schema_extra(field_name, field_info): | ||
attrs = dict(field_info.__repr_args__()) | ||
if "json_schema_extra" in attrs: | ||
attrs["json_schema_extra"] = { | ||
k: v | ||
for k, v in attrs["json_schema_extra"].items() | ||
if k != "__dspy_field_type" and not (k == "desc" and v == f"${{{field_name}}}") | ||
} | ||
return attrs | ||
|
||
mock_completion.assert_called_once() | ||
_, call_kwargs = mock_completion.call_args | ||
response_format = call_kwargs.get("response_format") | ||
assert response_format is not None | ||
assert issubclass(response_format, pydantic.BaseModel) | ||
assert response_format.model_fields.keys() == {"output1", "output2", "output3", "output4_unannotated"} | ||
for field_name in response_format.model_fields: | ||
assert dict(response_format.model_fields[field_name].__repr_args__()) == clean_schema_extra( | ||
field_name=field_name, | ||
field_info=TestSignature.output_fields[field_name], | ||
) | ||
|
||
# Configure DSPy to use a model from a fake provider that doesn't support structured outputs | ||
dspy.configure(lm=dspy.LM(model="fakeprovider/fakemodel"), adapter=dspy.JSONAdapter()) | ||
with mock.patch("litellm.completion") as mock_completion: | ||
program(input1="Test input") | ||
|
||
mock_completion.assert_called_once() | ||
_, call_kwargs = mock_completion.call_args | ||
assert response_format not in call_kwargs | ||
|
||
|
||
def test_json_adapter_falls_back_when_structured_outputs_fails(): | ||
class TestSignature(dspy.Signature): | ||
input1: str = dspy.InputField() | ||
output1: str = dspy.OutputField(desc="String output field") | ||
|
||
dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter()) | ||
program = dspy.Predict(TestSignature) | ||
with mock.patch("litellm.completion") as mock_completion: | ||
mock_completion.side_effect = [Exception("Bad structured outputs!"), mock_completion.return_value] | ||
program(input1="Test input") | ||
assert mock_completion.call_count == 2 | ||
_, first_call_kwargs = mock_completion.call_args_list[0] | ||
assert issubclass(first_call_kwargs.get("response_format"), pydantic.BaseModel) | ||
_, second_call_kwargs = mock_completion.call_args_list[1] | ||
assert second_call_kwargs.get("response_format") == {"type": "json_object"} | ||
|
||
|
||
def test_json_adapter_with_structured_outputs_does_not_mutate_original_signature(): | ||
class OutputField3(pydantic.BaseModel): | ||
subfield1: int = pydantic.Field(description="Int subfield 1") | ||
subfield2: float = pydantic.Field(description="Float subfield 2") | ||
|
||
class TestSignature(dspy.Signature): | ||
input1: str = dspy.InputField() | ||
output1: str = dspy.OutputField() # Description intentionally left blank | ||
output2: bool = dspy.OutputField(desc="Boolean output field") | ||
output3: OutputField3 = dspy.OutputField(desc="Nested output field") | ||
output4_unannotated = dspy.OutputField(desc="Unannotated output field") | ||
|
||
dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter()) | ||
program = dspy.Predict(TestSignature) | ||
with mock.patch("litellm.completion"): | ||
program(input1="Test input") | ||
|
||
assert program.signature.output_fields == TestSignature.output_fields |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LM providers have differing levels of support for
response_format
fields. For example, Databricks doesn't support anyOf / allOf, but OpenAI does.A blanket try/catch seems appropriate here to start.