-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support structured outputs response format based on signature in JSON…
… adapter (#1881) * Fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * Fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * Debug Signed-off-by: dbczumar <corey.zumar@databricks.com> * Fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * Here Signed-off-by: dbczumar <corey.zumar@databricks.com> * Here Signed-off-by: dbczumar <corey.zumar@databricks.com> * Update json_adapter.py * Update json_adapter.py * Update json_adapter.py * Update json_adapter.py --------- Signed-off-by: dbczumar <corey.zumar@databricks.com>
- Loading branch information
1 parent
b0401e4
commit a0ec266
Showing
4 changed files
with
169 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from unittest import mock | ||
|
||
import pydantic | ||
import pytest | ||
from pydantic import create_model | ||
|
||
import dspy | ||
|
||
|
||
def test_json_adapter_passes_structured_output_when_supported_by_model(): | ||
class OutputField3(pydantic.BaseModel): | ||
subfield1: int = pydantic.Field(description="Int subfield 1", ge=0, le=10) | ||
subfield2: float = pydantic.Field(description="Float subfield 2") | ||
|
||
class TestSignature(dspy.Signature): | ||
input1: str = dspy.InputField() | ||
output1: str = dspy.OutputField() # Description intentionally left blank | ||
output2: bool = dspy.OutputField(desc="Boolean output field") | ||
output3: OutputField3 = dspy.OutputField(desc="Nested output field") | ||
output4_unannotated = dspy.OutputField(desc="Unannotated output field") | ||
|
||
program = dspy.Predict(TestSignature) | ||
|
||
# Configure DSPy to use an OpenAI LM that supports structured outputs | ||
dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter()) | ||
with mock.patch("litellm.completion") as mock_completion: | ||
program(input1="Test input") | ||
|
||
def clean_schema_extra(field_name, field_info): | ||
attrs = dict(field_info.__repr_args__()) | ||
if "json_schema_extra" in attrs: | ||
attrs["json_schema_extra"] = { | ||
k: v | ||
for k, v in attrs["json_schema_extra"].items() | ||
if k != "__dspy_field_type" and not (k == "desc" and v == f"${{{field_name}}}") | ||
} | ||
return attrs | ||
|
||
mock_completion.assert_called_once() | ||
_, call_kwargs = mock_completion.call_args | ||
response_format = call_kwargs.get("response_format") | ||
assert response_format is not None | ||
assert issubclass(response_format, pydantic.BaseModel) | ||
assert response_format.model_fields.keys() == {"output1", "output2", "output3", "output4_unannotated"} | ||
for field_name in response_format.model_fields: | ||
assert dict(response_format.model_fields[field_name].__repr_args__()) == clean_schema_extra( | ||
field_name=field_name, | ||
field_info=TestSignature.output_fields[field_name], | ||
) | ||
|
||
# Configure DSPy to use a model from a fake provider that doesn't support structured outputs | ||
dspy.configure(lm=dspy.LM(model="fakeprovider/fakemodel"), adapter=dspy.JSONAdapter()) | ||
with mock.patch("litellm.completion") as mock_completion: | ||
program(input1="Test input") | ||
|
||
mock_completion.assert_called_once() | ||
_, call_kwargs = mock_completion.call_args | ||
assert response_format not in call_kwargs | ||
|
||
|
||
def test_json_adapter_falls_back_when_structured_outputs_fails(): | ||
class TestSignature(dspy.Signature): | ||
input1: str = dspy.InputField() | ||
output1: str = dspy.OutputField(desc="String output field") | ||
|
||
dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter()) | ||
program = dspy.Predict(TestSignature) | ||
with mock.patch("litellm.completion") as mock_completion: | ||
mock_completion.side_effect = [Exception("Bad structured outputs!"), mock_completion.return_value] | ||
program(input1="Test input") | ||
assert mock_completion.call_count == 2 | ||
_, first_call_kwargs = mock_completion.call_args_list[0] | ||
assert issubclass(first_call_kwargs.get("response_format"), pydantic.BaseModel) | ||
_, second_call_kwargs = mock_completion.call_args_list[1] | ||
assert second_call_kwargs.get("response_format") == {"type": "json_object"} | ||
|
||
|
||
def test_json_adapter_with_structured_outputs_does_not_mutate_original_signature(): | ||
class OutputField3(pydantic.BaseModel): | ||
subfield1: int = pydantic.Field(description="Int subfield 1") | ||
subfield2: float = pydantic.Field(description="Float subfield 2") | ||
|
||
class TestSignature(dspy.Signature): | ||
input1: str = dspy.InputField() | ||
output1: str = dspy.OutputField() # Description intentionally left blank | ||
output2: bool = dspy.OutputField(desc="Boolean output field") | ||
output3: OutputField3 = dspy.OutputField(desc="Nested output field") | ||
output4_unannotated = dspy.OutputField(desc="Unannotated output field") | ||
|
||
dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter()) | ||
program = dspy.Predict(TestSignature) | ||
with mock.patch("litellm.completion"): | ||
program(input1="Test input") | ||
|
||
assert program.signature.output_fields == TestSignature.output_fields |