From c4f1f9542876011233e2e374f49f7fbe84fce6d6 Mon Sep 17 00:00:00 2001 From: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Date: Mon, 25 Nov 2024 19:26:31 -0800 Subject: [PATCH] Adapters: Support JSON serialization of all pydantic types (e.g. datetimes, enums, etc.) (#1853) * Add Signed-off-by: dbczumar * fix Signed-off-by: dbczumar * fix Signed-off-by: dbczumar * fix Signed-off-by: dbczumar * fix Signed-off-by: dbczumar --------- Signed-off-by: dbczumar --- dspy/adapters/chat_adapter.py | 86 +---------- dspy/adapters/json_adapter.py | 118 ++++++--------- dspy/adapters/utils.py | 137 ++++++++++++++++++ tests/functional/test_signature_typed.py | 10 +- tests/predict/test_predict.py | 113 +++++++++++++++ .../test_many_types_1/inputs/input1.json | 4 +- .../test_many_types_1/inputs/input2.json | 4 +- .../generated/test_many_types_1/program.py | 2 +- 8 files changed, 304 insertions(+), 170 deletions(-) create mode 100644 dspy/adapters/utils.py diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index a8ae380cde..14dc3b7f1f 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -14,7 +14,7 @@ from dsp.adapters.base_template import Field from dspy.adapters.base import Adapter -from dspy.adapters.image_utils import Image, encode_image +from dspy.adapters.utils import find_enum_member, format_field_value from dspy.signatures.field import OutputField from dspy.signatures.signature import Signature, SignatureMeta from dspy.signatures.utils import get_dspy_field_type @@ -115,86 +115,6 @@ def format_fields(self, signature, values, role): return format_fields(fields_with_values) -def format_blob(blob): - if "\n" not in blob and "«" not in blob and "»" not in blob: - return f"«{blob}»" - - modified_blob = blob.replace("\n", "\n ") - return f"«««\n {modified_blob}\n»»»" - - -def format_input_list_field_value(value: List[Any]) -> str: - """ - Formats the value of an input field of type List[Any]. - - Args: - value: The value of the list-type input field. - Returns: - A string representation of the input field's list value. - """ - if len(value) == 0: - return "N/A" - if len(value) == 1: - return format_blob(value[0]) - - return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)]) - - -def _serialize_for_json(value): - if isinstance(value, pydantic.BaseModel): - return value.model_dump() - elif isinstance(value, list): - return [_serialize_for_json(item) for item in value] - elif isinstance(value, dict): - return {key: _serialize_for_json(val) for key, val in value.items()} - else: - return value - - -def _format_field_value(field_info: FieldInfo, value: Any, assume_text=True) -> Union[str, dict]: - """ - Formats the value of the specified field according to the field's DSPy type (input or output), - annotation (e.g. str, int, etc.), and the type of the value itself. - - Args: - field_info: Information about the field, including its DSPy field type and annotation. - value: The value of the field. - Returns: - The formatted value of the field, represented as a string. - """ - string_value = None - if isinstance(value, list) and field_info.annotation is str: - # If the field has no special type requirements, format it as a nice numbered list for the LM. - string_value = format_input_list_field_value(value) - elif isinstance(value, pydantic.BaseModel) or isinstance(value, dict) or isinstance(value, list): - string_value = json.dumps(_serialize_for_json(value), ensure_ascii=False) - else: - string_value = str(value) - - if assume_text: - return string_value - elif isinstance(value, Image) or field_info.annotation == Image: - # This validation should happen somewhere else - # Safe to import PIL here because it's only imported when an image is actually being formatted - try: - import PIL - except ImportError: - raise ImportError("PIL is required to format images; Run `pip install pillow` to install it.") - image_value = value - if not isinstance(image_value, Image): - if isinstance(image_value, dict) and "url" in image_value: - image_value = image_value["url"] - elif isinstance(image_value, str): - image_value = encode_image(image_value) - elif isinstance(image_value, PIL.Image.Image): - image_value = encode_image(image_value) - assert isinstance(image_value, str) - image_value = Image(url=image_value) - return {"type": "image_url", "image_url": image_value.model_dump()} - else: - return {"type": "text", "text": string_value} - - def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=True) -> Union[str, List[dict]]: """ Formats the values of the specified fields according to the field's DSPy type (input or output), @@ -209,7 +129,7 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text= """ output = [] for field, field_value in fields_with_values.items(): - formatted_field_value = _format_field_value(field_info=field.info, value=field_value, assume_text=assume_text) + formatted_field_value = format_field_value(field_info=field.info, value=field_value, assume_text=assume_text) if assume_text: output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}") else: @@ -231,7 +151,7 @@ def parse_value(value, annotation): parsed_value = value if isinstance(annotation, enum.EnumMeta): - parsed_value = annotation[value] + return find_enum_member(annotation, value) elif isinstance(value, str): try: parsed_value = json.loads(value) diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index 1930bf3cb7..afe2e8c4e9 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -1,26 +1,29 @@ import ast -import json import enum import inspect -import litellm -import pydantic +import json import textwrap -import json_repair - +from typing import Any, Dict, KeysView, Literal, NamedTuple, get_args, get_origin +import json_repair +import litellm +import pydantic from pydantic import TypeAdapter from pydantic.fields import FieldInfo -from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin from dspy.adapters.base import Adapter +from dspy.adapters.utils import find_enum_member, format_field_value, serialize_for_json + from ..adapters.image_utils import Image from ..signatures.signature import SignatureMeta from ..signatures.utils import get_dspy_field_type + class FieldInfoWithName(NamedTuple): name: str info: FieldInfo + class JSONAdapter(Adapter): def __init__(self): pass @@ -28,12 +31,11 @@ def __init__(self): def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): inputs = self.format(signature, demos, inputs) inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs) - - + try: - provider = lm.model.split('/', 1)[0] or "openai" - if 'response_format' in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider): - outputs = lm(**inputs, **lm_kwargs, response_format={ "type": "json_object" }) + provider = lm.model.split("/", 1)[0] or "openai" + if "response_format" in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider): + outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"}) else: outputs = lm(**inputs, **lm_kwargs) @@ -44,11 +46,12 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): for output in outputs: value = self.parse(signature, output, _parse_values=_parse_values) - assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}" + assert set(value.keys()) == set( + signature.output_fields.keys() + ), f"Expected {signature.output_fields.keys()} but got {value.keys()}" values.append(value) - - return values + return values def format(self, signature, demos, inputs): messages = [] @@ -71,7 +74,7 @@ def format(self, signature, demos, inputs): messages.append(format_turn(signature, demo, role="assistant", incomplete=demo in incomplete_demos)) messages.append(format_turn(signature, inputs, role="user")) - + return messages def parse(self, signature, completion, _parse_values=True): @@ -90,7 +93,7 @@ def parse(self, signature, completion, _parse_values=True): def format_turn(self, signature, values, role, incomplete=False): return format_turn(signature, values, role, incomplete) - + def format_fields(self, signature, values, role): fields_with_values = { FieldInfoWithName(name=field_name, info=field_info): values.get( @@ -101,16 +104,16 @@ def format_fields(self, signature, values, role): } return format_fields(role=role, fields_with_values=fields_with_values) - + def parse_value(value, annotation): if annotation is str: return str(value) - + parsed_value = value if isinstance(annotation, enum.EnumMeta): - parsed_value = annotation[value] + parsed_value = find_enum_member(annotation, value) elif isinstance(value, str): try: parsed_value = json.loads(value) @@ -119,45 +122,10 @@ def parse_value(value, annotation): parsed_value = ast.literal_eval(value) except (ValueError, SyntaxError): parsed_value = value - - return TypeAdapter(annotation).validate_python(parsed_value) - -def format_blob(blob): - if "\n" not in blob and "«" not in blob and "»" not in blob: - return f"«{blob}»" - - modified_blob = blob.replace("\n", "\n ") - return f"«««\n {modified_blob}\n»»»" - - -def format_input_list_field_value(value: List[Any]) -> str: - """ - Formats the value of an input field of type List[Any]. - - Args: - value: The value of the list-type input field. - Returns: - A string representation of the input field's list value. - """ - if len(value) == 0: - return "N/A" - if len(value) == 1: - return format_blob(value[0]) - - return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)]) + return TypeAdapter(annotation).validate_python(parsed_value) -def _serialize_for_json(value): - if isinstance(value, pydantic.BaseModel): - return value.model_dump() - elif isinstance(value, list): - return [_serialize_for_json(item) for item in value] - elif isinstance(value, dict): - return {key: _serialize_for_json(val) for key, val in value.items()} - else: - return value - def _format_field_value(field_info: FieldInfo, value: Any) -> str: """ Formats the value of the specified field according to the field's DSPy type (input or output), @@ -169,17 +137,10 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str: Returns: The formatted value of the field, represented as a string. """ - - if isinstance(value, list) and field_info.annotation is str: - # If the field has no special type requirements, format it as a nice numbere list for the LM. - return format_input_list_field_value(value) if field_info.annotation is Image: raise NotImplementedError("Images are not yet supported in JSON mode.") - elif isinstance(value, pydantic.BaseModel) or isinstance(value, dict) or isinstance(value, list): - return json.dumps(_serialize_for_json(value)) - else: - return str(value) + return format_field_value(field_info=field_info, value=value, assume_text=True) def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -> str: @@ -197,9 +158,8 @@ def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) - if role == "assistant": d = fields_with_values.items() - d = {k.name: _serialize_for_json(v) for k, v in d} - - return json.dumps(_serialize_for_json(d), indent=2) + d = {k.name: v for k, v in d} + return json.dumps(serialize_for_json(d), indent=2) output = [] for field, field_value in fields_with_values.items(): @@ -246,15 +206,19 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple field_name, "Not supplied for this particular example." ) for field_name, field_info in fields.items() - } + }, ) content.append(formatted_fields) if role == "user": + def type_info(v): - return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \ - if v.annotation is not str else "" - + return ( + f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" + if v.annotation is not str + else "" + ) + # TODO: Consider if not incomplete: content.append( "Respond with a JSON object in the following order of fields: " @@ -297,15 +261,15 @@ def prepare_instructions(signature: SignatureMeta): def field_metadata(field_name, field_info): type_ = field_info.annotation - if get_dspy_field_type(field_info) == 'input' or type_ is str: + if get_dspy_field_type(field_info) == "input" or type_ is str: desc = "" elif type_ is bool: desc = "must be True or False" elif type_ in (int, float): desc = f"must be a single {type_.__name__} value" elif inspect.isclass(type_) and issubclass(type_, enum.Enum): - desc= f"must be one of: {'; '.join(type_.__members__)}" - elif hasattr(type_, '__origin__') and type_.__origin__ is Literal: + desc = f"must be one of: {'; '.join(type_.__members__)}" + elif hasattr(type_, "__origin__") and type_.__origin__ is Literal: desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}" else: desc = "must be pareseable according to the following JSON schema: " @@ -320,13 +284,13 @@ def format_signature_fields_for_instructions(role, fields: Dict[str, FieldInfo]) fields_with_values={ FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info) for field_name, field_info in fields.items() - } + }, ) - + parts.append("Inputs will have the following structure:") - parts.append(format_signature_fields_for_instructions('user', signature.input_fields)) + parts.append(format_signature_fields_for_instructions("user", signature.input_fields)) parts.append("Outputs will be a JSON object with the following fields.") - parts.append(format_signature_fields_for_instructions('assistant', signature.output_fields)) + parts.append(format_signature_fields_for_instructions("assistant", signature.output_fields)) # parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""})) instructions = textwrap.dedent(signature.instructions) diff --git a/dspy/adapters/utils.py b/dspy/adapters/utils.py new file mode 100644 index 0000000000..9cddc1fe86 --- /dev/null +++ b/dspy/adapters/utils.py @@ -0,0 +1,137 @@ +import json +from typing import Any, List, Union + +from pydantic import TypeAdapter +from pydantic.fields import FieldInfo + +from .image_utils import Image, encode_image + + +def serialize_for_json(value: Any) -> Any: + """ + Formats the specified value so that it can be serialized as a JSON string. + + Args: + value: The value to format as a JSON string. + Returns: + The formatted value, which is serializable as a JSON string. + """ + # Attempt to format the value as a JSON-compatible object using pydantic, falling back to + # a string representation of the value if that fails (e.g. if the value contains an object + # that pydantic doesn't recognize or can't serialize) + try: + return TypeAdapter(type(value)).dump_python(value, mode="json") + except Exception: + return str(value) + + +def format_field_value(field_info: FieldInfo, value: Any, assume_text=True) -> Union[str, dict]: + """ + Formats the value of the specified field according to the field's DSPy type (input or output), + annotation (e.g. str, int, etc.), and the type of the value itself. + + Args: + field_info: Information about the field, including its DSPy field type and annotation. + value: The value of the field. + Returns: + The formatted value of the field, represented as a string. + """ + string_value = None + if isinstance(value, list) and field_info.annotation is str: + # If the field has no special type requirements, format it as a nice numbered list for the LM. + string_value = _format_input_list_field_value(value) + else: + jsonable_value = serialize_for_json(value) + if isinstance(jsonable_value, dict) or isinstance(jsonable_value, list): + string_value = json.dumps(jsonable_value, ensure_ascii=False) + else: + # If the value is not a Python representation of a JSON object or Array + # (e.g. the value is a JSON string), just use the string representation of the value + # to avoid double-quoting the JSON string (which would hurt accuracy for certain + # tasks, e.g. tasks that rely on computing string length) + string_value = str(jsonable_value) + + if assume_text: + return string_value + elif isinstance(value, Image) or field_info.annotation == Image: + # This validation should happen somewhere else + # Safe to import PIL here because it's only imported when an image is actually being formatted + try: + import PIL + except ImportError: + raise ImportError("PIL is required to format images; Run `pip install pillow` to install it.") + image_value = value + if not isinstance(image_value, Image): + if isinstance(image_value, dict) and "url" in image_value: + image_value = image_value["url"] + elif isinstance(image_value, str): + image_value = encode_image(image_value) + elif isinstance(image_value, PIL.Image.Image): + image_value = encode_image(image_value) + assert isinstance(image_value, str) + image_value = Image(url=image_value) + return {"type": "image_url", "image_url": image_value.model_dump()} + else: + return {"type": "text", "text": string_value} + + +def find_enum_member(enum, identifier): + """ + Finds the enum member corresponding to the specified identifier, which may be the + enum member's name or value. + + Args: + enum: The enum to search for the member. + identifier: If the enum is explicitly-valued, this is the value of the enum member to find. + If the enum is auto-valued, this is the name of the enum member to find. + Returns: + The enum member corresponding to the specified identifier. + """ + # Check if the identifier is a valid enum member value *before* checking if it's a valid enum + # member name, since the identifier will be a value for explicitly-valued enums. This handles + # the (rare) case where an enum member value is the same as another enum member's name in + # an explicitly-valued enum + for member in enum: + if member.value == identifier: + return member + + # If the identifier is not a valid enum member value, check if it's a valid enum member name, + # since the identifier will be a member name for auto-valued enums + if identifier in enum.__members__: + return enum[identifier] + + raise ValueError(f"{identifier} is not a valid name or value for the enum {enum.__name__}") + + +def _format_input_list_field_value(value: List[Any]) -> str: + """ + Formats the value of an input field of type List[Any]. + + Args: + value: The value of the list-type input field. + Returns: + A string representation of the input field's list value. + """ + if len(value) == 0: + return "N/A" + if len(value) == 1: + return _format_blob(value[0]) + + return "\n".join([f"[{idx+1}] {_format_blob(txt)}" for idx, txt in enumerate(value)]) + + +def _format_blob(blob: str) -> str: + """ + Formats the specified text blobs so that an LM can parse it correctly within a list + of multiple text blobs. + + Args: + blob: The text blob to format. + Returns: + The formatted text blob. + """ + if "\n" not in blob and "«" not in blob and "»" not in blob: + return f"«{blob}»" + + modified_blob = blob.replace("\n", "\n ") + return f"«««\n {modified_blob}\n»»»" diff --git a/tests/functional/test_signature_typed.py b/tests/functional/test_signature_typed.py index 7cd1c1fcfd..cdc0ef9722 100644 --- a/tests/functional/test_signature_typed.py +++ b/tests/functional/test_signature_typed.py @@ -4,7 +4,7 @@ import pytest import dspy -from dspy.adapters.chat_adapter import _format_field_value +from dspy.adapters.utils import format_field_value from dspy.functional import TypedPredictor from dspy.signatures.signature import signature_to_template @@ -116,7 +116,7 @@ class MySignature(dspy.Signature): instance = build_model_instance() parsed_instance = parser(instance.model_dump_json()) - formatted_instance = _format_field_value( + formatted_instance = format_field_value( field_info=dspy.OutputField(), value=instance.model_dump_json(), ) @@ -136,7 +136,7 @@ class MySignature(dspy.Signature): parsed_instance = parser(instance.model_dump_json()) assert parsed_instance == instance, f"{instance} != {parsed_instance}" - formatted_instance = _format_field_value( + formatted_instance = format_field_value( field_info=dspy.OutputField(), value=instance.model_dump_json(), ) @@ -160,7 +160,7 @@ class MySignature(dspy.Signature): instance = NestedModel(model=build_model_instance()) parsed_instance = parser(instance.model_dump_json()) - formatted_instance = _format_field_value( + formatted_instance = format_field_value( field_info=dspy.OutputField(), value=instance.model_dump_json(), ) @@ -191,7 +191,7 @@ class MySignature(dspy.Signature): parsed_instance = parser('{"string": "foobar", "number": 42, "floating": 3.14, "boolean": true}') assert parsed_instance == instance, f"{instance} != {parsed_instance}" - formatted_instance = _format_field_value( + formatted_instance = format_field_value( field_info=dspy.OutputField(), value=ujson.dumps(asdict(instance)), ) diff --git a/tests/predict/test_predict.py b/tests/predict/test_predict.py index 7f3aa43306..97544043ed 100644 --- a/tests/predict/test_predict.py +++ b/tests/predict/test_predict.py @@ -1,6 +1,9 @@ import copy +import enum +from datetime import datetime import pydantic +import pytest import ujson import dspy @@ -197,6 +200,116 @@ def test_multi_output2(): assert results.completions.answer2[1] == "my 3 answer" +def test_datetime_inputs_and_outputs(): + # Define a model for datetime inputs and outputs + class TimedEvent(pydantic.BaseModel): + event_name: str + event_time: datetime + + class TimedSignature(dspy.Signature): + events: list[TimedEvent] = dspy.InputField() + summary: str = dspy.OutputField() + next_event_time: datetime = dspy.OutputField() + + program = Predict(TimedSignature) + + lm = DummyLM( + [ + { + "reasoning": "Processed datetime inputs", + "summary": "All events are processed", + "next_event_time": "2024-11-27T14:00:00", + } + ] + ) + dspy.settings.configure(lm=lm) + + output = program( + events=[ + TimedEvent(event_name="Event 1", event_time=datetime(2024, 11, 25, 10, 0, 0)), + TimedEvent(event_name="Event 2", event_time=datetime(2024, 11, 25, 15, 30, 0)), + ] + ) + assert output.summary == "All events are processed" + assert output.next_event_time == datetime(2024, 11, 27, 14, 0, 0) + + +def test_explicitly_valued_enum_inputs_and_outputs(): + class Status(enum.Enum): + PENDING = "pending" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + + class StatusSignature(dspy.Signature): + current_status: Status = dspy.InputField() + next_status: Status = dspy.OutputField() + + program = Predict(StatusSignature) + + lm = DummyLM( + [ + { + "reasoning": "The current status is 'PENDING', advancing to 'IN_PROGRESS'.", + "next_status": "in_progress", + } + ] + ) + dspy.settings.configure(lm=lm) + + output = program(current_status=Status.PENDING) + assert output.next_status == Status.IN_PROGRESS + + +def test_enum_inputs_and_outputs_with_shared_names_and_values(): + class TicketStatus(enum.Enum): + OPEN = "CLOSED" + CLOSED = "RESOLVED" + RESOLVED = "OPEN" + + class TicketStatusSignature(dspy.Signature): + current_status: TicketStatus = dspy.InputField() + next_status: TicketStatus = dspy.OutputField() + + program = Predict(TicketStatusSignature) + + # Mock reasoning and output + lm = DummyLM( + [ + { + "reasoning": "The ticket is currently 'OPEN', transitioning to 'CLOSED'.", + "next_status": "RESOLVED", # Refers to TicketStatus.CLOSED by value + } + ] + ) + dspy.settings.configure(lm=lm) + + output = program(current_status=TicketStatus.OPEN) + assert output.next_status == TicketStatus.CLOSED # By value + + +def test_auto_valued_enum_inputs_and_outputs(): + Status = enum.Enum("Status", ["PENDING", "IN_PROGRESS", "COMPLETED"]) + + class StatusSignature(dspy.Signature): + current_status: Status = dspy.InputField() + next_status: Status = dspy.OutputField() + + program = Predict(StatusSignature) + + lm = DummyLM( + [ + { + "reasoning": "The current status is 'PENDING', advancing to 'IN_PROGRESS'.", + "next_status": "IN_PROGRESS", # Use the auto-assigned value for IN_PROGRESS + } + ] + ) + dspy.settings.configure(lm=lm) + + output = program(current_status=Status.PENDING) + assert output.next_status == Status.IN_PROGRESS + + def test_named_predictors(): class MyModule(dspy.Module): def __init__(self): diff --git a/tests/reliability/complex_types/generated/test_many_types_1/inputs/input1.json b/tests/reliability/complex_types/generated/test_many_types_1/inputs/input1.json index 91b2530aa0..63044a2510 100644 --- a/tests/reliability/complex_types/generated/test_many_types_1/inputs/input1.json +++ b/tests/reliability/complex_types/generated/test_many_types_1/inputs/input1.json @@ -1,11 +1,11 @@ { "assertions": [ - "The 'processedTupleField' should be a tuple containing a string and a number. Note that 'processedNestedObjectField.tupleField' should NOT actually be a tuple.", + "The 'processedTupleField' should be a tuple containing a string and a number.", "The 'processedEnumField' should be one of the allowed enum values: 'option1', 'option2', or 'option3'.", "The 'processedDatetimeField' should be a date-time", "The 'processedLiteralField' should be exactly 'literalValue'.", "The 'processedObjectField' should contain 'subField1' (string), 'subField2' (number), and an additional boolean field 'additionalField'.", - "The 'processedNestedObjectField' should contain 'tupleField' (which is actually a list with a string and a number - the name is misleading), 'enumField' (one of the allowed enum values), 'datetimeField' (string formatted as date-time), 'literalField' (exactly 'literalValue'), and an additional boolean field 'additionalField'." + "The 'processedNestedObjectField' should contain 'tupleField' as a tuple with a string and float, 'enumField' (one of the allowed enum values), 'datetimeField' (string formatted as date-time), 'literalField' (exactly 'literalValue'), and an additional boolean field 'additionalField'." ], "input": { "datetimeField": "2023-10-12T07:20:50.52Z", diff --git a/tests/reliability/complex_types/generated/test_many_types_1/inputs/input2.json b/tests/reliability/complex_types/generated/test_many_types_1/inputs/input2.json index bca9c80f98..cae7c21a5a 100644 --- a/tests/reliability/complex_types/generated/test_many_types_1/inputs/input2.json +++ b/tests/reliability/complex_types/generated/test_many_types_1/inputs/input2.json @@ -1,11 +1,11 @@ { "assertions": [ - "The 'processedTupleField' should be an tuple with exactly two elements: the first element being a string and the second element being a number. Note that 'processedNestedObjectField.tupleField' should NOT actually be a tuple", + "The 'processedTupleField' should be an tuple with exactly two elements: the first element being a string and the second element being a number.", "The 'processedEnumField' should be one of the predefined options: 'option1', 'option2', or 'option3'.", "The 'processedDatetimeField' should be a date-time", "The 'processedLiteralField' should be the enum 'literalValue'.", "The 'processedObjectField' should be an object containing 'subField1' as a string, 'subField2' as a number, and an 'additionalField' as a boolean.", - "The 'processedNestedObjectField' should be an object containing 'tupleField' as a list (NOT a tuple) with exactly two elements (a string and a number), 'enumField' as one of the predefined options (option1, option2, or option3), 'datetimeField' as a 'date-time' object, 'literalField' as the string 'literalValue', and an 'additionalField' as a boolean." + "The 'processedNestedObjectField' should be an object containing 'tupleField' as a tuple with a string and float, 'enumField' as one of the predefined options (option1, option2, or option3), 'datetimeField' as a 'date-time' object, 'literalField' as the string 'literalValue', and an 'additionalField' as a boolean." ], "input": { "datetimeField": "2023-10-01T12:00:00Z", diff --git a/tests/reliability/complex_types/generated/test_many_types_1/program.py b/tests/reliability/complex_types/generated/test_many_types_1/program.py index 52798903bf..49332a8038 100644 --- a/tests/reliability/complex_types/generated/test_many_types_1/program.py +++ b/tests/reliability/complex_types/generated/test_many_types_1/program.py @@ -76,7 +76,7 @@ class LiteralField(Enum): class ProcessedNestedObjectField(BaseModel): - tupleField: List[Union[str, float]] = Field(..., max_items=2, min_items=2) + tupleField: Tuple[str, float] enumField: EnumField datetimeField: datetime literalField: LiteralField