Skip to content

Commit

Permalink
cleanup model rebuilding
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed May 2, 2024
1 parent 50621b5 commit 57525e1
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
6 changes: 4 additions & 2 deletions hugr-py/src/hugr/serialization/serial_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import BaseModel, Field, ConfigDict

from .ops import NodeID, OpType, classes
from .ops import NodeID, OpType, classes as ops_classes
from .tys import model_rebuild
import hugr

Expand Down Expand Up @@ -37,7 +37,9 @@ def get_version(cls) -> str:

@classmethod
def _pydantic_rebuild(cls, config: ConfigDict = ConfigDict(), **kwargs):
model_rebuild([(cls.__name__, cls)] + classes, config=config, **kwargs)
my_classes = dict(ops_classes)
my_classes[cls.__name__] = cls
model_rebuild(my_classes, config=config, **kwargs)

class Config:
title = "Hugr"
Expand Down
6 changes: 4 additions & 2 deletions hugr-py/src/hugr/serialization/testing_hugr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic import ConfigDict
from typing import Literal
from .tys import Type, SumType, PolyFuncType, ConfiguredBaseModel, model_rebuild
from .ops import Value, OpType, classes
from .ops import Value, OpType, classes as ops_classes


class TestingHugr(ConfiguredBaseModel):
Expand All @@ -22,7 +22,9 @@ def get_version(cls) -> str:

@classmethod
def _pydantic_rebuild(cls, config: ConfigDict = ConfigDict(), **kwargs):
model_rebuild([(cls.__name__, cls)] + classes, config=config, **kwargs)
my_classes = dict(ops_classes)
my_classes[cls.__name__] = cls
model_rebuild(my_classes, config=config, **kwargs)

class Config:
title = "HugrTesting"
8 changes: 4 additions & 4 deletions hugr-py/src/hugr/serialization/tys.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import sys
from enum import Enum
from typing import Annotated, Any, Literal, Optional, Union
from typing import Annotated, Any, Literal, Optional, Union, Mapping

from pydantic import (
BaseModel,
Expand Down Expand Up @@ -354,16 +354,16 @@ class Signature(ConfiguredBaseModel):


def model_rebuild(
classes: list[tuple[str, Any]],
classes: Mapping[str, type],
config: ConfigDict = ConfigDict(),
**kwargs,
):
new_config = default_model_config.copy()
new_config.update(config)
for c in {k: v for (k, v) in classes}.values():
for c in classes.values():
if issubclass(c, ConfiguredBaseModel):
c.set_model_config(new_config)
c.model_rebuild(**kwargs)


model_rebuild(classes)
model_rebuild(dict(classes))

0 comments on commit 57525e1

Please sign in to comment.