Skip to content

Commit

Permalink
Bugfix/signature replace and pydantic 2.10 (#1855)
Browse files Browse the repository at this point in the history
* feat(dspy): add datamodel-code-generator to dev reqs

* fix(dspy): fix signature replace for pydantic v2.10

* fix(dspy): fix signature replace for pydantic v2.10
  • Loading branch information
mikeedjones authored Nov 25, 2024
1 parent 2aa6f01 commit ff6f5a8
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 14 deletions.
45 changes: 33 additions & 12 deletions dspy/signatures/signature.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
import inspect
import logging
import re
import types
import typing
Expand All @@ -11,8 +12,9 @@
from pydantic.fields import FieldInfo

import dsp
from dspy.signatures.field import InputField, OutputField, new_to_old_field
from dspy.adapters.image_utils import Image
from dspy.signatures.field import InputField, OutputField, new_to_old_field


def signature_to_template(signature, adapter=None) -> dsp.Template:
"""Convert from new to legacy format."""
Expand Down Expand Up @@ -242,8 +244,8 @@ class Signature(BaseModel, metaclass=SignatureMeta):
@classmethod
@contextmanager
def replace(
cls: "Signature",
new_signature: "Signature",
cls,
new_signature: "Type[Signature]",
validate_new_signature: bool = True,
) -> typing.Generator[None, None, None]:
"""Replace the signature with an updated version.
Expand All @@ -262,16 +264,35 @@ def replace(
f"Field '{field}' is missing from the updated signature '{new_signature.__class__}.",
)

class OldSignature(cls, Signature):
class OldSignature(cls):
pass

replace_fields = ["__doc__", "model_fields", "model_extra", "model_config"]
for field in replace_fields:
setattr(cls, field, getattr(new_signature, field))
def swap_attributes(source: Type[Signature]):
unhandled = {}

for attr in ["__doc__", "__pydantic_fields__", "model_fields", "model_extra", "model_config"]:
try:
setattr(cls, attr, getattr(source, attr))
except AttributeError as exc:
if attr in ("__pydantic_fields__", "model_fields"):
version = "< 2.10" if attr == "__pydantic_fields__" else ">= 2.10"
logging.debug(f"Model attribute {attr} not replaced, expected with pydantic {version}")
unhandled[attr] = exc
else:
raise exc

# if neither of the attributes were replaced, raise an error to prevent silent failures
if set(unhandled.keys()) >= {"model_fields", "__pydantic_fields__"}:
raise ValueError("Failed to replace either model_fields or __pydantic_fields__") from (
unhandled.get("model_fields") or unhandled.get("__pydantic_fields__")
)

swap_attributes(new_signature)
cls.model_rebuild(force=True)

yield
for field in replace_fields:
setattr(cls, field, getattr(OldSignature, field))

swap_attributes(OldSignature)
cls.model_rebuild(force=True)


Expand Down Expand Up @@ -383,7 +404,7 @@ def _parse_type_node(node, names=None) -> Any:
without using structural pattern matching introduced in Python 3.10.
"""

if names is None:
names = typing.__dict__

Expand All @@ -401,7 +422,7 @@ def _parse_type_node(node, names=None) -> Any:
id_ = node.id
if id_ in names:
return names[id_]

for type_ in [int, str, float, bool, list, tuple, dict, Image]:
if type_.__name__ == id_:
return type_
Expand All @@ -420,7 +441,7 @@ def _parse_type_node(node, names=None) -> Any:
keys = [kw.arg for kw in node.keywords]
values = [kw.value.value for kw in node.keywords]
return Field(**dict(zip(keys, values)))

if isinstance(node, ast.Attribute) and node.attr == "Image":
return Image

Expand Down
5 changes: 3 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
black==24.2.0
datamodel-code-generator==0.26.3
litellm[proxy]==1.51.0
pillow==10.4.0
pre-commit==3.7.0
pytest==8.3.3
pytest-env==1.1.3
pytest-mock==3.12.0
ruff==0.3.0
torch==2.2.1
transformers==4.38.2
pillow==10.4.0
litellm[proxy]==1.51.0

0 comments on commit ff6f5a8

Please sign in to comment.