Skip to content

Commit

Permalink
Support and test type serialization (quantumlib#4693)
Browse files Browse the repository at this point in the history
Prior to this PR, Cirq supported serialization of _instances_ of Cirq types, but not the types themselves. This PR adds serialization support for Cirq types, with the format:
```
{
    'cirq_type': 'type',
    'typename': $NAME
}
```
where `$NAME` is the `cirq_type` of the object in its JSON representation. For type T, `$NAME` is usually `T.__name__`, but some types (mostly in `cirq_google`) do not follow this rule. The `json_cirq_type` protocol and `_json_cirq_type_` magic method are provided to handle this.

It is worth noting that this PR explicitly **does not** support serialization of non-Cirq types (e.g. python builtins, sympy and numpy objects) despite instances of these objects being serializable in Cirq. This support can be added to `json_cirq_type` and `_cirq_object_hook` in `json_serialization.py` if we decide it is necessary; I left it out of this PR as it is not required by the motivating changes behind this PR (quantumlib#4640 and sub-PRs).
  • Loading branch information
95-martin-orion authored Nov 23, 2021
1 parent 32bd4c1 commit 2576520
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 12 deletions.
4 changes: 4 additions & 0 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@
circuit_diagram_info,
CircuitDiagramInfo,
CircuitDiagramInfoArgs,
cirq_type_from_json,
commutes,
control_keys,
decompose,
Expand All @@ -520,10 +521,13 @@
has_mixture,
has_stabilizer_effect,
has_unitary,
HasJSONNamespace,
inverse,
is_measurement,
is_parameterized,
JsonResolver,
json_cirq_type,
json_namespace,
json_serializable_dataclass,
dataclass_json_dict,
kraus,
Expand Down
9 changes: 5 additions & 4 deletions cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def _parallel_gate_op(gate, qubits):
'HPowGate': cirq.HPowGate,
'ISwapPowGate': cirq.ISwapPowGate,
'IdentityGate': cirq.IdentityGate,
'IdentityOperation': _identity_operation_from_dict,
'InitObsSetting': cirq.work.InitObsSetting,
'KrausChannel': cirq.KrausChannel,
'LinearDict': cirq.LinearDict,
Expand All @@ -115,7 +114,6 @@ def _parallel_gate_op(gate, qubits):
'_PauliY': cirq.ops.pauli_gates._PauliY,
'_PauliZ': cirq.ops.pauli_gates._PauliZ,
'ParamResolver': cirq.ParamResolver,
'ParallelGateOperation': _parallel_gate_op, # Removed in v0.14
'ParallelGate': cirq.ParallelGate,
'PauliMeasurementGate': cirq.PauliMeasurementGate,
'PauliString': cirq.PauliString,
Expand All @@ -134,7 +132,6 @@ def _parallel_gate_op(gate, qubits):
'RepetitionsStoppingCriteria': cirq.work.RepetitionsStoppingCriteria,
'ResetChannel': cirq.ResetChannel,
'SingleQubitCliffordGate': cirq.SingleQubitCliffordGate,
'SingleQubitMatrixGate': single_qubit_matrix_gate,
'SingleQubitPauliStringGateOperation': cirq.SingleQubitPauliStringGateOperation,
'SingleQubitReadoutCalibrationResult': cirq.experiments.SingleQubitReadoutCalibrationResult,
'StabilizerStateChForm': cirq.StabilizerStateChForm,
Expand All @@ -147,7 +144,6 @@ def _parallel_gate_op(gate, qubits):
'Rx': cirq.Rx,
'Ry': cirq.Ry,
'Rz': cirq.Rz,
'TwoQubitMatrixGate': two_qubit_matrix_gate,
'_UnconstrainedDevice': cirq.devices.unconstrained_device._UnconstrainedDevice,
'VarianceStoppingCriteria': cirq.work.VarianceStoppingCriteria,
'VirtualTag': cirq.VirtualTag,
Expand All @@ -163,6 +159,11 @@ def _parallel_gate_op(gate, qubits):
'YYPowGate': cirq.YYPowGate,
'ZPowGate': cirq.ZPowGate,
'ZZPowGate': cirq.ZZPowGate,
# Old types, only supported for backwards-compatibility
'IdentityOperation': _identity_operation_from_dict,
'ParallelGateOperation': _parallel_gate_op, # Removed in v0.14
'SingleQubitMatrixGate': single_qubit_matrix_gate,
'TwoQubitMatrixGate': two_qubit_matrix_gate,
# not a cirq class, but treated as one:
'pandas.DataFrame': pd.DataFrame,
'pandas.Index': pd.Index,
Expand Down
4 changes: 4 additions & 0 deletions cirq/protocols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,13 @@
inverse,
)
from cirq.protocols.json_serialization import (
cirq_type_from_json,
DEFAULT_RESOLVERS,
HasJSONNamespace,
JsonResolver,
json_serializable_dataclass,
json_cirq_type,
json_namespace,
to_json_gzip,
read_json_gzip,
to_json,
Expand Down
116 changes: 109 additions & 7 deletions cirq/protocols/json_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,23 @@ def _json_dict_(self) -> Union[None, NotImplementedType, Dict[Any, Any]]:
pass


class HasJSONNamespace(Protocol):
"""An object which prepends a namespace to its JSON cirq_type.
Classes which implement this method have the following cirq_type format:
f"{obj._json_namespace_()}.{obj.__class__.__name__}
Classes outside of Cirq or its submodules MUST implement this method to be
used in type serialization.
"""

@doc_private
@classmethod
def _json_namespace_(cls) -> str:
pass


def obj_to_dict_helper(
obj: Any, attribute_names: Iterable[str], namespace: Optional[str] = None
) -> Dict[str, Any]:
Expand Down Expand Up @@ -350,13 +367,7 @@ def _cirq_object_hook(d, resolvers: Sequence[JsonResolver], context_map: Dict[st
if d['cirq_type'] == '_ContextualSerialization':
return _ContextualSerialization.deserialize_with_context(**d)

for resolver in resolvers:
cls = resolver(d['cirq_type'])
if cls is not None:
break
else:
raise ValueError(f"Could not resolve type '{d['cirq_type']}' during deserialization")

cls = factory_from_json(d['cirq_type'], resolvers=resolvers)
from_json_dict = getattr(cls, '_from_json_dict_', None)
if from_json_dict is not None:
return from_json_dict(**d)
Expand Down Expand Up @@ -505,6 +516,97 @@ def get_serializable_by_keys(obj: Any) -> List[SerializableByKey]:
return []


def json_namespace(type_obj: Type) -> str:
"""Returns a namespace for JSON serialization of `type_obj`.
Types can provide custom namespaces with `_json_namespace_`; otherwise, a
Cirq type will not include a namespace in its cirq_type. Non-Cirq types
must provide a namespace for serialization in Cirq.
Args:
type_obj: Type to retrieve the namespace from.
Returns:
The namespace to prepend `type_obj` with in its JSON cirq_type.
Raises:
ValueError: if `type_obj` is not a Cirq type and does not explicitly
define its namespace with _json_namespace_.
"""
if hasattr(type_obj, '_json_namespace_'):
return type_obj._json_namespace_()
if type_obj.__module__.startswith('cirq'):
return ''
raise ValueError(f'{type_obj} is not a Cirq type, and does not define _json_namespace_.')


def json_cirq_type(type_obj: Type) -> str:
"""Returns a string type for JSON serialization of `type_obj`.
This method is not part of the base serialization path. Together with
`cirq_type_from_json`, it can be used to provide type-object serialization
for classes that need it.
"""
namespace = json_namespace(type_obj)
if namespace:
return f'{namespace}.{type_obj.__name__}'
return type_obj.__name__


def factory_from_json(
type_str: str, resolvers: Optional[Sequence[JsonResolver]] = None
) -> ObjectFactory:
"""Returns a factory for constructing objects of type `type_str`.
DEFAULT_RESOLVERS is updated dynamically as cirq submodules are imported.
Args:
type_str: string representation of the type to deserialize.
resolvers: list of JsonResolvers to use in type resolution. If this is
left blank, DEFAULT_RESOLVERS will be used.
Returns:
An ObjectFactory that can be called to construct an object whose type
matches the name `type_str`.
Raises:
ValueError: if type_str does not have a match in `resolvers`.
"""
resolvers = resolvers if resolvers is not None else DEFAULT_RESOLVERS
for resolver in resolvers:
cirq_type = resolver(type_str)
if cirq_type is not None:
return cirq_type
raise ValueError(f"Could not resolve type '{type_str}' during deserialization")


def cirq_type_from_json(type_str: str, resolvers: Optional[Sequence[JsonResolver]] = None) -> Type:
"""Returns a type object for JSON deserialization of `type_str`.
This method is not part of the base deserialization path. Together with
`json_cirq_type`, it can be used to provide type-object deserialization
for classes that need it.
Args:
type_str: string representation of the type to deserialize.
resolvers: list of JsonResolvers to use in type resolution. If this is
left blank, DEFAULT_RESOLVERS will be used.
Returns:
The type object T for which json_cirq_type(T) matches `type_str`.
Raises:
ValueError: if type_str does not have a match in `resolvers`, or if the
match found is a factory method instead of a type.
"""
cirq_type = factory_from_json(type_str, resolvers)
if isinstance(cirq_type, type):
return cirq_type
# We assume that if factory_from_json returns a factory, there is not
# another resolver which resolves `type_str` to a type object.
raise ValueError(f"Type {type_str} maps to a factory method instead of a type.")


# pylint: disable=function-redefined
@overload
def to_json(
Expand Down
69 changes: 68 additions & 1 deletion cirq/protocols/json_serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pathlib
import sys
import warnings
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Type
from unittest import mock

import numpy as np
Expand Down Expand Up @@ -534,6 +534,73 @@ def test_json_test_data_coverage(mod_spec: ModuleJsonTestSpec, cirq_obj_name: st
)


@dataclasses.dataclass
class SerializableTypeObject:
test_type: Type

def _json_dict_(self):
return {
'cirq_type': 'SerializableTypeObject',
'test_type': json_serialization.json_cirq_type(self.test_type),
}

@classmethod
def _from_json_dict_(cls, test_type, **kwargs):
return cls(json_serialization.cirq_type_from_json(test_type))


@pytest.mark.parametrize(
'mod_spec,cirq_obj_name,cls',
_list_public_classes_for_tested_modules(),
)
def test_type_serialization(mod_spec: ModuleJsonTestSpec, cirq_obj_name: str, cls):
if cirq_obj_name in mod_spec.tested_elsewhere:
pytest.skip("Tested elsewhere.")

if cirq_obj_name in mod_spec.not_yet_serializable:
return pytest.xfail(reason="Not serializable (yet)")

if cls is None:
pytest.skip(f'No serialization for None-mapped type: {cirq_obj_name}')

try:
typename = cirq.json_cirq_type(cls)
except ValueError as e:
pytest.skip(f'No serialization for non-Cirq type: {str(e)}')

def custom_resolver(name):
if name == 'SerializableTypeObject':
return SerializableTypeObject

sto = SerializableTypeObject(cls)
test_resolvers = [custom_resolver] + cirq.DEFAULT_RESOLVERS
expected_json = (
f'{{\n "cirq_type": "SerializableTypeObject",\n' f' "test_type": "{typename}"\n}}'
)
assert cirq.to_json(sto) == expected_json
assert cirq.read_json(json_text=expected_json, resolvers=test_resolvers) == sto
assert_json_roundtrip_works(sto, resolvers=test_resolvers)


def test_invalid_type_deserialize():
def custom_resolver(name):
if name == 'SerializableTypeObject':
return SerializableTypeObject

test_resolvers = [custom_resolver] + cirq.DEFAULT_RESOLVERS
invalid_json = (
f'{{\n "cirq_type": "SerializableTypeObject",\n' f' "test_type": "bad_type"\n}}'
)
with pytest.raises(ValueError, match='Could not resolve type'):
_ = cirq.read_json(json_text=invalid_json, resolvers=test_resolvers)

factory_json = (
f'{{\n "cirq_type": "SerializableTypeObject",\n' f' "test_type": "sympy.Add"\n}}'
)
with pytest.raises(ValueError, match='maps to a factory method'):
_ = cirq.read_json(json_text=factory_json, resolvers=test_resolvers)


def test_to_from_strings():
x_json_text = """{
"cirq_type": "_PauliX",
Expand Down
1 change: 1 addition & 0 deletions cirq/protocols/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
'SimulatesFinalState',
'NamedTopology',
# protocols:
'HasJSONNamespace',
'SupportsActOn',
'SupportsActOnQubits',
'SupportsApplyChannel',
Expand Down

0 comments on commit 2576520

Please sign in to comment.