Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support structured outputs response format based on signature in JSON adapter #1881

Merged
merged 12 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import enum
import inspect
import json
import logging
import textwrap
from copy import deepcopy
from typing import Any, Dict, KeysView, Literal, NamedTuple, get_args, get_origin

import json_repair
import litellm
import pydantic
from pydantic import TypeAdapter
from pydantic import TypeAdapter, create_model
from pydantic.fields import FieldInfo

from dspy.adapters.base import Adapter
Expand All @@ -18,6 +20,8 @@
from ..signatures.signature import SignatureMeta
from ..signatures.utils import get_dspy_field_type

_logger = logging.getLogger(__name__)


class FieldInfoWithName(NamedTuple):
name: str
Expand All @@ -35,7 +39,16 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
try:
provider = lm.model.split("/", 1)[0] or "openai"
if "response_format" in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider):
outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"})
try:
response_format = _get_structured_outputs_response_format(signature)
outputs = lm(**inputs, **lm_kwargs, response_format=response_format)
except Exception:
Comment on lines +42 to +45
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LM providers have differing levels of support for response_format fields. For example, Databricks doesn't support anyOf / allOf, but OpenAI does.

A blanket try/catch seems appropriate here to start.

_logger.debug(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debated a warning, but it seems too spammy

"Failed to obtain response using signature-based structured outputs"
" response format: Falling back to default 'json_object' response format."
" Exception: {e}"
)
Comment on lines +47 to +50
Copy link
Collaborator Author

@dbczumar dbczumar Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We expect to hit this case for tuples, until there's support for prefixItems in the OpenAI structured outputs API. Other vendors, e.g. Databricks, will likely lag even further behind (e.g. Databricks doesn't support anyOf currently, but OpenAI does), meaning that we could hit this case for additional output types

outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"})
else:
outputs = lm(**inputs, **lm_kwargs)

Expand Down Expand Up @@ -303,3 +316,50 @@ def format_signature_fields_for_instructions(role, fields: Dict[str, FieldInfo])
# ", and then ending with the marker for `completed`.")

return "\n\n".join(parts).strip()


def _get_structured_outputs_response_format(signature: SignatureMeta) -> pydantic.BaseModel:
"""
Obtains the LiteLLM / OpenAI `response_format` parameter for generating structured outputs from
an LM request, based on the output fields of the specified DSPy signature.

Args:
signature: The DSPy signature for which to obtain the `response_format` request parameter.
Returns:
A Pydantic model representing the `response_format` parameter for the LM request.
"""

def filter_json_schema_extra(field_name: str, field_info: FieldInfo) -> FieldInfo:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs test coverage

"""
Recursively filter the `json_schema_extra` of a FieldInfo to exclude DSPy internal attributes
(e.g. `__dspy_field_type`) and remove descriptions that are placeholders for the field name.
"""
field_copy = deepcopy(field_info) # Make a copy to avoid mutating the original

# Update `json_schema_extra` for the copied field
if field_copy.json_schema_extra:
field_copy.json_schema_extra = {
key: value
for key, value in field_info.json_schema_extra.items()
if key not in ("desc", "__dspy_field_type")
}
field_desc = field_info.json_schema_extra.get("desc")
if field_desc is not None and field_desc != f"${{{field_name}}}":
field_copy.json_schema_extra["desc"] = field_desc

# Handle nested fields
if hasattr(field_copy.annotation, "__pydantic_model__"):
# Recursively update fields of the nested model
nested_model = field_copy.annotation.__pydantic_model__
updated_fields = {
key: filter_json_schema_extra(key, value) for key, value in nested_model.__fields__.items()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious - why do we need recursive handling? nested_model.__fields__ should be Pydantic fields instead of DSPy fields, do they also have these DSPy internal attributes like __dspy_field_type?

Copy link
Collaborator Author

@dbczumar dbczumar Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the vast majority of cases, hopefully not....

Though the following user error will silently produce this state, which is probably best to exclude from response_format because the program still runs

import dspy

import pydantic

class Obj(pydantic.BaseModel):
    a: int = dspy.OutputField()
    b: str

class MySig(dspy.Signature):
    inp: str = dspy.InputField() 
    outp: Obj = dspy.OutputField()

print(MySig.schema())
{'$defs': {'Obj': {'properties': {'a': {'__dspy_field_type': 'output', 'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'string'}}, 'required': ['a', 'b'], 'title': 'Obj', 'type': 'object'}}, 'description': 'Given the fields `inp`, produce the fields `outp`.', 'properties': {'inp': {'__dspy_field_type': 'input', 'desc': '${inp}', 'prefix': 'Inp:', 'title': 'Inp', 'type': 'string'}, 'outp': {'$ref': '#/$defs/Obj', '__dspy_field_type': 'output', 'desc': '${outp}', 'prefix': 'Outp:'}}, 'required': ['inp', 'outp'], 'title': 'MySig', 'type': 'object'}

}
# Create a new model with the same name and updated fields
field_copy.annotation = create_model(nested_model.__name__, **updated_fields)

return field_copy

output_pydantic_fields = {
key: (value.annotation, filter_json_schema_extra(key, value)) for key, value in signature.output_fields.items()
}
return create_model("DSPyProgramOutputs", **output_pydantic_fields)
2 changes: 1 addition & 1 deletion dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,11 @@ def func_cached(key: str, request: Dict[str, Any], *args, **kwargs):
def wrapper(request: dict, *args, **kwargs):
try:
key = cache_key(request)
return func_cached(key, request, *args, **kwargs)
except Exception:
# If the cache key cannot be computed (e.g. because it contains a value that cannot
# be converted to JSON), bypass the cache and call the target function directly
return func(request, *args, **kwargs)
return func_cached(key, request, *args, **kwargs)

return wrapper

Expand Down
16 changes: 12 additions & 4 deletions tests/reliability/test_pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

import pydantic
import pytest

import dspy
from tests.reliability.utils import assert_program_output_correct, known_failing_models
Expand Down Expand Up @@ -33,22 +34,29 @@ class QA(dspy.Signature):
assert_program_output_correct(
program_input=question,
program_output=answer.comments,
grading_guidelines="The comments should be relevant to the answer",
grading_guidelines=(
"The comments should be relevant to the answer. They don't need to restate the answer explicitly."
),
)
assert answer.certainty >= 0
assert answer.certainty <= 1
assert len(answer.comments) >= 2


def test_color_classification_using_enum():
@pytest.mark.parametrize("module", [dspy.Predict, dspy.ChainOfThought])
Copy link
Collaborator Author

@dbczumar dbczumar Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CoT fails this test case with chat and json adapters on master:

FAILED test_pydantic_models.py::test_color_classification_using_enum[llama-3.1-70b-instruct-ChainOfThought] - ValueError: Color.BLUE is not a valid name or value for the enum Color
================================================ 1 failed, 1 passed, 24 skipped, 26 deselected, 2 warnings in 0.22s

However, it passes on the PR branch :D

def test_color_classification_using_enum(module):
Color = Enum("Color", ["RED", "GREEN", "BLUE"])

class Colorful(dspy.Signature):
text: str = dspy.InputField()
color: Color = dspy.OutputField()

program = dspy.Predict(Colorful)
color = program(text="The sky is blue").color
program = module(Colorful)
# Note: The precise text, including the trailing period, is important here for ensuring that
# the program is correctly extracting the color from the text; previous implementations have
# produced invalid enum responses for "The sky is blue.", but they have produced valid enum
# responses for "The sky is blue" (without the period).
color = program(text="The sky is blue.").color

assert color == Color.BLUE

Expand Down
1 change: 0 additions & 1 deletion tests/reliability/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def assert_program_output_correct(
grading_guidelines = [grading_guidelines]

with judge_dspy_configuration():
print("GUIDELINES", grading_guidelines)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing a leftover & unintentional debugging statement from test generation code

for guideline_entry in grading_guidelines:
judge_response = _get_judge_program()(
program_input=str(program_input),
Expand Down
95 changes: 95 additions & 0 deletions tests/test_json_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from unittest import mock

import pydantic
import pytest
from pydantic import create_model

import dspy


def test_json_adapter_passes_structured_output_when_supported_by_model():
class OutputField3(pydantic.BaseModel):
subfield1: int = pydantic.Field(description="Int subfield 1", ge=0, le=10)
subfield2: float = pydantic.Field(description="Float subfield 2")

class TestSignature(dspy.Signature):
input1: str = dspy.InputField()
output1: str = dspy.OutputField() # Description intentionally left blank
output2: bool = dspy.OutputField(desc="Boolean output field")
output3: OutputField3 = dspy.OutputField(desc="Nested output field")
output4_unannotated = dspy.OutputField(desc="Unannotated output field")

program = dspy.Predict(TestSignature)

# Configure DSPy to use an OpenAI LM that supports structured outputs
dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter())
with mock.patch("litellm.completion") as mock_completion:
program(input1="Test input")

def clean_schema_extra(field_name, field_info):
attrs = dict(field_info.__repr_args__())
if "json_schema_extra" in attrs:
attrs["json_schema_extra"] = {
k: v
for k, v in attrs["json_schema_extra"].items()
if k != "__dspy_field_type" and not (k == "desc" and v == f"${{{field_name}}}")
}
return attrs

mock_completion.assert_called_once()
_, call_kwargs = mock_completion.call_args
response_format = call_kwargs.get("response_format")
assert response_format is not None
assert issubclass(response_format, pydantic.BaseModel)
assert response_format.model_fields.keys() == {"output1", "output2", "output3", "output4_unannotated"}
for field_name in response_format.model_fields:
assert dict(response_format.model_fields[field_name].__repr_args__()) == clean_schema_extra(
field_name=field_name,
field_info=TestSignature.output_fields[field_name],
)

# Configure DSPy to use a model from a fake provider that doesn't support structured outputs
dspy.configure(lm=dspy.LM(model="fakeprovider/fakemodel"), adapter=dspy.JSONAdapter())
with mock.patch("litellm.completion") as mock_completion:
program(input1="Test input")

mock_completion.assert_called_once()
_, call_kwargs = mock_completion.call_args
assert response_format not in call_kwargs


def test_json_adapter_falls_back_when_structured_outputs_fails():
class TestSignature(dspy.Signature):
input1: str = dspy.InputField()
output1: str = dspy.OutputField(desc="String output field")

dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter())
program = dspy.Predict(TestSignature)
with mock.patch("litellm.completion") as mock_completion:
mock_completion.side_effect = [Exception("Bad structured outputs!"), mock_completion.return_value]
program(input1="Test input")
assert mock_completion.call_count == 2
_, first_call_kwargs = mock_completion.call_args_list[0]
assert issubclass(first_call_kwargs.get("response_format"), pydantic.BaseModel)
_, second_call_kwargs = mock_completion.call_args_list[1]
assert second_call_kwargs.get("response_format") == {"type": "json_object"}


def test_json_adapter_with_structured_outputs_does_not_mutate_original_signature():
class OutputField3(pydantic.BaseModel):
subfield1: int = pydantic.Field(description="Int subfield 1")
subfield2: float = pydantic.Field(description="Float subfield 2")

class TestSignature(dspy.Signature):
input1: str = dspy.InputField()
output1: str = dspy.OutputField() # Description intentionally left blank
output2: bool = dspy.OutputField(desc="Boolean output field")
output3: OutputField3 = dspy.OutputField(desc="Nested output field")
output4_unannotated = dspy.OutputField(desc="Unannotated output field")

dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter())
program = dspy.Predict(TestSignature)
with mock.patch("litellm.completion"):
program(input1="Test input")

assert program.signature.output_fields == TestSignature.output_fields
Loading