-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support structured outputs response format based on signature in JSON adapter #1881
Conversation
A Pydantic model representing the `response_format` parameter for the LM request. | ||
""" | ||
|
||
def filter_json_schema_extra(field_name: str, field_info: FieldInfo) -> FieldInfo: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs test coverage
"Failed to obtain response using signature-based structured outputs" | ||
" response format: Falling back to default 'json_object' response format." | ||
" Exception: {e}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We expect to hit this case for tuples, until there's support for prefixItems
in the OpenAI structured outputs API. Other vendors, e.g. Databricks, will likely lag even further behind (e.g. Databricks doesn't support anyOf currently, but OpenAI does), meaning that we could hit this case for additional output types
@@ -31,7 +31,6 @@ def assert_program_output_correct( | |||
grading_guidelines = [grading_guidelines] | |||
|
|||
with judge_dspy_configuration(): | |||
print("GUIDELINES", grading_guidelines) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing a leftover & unintentional debugging statement from test generation code
) | ||
assert answer.certainty >= 0 | ||
assert answer.certainty <= 1 | ||
assert len(answer.comments) >= 2 | ||
|
||
|
||
def test_color_classification_using_enum(): | ||
@pytest.mark.parametrize("module", [dspy.Predict, dspy.ChainOfThought]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CoT fails this test case with chat and json adapters on master:
FAILED test_pydantic_models.py::test_color_classification_using_enum[llama-3.1-70b-instruct-ChainOfThought] - ValueError: Color.BLUE is not a valid name or value for the enum Color
================================================ 1 failed, 1 passed, 24 skipped, 26 deselected, 2 warnings in 0.22s
However, it passes on the PR branch :D
dspy/clients/lm.py
Outdated
@@ -212,7 +216,7 @@ def copy(self, **kwargs): | |||
return new_instance | |||
|
|||
|
|||
@functools.lru_cache(maxsize=None) | |||
# @functools.lru_cache(maxsize=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a hack to get the implementation working end to end for test purposes (see also https://github.com/stanfordnlp/dspy/pull/1881/files#r1867211773). We need a proper fix before merge, e.g. #1862 (though it's not 100% clear to me why we need LRU caching here in the first place on top of the caching that LiteLLM is already providing)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @okhat I'm sure I'm missing something here - let me know if there's additional context motivating this lru_cache
.
dspy/clients/lm.py
Outdated
@@ -92,7 +92,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): | |||
completion = cached_litellm_text_completion if cache else litellm_text_completion | |||
|
|||
response = completion( | |||
request=ujson.dumps(dict(model=self.model, messages=messages, **kwargs)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When response_format is a pydantic model (recommended by LiteLLM), ujson.dumps()
fails because pydantic models are not directly serializable using ujson.dumps()
. This line diff is a temporary hack to get the implementation working end-to-end for test purposes. We need a proper solution before merge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @dbczumar, please let me know if this approach is worthy of a PR, I am happy to contribute. Custom adapter using this approach.
5bfba54
to
5af146d
Compare
try: | ||
response_format = _get_structured_outputs_response_format(signature) | ||
outputs = lm(**inputs, **lm_kwargs, response_format=response_format) | ||
except Exception: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LM providers have differing levels of support for response_format
fields. For example, Databricks doesn't support anyOf / allOf, but OpenAI does.
A blanket try/catch seems appropriate here to start.
response_format = _get_structured_outputs_response_format(signature) | ||
outputs = lm(**inputs, **lm_kwargs, response_format=response_format) | ||
except Exception: | ||
_logger.debug( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debated a warning, but it seems too spammy
# Recursively update fields of the nested model | ||
nested_model = field_copy.annotation.__pydantic_model__ | ||
updated_fields = { | ||
key: filter_json_schema_extra(key, value) for key, value in nested_model.__fields__.items() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious - why do we need recursive handling? nested_model.__fields__
should be Pydantic fields instead of DSPy fields, do they also have these DSPy internal attributes like __dspy_field_type
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the vast majority of cases, hopefully not....
Though the following user error will silently produce this state, which is probably best to exclude from response_format
because the program still runs
import dspy
import pydantic
class Obj(pydantic.BaseModel):
a: int = dspy.OutputField()
b: str
class MySig(dspy.Signature):
inp: str = dspy.InputField()
outp: Obj = dspy.OutputField()
print(MySig.schema())
{'$defs': {'Obj': {'properties': {'a': {'__dspy_field_type': 'output', 'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'string'}}, 'required': ['a', 'b'], 'title': 'Obj', 'type': 'object'}}, 'description': 'Given the fields `inp`, produce the fields `outp`.', 'properties': {'inp': {'__dspy_field_type': 'input', 'desc': '${inp}', 'prefix': 'Inp:', 'title': 'Inp', 'type': 'string'}, 'outp': {'$ref': '#/$defs/Obj', '__dspy_field_type': 'output', 'desc': '${outp}', 'prefix': 'Outp:'}}, 'required': ['inp', 'outp'], 'title': 'MySig', 'type': 'object'}
… adapter (#1881) * Fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * Fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * Debug Signed-off-by: dbczumar <corey.zumar@databricks.com> * Fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * Here Signed-off-by: dbczumar <corey.zumar@databricks.com> * Here Signed-off-by: dbczumar <corey.zumar@databricks.com> * Update json_adapter.py * Update json_adapter.py * Update json_adapter.py * Update json_adapter.py --------- Signed-off-by: dbczumar <corey.zumar@databricks.com>
Support structured outputs response format based on signature in JSON adapter