Skip to content

Commit

Permalink
Support structured outputs response format based on signature in JSON…
Browse files Browse the repository at this point in the history
… adapter (#1881)

* Fix

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Fix

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* fix

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Debug

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Fix

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Here

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Here

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Update json_adapter.py

* Update json_adapter.py

* Update json_adapter.py

* Update json_adapter.py

---------

Signed-off-by: dbczumar <corey.zumar@databricks.com>
  • Loading branch information
dbczumar authored and isaacbmiller committed Dec 11, 2024
1 parent b0401e4 commit a0ec266
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 7 deletions.
64 changes: 62 additions & 2 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import enum
import inspect
import json
import logging
import textwrap
from copy import deepcopy
from typing import Any, Dict, KeysView, Literal, NamedTuple, get_args, get_origin

import json_repair
import litellm
import pydantic
from pydantic import TypeAdapter
from pydantic import TypeAdapter, create_model
from pydantic.fields import FieldInfo

from dspy.adapters.base import Adapter
Expand All @@ -18,6 +20,8 @@
from ..signatures.signature import SignatureMeta
from ..signatures.utils import get_dspy_field_type

_logger = logging.getLogger(__name__)


class FieldInfoWithName(NamedTuple):
name: str
Expand All @@ -35,7 +39,16 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
try:
provider = lm.model.split("/", 1)[0] or "openai"
if "response_format" in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider):
outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"})
try:
response_format = _get_structured_outputs_response_format(signature)
outputs = lm(**inputs, **lm_kwargs, response_format=response_format)
except Exception:
_logger.debug(
"Failed to obtain response using signature-based structured outputs"
" response format: Falling back to default 'json_object' response format."
" Exception: {e}"
)
outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"})
else:
outputs = lm(**inputs, **lm_kwargs)

Expand Down Expand Up @@ -303,3 +316,50 @@ def format_signature_fields_for_instructions(role, fields: Dict[str, FieldInfo])
# ", and then ending with the marker for `completed`.")

return "\n\n".join(parts).strip()


def _get_structured_outputs_response_format(signature: SignatureMeta) -> pydantic.BaseModel:
"""
Obtains the LiteLLM / OpenAI `response_format` parameter for generating structured outputs from
an LM request, based on the output fields of the specified DSPy signature.
Args:
signature: The DSPy signature for which to obtain the `response_format` request parameter.
Returns:
A Pydantic model representing the `response_format` parameter for the LM request.
"""

def filter_json_schema_extra(field_name: str, field_info: FieldInfo) -> FieldInfo:
"""
Recursively filter the `json_schema_extra` of a FieldInfo to exclude DSPy internal attributes
(e.g. `__dspy_field_type`) and remove descriptions that are placeholders for the field name.
"""
field_copy = deepcopy(field_info) # Make a copy to avoid mutating the original

# Update `json_schema_extra` for the copied field
if field_copy.json_schema_extra:
field_copy.json_schema_extra = {
key: value
for key, value in field_info.json_schema_extra.items()
if key not in ("desc", "__dspy_field_type")
}
field_desc = field_info.json_schema_extra.get("desc")
if field_desc is not None and field_desc != f"${{{field_name}}}":
field_copy.json_schema_extra["desc"] = field_desc

# Handle nested fields
if hasattr(field_copy.annotation, "__pydantic_model__"):
# Recursively update fields of the nested model
nested_model = field_copy.annotation.__pydantic_model__
updated_fields = {
key: filter_json_schema_extra(key, value) for key, value in nested_model.__fields__.items()
}
# Create a new model with the same name and updated fields
field_copy.annotation = create_model(nested_model.__name__, **updated_fields)

return field_copy

output_pydantic_fields = {
key: (value.annotation, filter_json_schema_extra(key, value)) for key, value in signature.output_fields.items()
}
return create_model("DSPyProgramOutputs", **output_pydantic_fields)
16 changes: 12 additions & 4 deletions tests/reliability/test_pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

import pydantic
import pytest

import dspy
from tests.reliability.utils import assert_program_output_correct, known_failing_models
Expand Down Expand Up @@ -33,22 +34,29 @@ class QA(dspy.Signature):
assert_program_output_correct(
program_input=question,
program_output=answer.comments,
grading_guidelines="The comments should be relevant to the answer",
grading_guidelines=(
"The comments should be relevant to the answer. They don't need to restate the answer explicitly."
),
)
assert answer.certainty >= 0
assert answer.certainty <= 1
assert len(answer.comments) >= 2


def test_color_classification_using_enum():
@pytest.mark.parametrize("module", [dspy.Predict, dspy.ChainOfThought])
def test_color_classification_using_enum(module):
Color = Enum("Color", ["RED", "GREEN", "BLUE"])

class Colorful(dspy.Signature):
text: str = dspy.InputField()
color: Color = dspy.OutputField()

program = dspy.Predict(Colorful)
color = program(text="The sky is blue").color
program = module(Colorful)
# Note: The precise text, including the trailing period, is important here for ensuring that
# the program is correctly extracting the color from the text; previous implementations have
# produced invalid enum responses for "The sky is blue.", but they have produced valid enum
# responses for "The sky is blue" (without the period).
color = program(text="The sky is blue.").color

assert color == Color.BLUE

Expand Down
1 change: 0 additions & 1 deletion tests/reliability/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def assert_program_output_correct(
grading_guidelines = [grading_guidelines]

with judge_dspy_configuration():
print("GUIDELINES", grading_guidelines)
for guideline_entry in grading_guidelines:
judge_response = _get_judge_program()(
program_input=str(program_input),
Expand Down
95 changes: 95 additions & 0 deletions tests/test_json_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from unittest import mock

import pydantic
import pytest
from pydantic import create_model

import dspy


def test_json_adapter_passes_structured_output_when_supported_by_model():
class OutputField3(pydantic.BaseModel):
subfield1: int = pydantic.Field(description="Int subfield 1", ge=0, le=10)
subfield2: float = pydantic.Field(description="Float subfield 2")

class TestSignature(dspy.Signature):
input1: str = dspy.InputField()
output1: str = dspy.OutputField() # Description intentionally left blank
output2: bool = dspy.OutputField(desc="Boolean output field")
output3: OutputField3 = dspy.OutputField(desc="Nested output field")
output4_unannotated = dspy.OutputField(desc="Unannotated output field")

program = dspy.Predict(TestSignature)

# Configure DSPy to use an OpenAI LM that supports structured outputs
dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter())
with mock.patch("litellm.completion") as mock_completion:
program(input1="Test input")

def clean_schema_extra(field_name, field_info):
attrs = dict(field_info.__repr_args__())
if "json_schema_extra" in attrs:
attrs["json_schema_extra"] = {
k: v
for k, v in attrs["json_schema_extra"].items()
if k != "__dspy_field_type" and not (k == "desc" and v == f"${{{field_name}}}")
}
return attrs

mock_completion.assert_called_once()
_, call_kwargs = mock_completion.call_args
response_format = call_kwargs.get("response_format")
assert response_format is not None
assert issubclass(response_format, pydantic.BaseModel)
assert response_format.model_fields.keys() == {"output1", "output2", "output3", "output4_unannotated"}
for field_name in response_format.model_fields:
assert dict(response_format.model_fields[field_name].__repr_args__()) == clean_schema_extra(
field_name=field_name,
field_info=TestSignature.output_fields[field_name],
)

# Configure DSPy to use a model from a fake provider that doesn't support structured outputs
dspy.configure(lm=dspy.LM(model="fakeprovider/fakemodel"), adapter=dspy.JSONAdapter())
with mock.patch("litellm.completion") as mock_completion:
program(input1="Test input")

mock_completion.assert_called_once()
_, call_kwargs = mock_completion.call_args
assert response_format not in call_kwargs


def test_json_adapter_falls_back_when_structured_outputs_fails():
class TestSignature(dspy.Signature):
input1: str = dspy.InputField()
output1: str = dspy.OutputField(desc="String output field")

dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter())
program = dspy.Predict(TestSignature)
with mock.patch("litellm.completion") as mock_completion:
mock_completion.side_effect = [Exception("Bad structured outputs!"), mock_completion.return_value]
program(input1="Test input")
assert mock_completion.call_count == 2
_, first_call_kwargs = mock_completion.call_args_list[0]
assert issubclass(first_call_kwargs.get("response_format"), pydantic.BaseModel)
_, second_call_kwargs = mock_completion.call_args_list[1]
assert second_call_kwargs.get("response_format") == {"type": "json_object"}


def test_json_adapter_with_structured_outputs_does_not_mutate_original_signature():
class OutputField3(pydantic.BaseModel):
subfield1: int = pydantic.Field(description="Int subfield 1")
subfield2: float = pydantic.Field(description="Float subfield 2")

class TestSignature(dspy.Signature):
input1: str = dspy.InputField()
output1: str = dspy.OutputField() # Description intentionally left blank
output2: bool = dspy.OutputField(desc="Boolean output field")
output3: OutputField3 = dspy.OutputField(desc="Nested output field")
output4_unannotated = dspy.OutputField(desc="Unannotated output field")

dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter())
program = dspy.Predict(TestSignature)
with mock.patch("litellm.completion"):
program(input1="Test input")

assert program.signature.output_fields == TestSignature.output_fields

0 comments on commit a0ec266

Please sign in to comment.