diff --git a/hugr-py/src/hugr/serialization/serial_hugr.py b/hugr-py/src/hugr/serialization/serial_hugr.py index 9be4aa69c..a8c104937 100644 --- a/hugr-py/src/hugr/serialization/serial_hugr.py +++ b/hugr-py/src/hugr/serialization/serial_hugr.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field, ConfigDict -from .ops import NodeID, OpType, classes +from .ops import NodeID, OpType, classes as ops_classes from .tys import model_rebuild import hugr @@ -37,7 +37,9 @@ def get_version(cls) -> str: @classmethod def _pydantic_rebuild(cls, config: ConfigDict = ConfigDict(), **kwargs): - model_rebuild([(cls.__name__, cls)] + classes, config=config, **kwargs) + my_classes = dict(ops_classes) + my_classes[cls.__name__] = cls + model_rebuild(my_classes, config=config, **kwargs) class Config: title = "Hugr" diff --git a/hugr-py/src/hugr/serialization/testing_hugr.py b/hugr-py/src/hugr/serialization/testing_hugr.py index 5bac1f114..59db4b80d 100644 --- a/hugr-py/src/hugr/serialization/testing_hugr.py +++ b/hugr-py/src/hugr/serialization/testing_hugr.py @@ -1,7 +1,7 @@ from pydantic import ConfigDict from typing import Literal from .tys import Type, SumType, PolyFuncType, ConfiguredBaseModel, model_rebuild -from .ops import Value, OpType, classes +from .ops import Value, OpType, classes as ops_classes class TestingHugr(ConfiguredBaseModel): @@ -22,7 +22,9 @@ def get_version(cls) -> str: @classmethod def _pydantic_rebuild(cls, config: ConfigDict = ConfigDict(), **kwargs): - model_rebuild([(cls.__name__, cls)] + classes, config=config, **kwargs) + my_classes = dict(ops_classes) + my_classes[cls.__name__] = cls + model_rebuild(my_classes, config=config, **kwargs) class Config: title = "HugrTesting" diff --git a/hugr-py/src/hugr/serialization/tys.py b/hugr-py/src/hugr/serialization/tys.py index bc39e56ef..b75a614cd 100644 --- a/hugr-py/src/hugr/serialization/tys.py +++ b/hugr-py/src/hugr/serialization/tys.py @@ -1,7 +1,7 @@ import inspect import sys from enum import Enum -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal, Optional, Union, Mapping from pydantic import ( BaseModel, @@ -354,16 +354,16 @@ class Signature(ConfiguredBaseModel): def model_rebuild( - classes: list[tuple[str, Any]], + classes: Mapping[str, type], config: ConfigDict = ConfigDict(), **kwargs, ): new_config = default_model_config.copy() new_config.update(config) - for c in {k: v for (k, v) in classes}.values(): + for c in classes.values(): if issubclass(c, ConfiguredBaseModel): c.set_model_config(new_config) c.model_rebuild(**kwargs) -model_rebuild(classes) +model_rebuild(dict(classes))