Skip to content

Commit

Permalink
refactor(hugr-py)!: make serialization (module/methods) private (#1477)
Browse files Browse the repository at this point in the history
Closes #1464

BREAKING CHANGE: `hugr.serialization` module and `to_serial` methods are
now internal only.
  • Loading branch information
ss2165 committed Aug 28, 2024
1 parent 0dc2c9c commit 49a5bad
Show file tree
Hide file tree
Showing 24 changed files with 814 additions and 815 deletions.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class BaseValue(ABC, ConfiguredBaseModel):
def deserialize(self) -> val.Value: ...


class ExtensionValue(BaseValue):
class CustomValue(BaseValue):
"""An extension constant value, that can check it is of a given [CustomType]."""

v: Literal["Extension"] = Field(default="Extension", title="ValueTag")
Expand All @@ -127,11 +127,11 @@ class FunctionValue(BaseValue):
hugr: Any

def deserialize(self) -> val.Value:
from hugr._serialization.serial_hugr import SerialHugr
from hugr.hugr import Hugr
from hugr.serialization.serial_hugr import SerialHugr

# pydantic stores the serialized dictionary because of the "Any" annotation
return val.Function(Hugr.from_serial(SerialHugr(**self.hugr)))
return val.Function(Hugr._from_serial(SerialHugr(**self.hugr)))


class TupleValue(BaseValue):
Expand Down Expand Up @@ -172,9 +172,7 @@ def deserialize(self) -> val.Value:
class Value(RootModel):
"""A constant Value."""

root: ExtensionValue | FunctionValue | TupleValue | SumValue = Field(
discriminator="v"
)
root: CustomValue | FunctionValue | TupleValue | SumValue = Field(discriminator="v")

model_config = ConfigDict(json_schema_extra={"required": ["v"]})

Expand Down Expand Up @@ -501,7 +499,7 @@ def deserialize(self) -> ops.CFG:
ControlFlowOp = Conditional | TailLoop | CFG


class Extension(DataflowOp):
class ExtensionOp(DataflowOp):
"""A user-defined operation that can be downcasted by the extensions that define
it.
"""
Expand Down Expand Up @@ -649,7 +647,7 @@ class OpType(RootModel):
| CallIndirect
| LoadConstant
| LoadFunction
| Extension
| ExtensionOp
| Noop
| MakeTuple
| UnpackTuple
Expand Down
File renamed without changes.
46 changes: 23 additions & 23 deletions hugr-py/src/hugr/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from semver import Version

import hugr.serialization.extension as ext_s
import hugr._serialization.extension as ext_s
from hugr import ops, tys, val
from hugr.utils import ser_it

Expand Down Expand Up @@ -43,11 +43,11 @@ class ExplicitBound:

bound: tys.TypeBound

def to_serial(self) -> ext_s.ExplicitBound:
def _to_serial(self) -> ext_s.ExplicitBound:
return ext_s.ExplicitBound(bound=self.bound)

def to_serial_root(self) -> ext_s.TypeDefBound:
return ext_s.TypeDefBound(root=self.to_serial())
def _to_serial_root(self) -> ext_s.TypeDefBound:
return ext_s.TypeDefBound(root=self._to_serial())


@dataclass
Expand All @@ -63,11 +63,11 @@ class FromParamsBound:

indices: list[int]

def to_serial(self) -> ext_s.FromParamsBound:
def _to_serial(self) -> ext_s.FromParamsBound:
return ext_s.FromParamsBound(indices=self.indices)

def to_serial_root(self) -> ext_s.TypeDefBound:
return ext_s.TypeDefBound(root=self.to_serial())
def _to_serial_root(self) -> ext_s.TypeDefBound:
return ext_s.TypeDefBound(root=self._to_serial())


@dataclass
Expand Down Expand Up @@ -128,13 +128,13 @@ class TypeDef(ExtensionObject):
#: The type bound of the type.
bound: ExplicitBound | FromParamsBound

def to_serial(self) -> ext_s.TypeDef:
def _to_serial(self) -> ext_s.TypeDef:
return ext_s.TypeDef(
extension=self.get_extension().name,
name=self.name,
description=self.description,
params=ser_it(self.params),
bound=ext_s.TypeDefBound(root=self.bound.to_serial()),
bound=ext_s.TypeDefBound(root=self.bound._to_serial()),
)

def instantiate(self, args: Sequence[tys.TypeArg]) -> tys.ExtType:
Expand All @@ -155,7 +155,7 @@ class FixedHugr:
#: HUGR defining operation lowering.
hugr: Hugr

def to_serial(self) -> ext_s.FixedHugr:
def _to_serial(self) -> ext_s.FixedHugr:
return ext_s.FixedHugr(extensions=self.extensions, hugr=self.hugr)


Expand Down Expand Up @@ -200,17 +200,17 @@ class OpDef(ExtensionObject):
#: Lowerings of the operation.
lower_funcs: list[FixedHugr] = field(default_factory=list, repr=False)

def to_serial(self) -> ext_s.OpDef:
def _to_serial(self) -> ext_s.OpDef:
return ext_s.OpDef(
extension=self.get_extension().name,
name=self.name,
description=self.description,
misc=self.misc,
signature=self.signature.poly_func.to_serial()
signature=self.signature.poly_func._to_serial()
if self.signature.poly_func
else None,
binary=self.signature.binary,
lower_funcs=[f.to_serial() for f in self.lower_funcs],
lower_funcs=[f._to_serial() for f in self.lower_funcs],
)


Expand All @@ -223,11 +223,11 @@ class ExtensionValue(ExtensionObject):
#: Value payload.
val: val.Value

def to_serial(self) -> ext_s.ExtensionValue:
def _to_serial(self) -> ext_s.ExtensionValue:
return ext_s.ExtensionValue(
extension=self.get_extension().name,
name=self.name,
typed_value=self.val.to_serial_root(),
typed_value=self.val._to_serial_root(),
)


Expand Down Expand Up @@ -257,14 +257,14 @@ class NotFound(Exception):

name: str

def to_serial(self) -> ext_s.Extension:
def _to_serial(self) -> ext_s.Extension:
return ext_s.Extension(
name=self.name,
version=self.version, # type: ignore[arg-type]
extension_reqs=self.extension_reqs,
types={k: v.to_serial() for k, v in self.types.items()},
values={k: v.to_serial() for k, v in self.values.items()},
operations={k: v.to_serial() for k, v in self.operations.items()},
types={k: v._to_serial() for k, v in self.types.items()},
values={k: v._to_serial() for k, v in self.values.items()},
operations={k: v._to_serial() for k, v in self.operations.items()},
)

def add_op_def(self, op_def: OpDef) -> OpDef:
Expand Down Expand Up @@ -465,11 +465,11 @@ class Package:
#: Extensions included in the package.
extensions: list[Extension] = field(default_factory=list)

def to_serial(self) -> ext_s.Package:
def _to_serial(self) -> ext_s.Package:
return ext_s.Package(
modules=[m.to_serial() for m in self.modules],
extensions=[e.to_serial() for e in self.extensions],
modules=[m._to_serial() for m in self.modules],
extensions=[e._to_serial() for e in self.extensions],
)

def to_json(self) -> str:
return self.to_serial().model_dump_json()
return self._to_serial().model_dump_json()
18 changes: 9 additions & 9 deletions hugr-py/src/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
overload,
)

from hugr._serialization.ops import OpType as SerialOp
from hugr._serialization.serial_hugr import SerialHugr
from hugr.node_port import (
Direction,
InPort,
Expand All @@ -26,8 +28,6 @@
_SubPort,
)
from hugr.ops import Call, Const, Custom, DataflowOp, Module, Op
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.tys import Kind, Type, ValueKind
from hugr.utils import BiMap
from hugr.val import Value
Expand All @@ -54,8 +54,8 @@ class NodeData:
children: list[Node] = field(default_factory=list, repr=False)
metadata: dict[str, Any] = field(default_factory=dict)

def to_serial(self, node: Node) -> SerialOp:
o = self.op.to_serial(self.parent if self.parent else node)
def _to_serial(self, node: Node) -> SerialOp:
o = self.op._to_serial(self.parent if self.parent else node)

return SerialOp(root=o) # type: ignore[arg-type]

Expand Down Expand Up @@ -601,7 +601,7 @@ def insert_hugr(self, hugr: Hugr, parent: ToNode | None = None) -> dict[Node, No
)
return mapping

def to_serial(self) -> SerialHugr:
def _to_serial(self) -> SerialHugr:
"""Serialize the HUGR."""
node_it = (node for node in self._nodes if node is not None)

Expand All @@ -614,7 +614,7 @@ def _serialize_link(

return SerialHugr(
# non contiguous indices will be erased
nodes=[node.to_serial(Node(idx, {})) for idx, node in enumerate(node_it)],
nodes=[node._to_serial(Node(idx, {})) for idx, node in enumerate(node_it)],
edges=[_serialize_link(link) for link in self._links.items()],
metadata=[node.metadata if node.metadata else None for node in node_it],
)
Expand Down Expand Up @@ -644,7 +644,7 @@ def resolve_extensions(self, registry: ext.ExtensionRegistry) -> Hugr:
return self

@classmethod
def from_serial(cls, serial: SerialHugr) -> Hugr:
def _from_serial(cls, serial: SerialHugr) -> Hugr:
"""Load a HUGR from a serialized form."""
assert serial.nodes, "Empty Hugr is invalid"

Expand Down Expand Up @@ -685,14 +685,14 @@ def get_meta(idx: int) -> dict[str, Any]:

def to_json(self) -> str:
"""Serialize the HUGR to a JSON string."""
return self.to_serial().to_json()
return self._to_serial().to_json()

@classmethod
def load_json(cls, json_str: str) -> Hugr:
"""Deserialize a JSON string into a HUGR."""
json_dict = json.loads(json_str)
serial = SerialHugr.load_json(json_dict)
return cls.from_serial(serial)
return cls._from_serial(serial)

def render_dot(self, palette: str | None = None) -> gv.Digraph:
"""Render the HUGR to a graphviz Digraph.
Expand Down
Loading

0 comments on commit 49a5bad

Please sign in to comment.