From a85d481ea1515e8355b5f4c9560e56ca7ed78292 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Fri, 15 Jul 2022 22:11:11 +0100 Subject: [PATCH] Created SumOfProducts subclass of ControlValues (#5755) This class allows the creation of control values that can't be factored into simple products hence solving #4512 --- cirq/__init__.py | 1 + cirq/json_resolver_cache.py | 1 + cirq/ops/__init__.py | 2 +- cirq/ops/common_gates.py | 28 ++++ cirq/ops/control_values.py | 124 ++++++++++++++++-- cirq/ops/control_values_test.py | 95 +++++++++++++- cirq/ops/controlled_gate_test.py | 3 - .../json_test_data/SumOfProducts.json | 4 + .../json_test_data/SumOfProducts.repr | 1 + 9 files changed, 241 insertions(+), 18 deletions(-) create mode 100644 cirq/protocols/json_test_data/SumOfProducts.json create mode 100644 cirq/protocols/json_test_data/SumOfProducts.repr diff --git a/cirq/__init__.py b/cirq/__init__.py index 85211ad664c..d966c8fade2 100644 --- a/cirq/__init__.py +++ b/cirq/__init__.py @@ -301,6 +301,7 @@ SQRT_ISWAP_INV, SWAP, SwapPowGate, + SumOfProducts, T, TaggedOperation, ThreeQubitDiagonalGate, diff --git a/cirq/json_resolver_cache.py b/cirq/json_resolver_cache.py index bc50de01da1..c8be4911e56 100644 --- a/cirq/json_resolver_cache.py +++ b/cirq/json_resolver_cache.py @@ -213,6 +213,7 @@ def _symmetricalqidpair(qids): 'SqrtIswapTargetGateset': cirq.SqrtIswapTargetGateset, 'StabilizerStateChForm': cirq.StabilizerStateChForm, 'StatePreparationChannel': cirq.StatePreparationChannel, + 'SumOfProducts': cirq.SumOfProducts, 'SwapPowGate': cirq.SwapPowGate, 'SympyCondition': cirq.SympyCondition, 'TaggedOperation': cirq.TaggedOperation, diff --git a/cirq/ops/__init__.py b/cirq/ops/__init__.py index 6e4ff6d8393..3be21389028 100644 --- a/cirq/ops/__init__.py +++ b/cirq/ops/__init__.py @@ -210,4 +210,4 @@ from cirq.ops.state_preparation_channel import StatePreparationChannel -from cirq.ops.control_values import AbstractControlValues, ProductOfSums +from cirq.ops.control_values import AbstractControlValues, ProductOfSums, SumOfProducts diff --git a/cirq/ops/common_gates.py b/cirq/ops/common_gates.py index 7ec501db1f4..dfbe1716622 100644 --- a/cirq/ops/common_gates.py +++ b/cirq/ops/common_gates.py @@ -233,10 +233,17 @@ def controlled( A `cirq.ControlledGate` (or `cirq.CXPowGate` if possible) representing `self` controlled by the given control values and qubits. """ + if control_values and not isinstance(control_values, cv.AbstractControlValues): + control_values = cv.ProductOfSums( + tuple( + (val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values + ) + ) result = super().controlled(num_controls, control_values, control_qid_shape) if ( self._global_shift == 0 and isinstance(result, controlled_gate.ControlledGate) + and isinstance(result.control_values, cv.ProductOfSums) and result.control_values[-1] == (1,) and result.control_qid_shape[-1] == 2 ): @@ -680,10 +687,17 @@ def controlled( A `cirq.ControlledGate` (or `cirq.CZPowGate` if possible) representing `self` controlled by the given control values and qubits. """ + if control_values and not isinstance(control_values, cv.AbstractControlValues): + control_values = cv.ProductOfSums( + tuple( + (val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values + ) + ) result = super().controlled(num_controls, control_values, control_qid_shape) if ( self._global_shift == 0 and isinstance(result, controlled_gate.ControlledGate) + and isinstance(result.control_values, cv.ProductOfSums) and result.control_values[-1] == (1,) and result.control_qid_shape[-1] == 2 ): @@ -1116,10 +1130,17 @@ def controlled( A `cirq.ControlledGate` (or `cirq.CCZPowGate` if possible) representing `self` controlled by the given control values and qubits. """ + if control_values and not isinstance(control_values, cv.AbstractControlValues): + control_values = cv.ProductOfSums( + tuple( + (val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values + ) + ) result = super().controlled(num_controls, control_values, control_qid_shape) if ( self._global_shift == 0 and isinstance(result, controlled_gate.ControlledGate) + and isinstance(result.control_values, cv.ProductOfSums) and result.control_values[-1] == (1,) and result.control_qid_shape[-1] == 2 ): @@ -1315,10 +1336,17 @@ def controlled( A `cirq.ControlledGate` (or `cirq.CCXPowGate` if possible) representing `self` controlled by the given control values and qubits. """ + if control_values and not isinstance(control_values, cv.AbstractControlValues): + control_values = cv.ProductOfSums( + tuple( + (val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values + ) + ) result = super().controlled(num_controls, control_values, control_qid_shape) if ( self._global_shift == 0 and isinstance(result, controlled_gate.ControlledGate) + and isinstance(result.control_values, cv.ProductOfSums) and result.control_values[-1] == (1,) and result.control_qid_shape[-1] == 2 ): diff --git a/cirq/ops/control_values.py b/cirq/ops/control_values.py index cffeee4ea0f..c83742ada6f 100644 --- a/cirq/ops/control_values.py +++ b/cirq/ops/control_values.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import Union, Tuple, List, TYPE_CHECKING, Any, Dict, Generator, cast, Iterator +from typing import Union, Tuple, List, TYPE_CHECKING, Any, Dict, Generator, cast, Iterator, Optional from dataclasses import dataclass import itertools @@ -53,7 +53,7 @@ def _expand(self) -> Iterator[Tuple[int, ...]]: """Expands the (possibly compressed) internal representation into a sum of products representation.""" # pylint: disable=line-too-long @abc.abstractmethod - def diagram_repr(self) -> str: + def diagram_repr(self, label: Optional[str] = None) -> str: """Returns a string representation to be used in circuit diagrams.""" @abc.abstractmethod @@ -92,12 +92,6 @@ def _are_ones(self) -> bool: def _json_dict_(self) -> Dict[str, Any]: pass - @abc.abstractmethod - def __getitem__( - self, key: Union[slice, int] - ) -> Union['AbstractControlValues', Tuple[int, ...]]: - pass - def __iter__(self) -> Generator[Tuple[int, ...], None, None]: for assignment in self._expand(): yield assignment @@ -154,7 +148,9 @@ def validate(self, qid_shapes: Union[Tuple[int, ...], List[int]]) -> None: def _are_ones(self) -> bool: return frozenset(self._internal_representation) == {(1,)} - def diagram_repr(self) -> str: + def diagram_repr(self, label: Optional[str] = None) -> str: + if label: + return label if self._are_ones(): return 'C' * self._number_variables() @@ -174,9 +170,117 @@ def __getitem__( def _json_dict_(self) -> Dict[str, Any]: return {'_internal_representation': self._internal_representation} - def __and__(self, other: AbstractControlValues) -> 'ProductOfSums': + def __and__(self, other: AbstractControlValues) -> AbstractControlValues: + if isinstance(other, SumOfProducts): + return SumOfProducts(tuple(p for p in self)) & other if not isinstance(other, ProductOfSums): raise TypeError( f'And operation not supported between types ProductOfSums and {type(other)}' ) return type(self)(self._internal_representation + other._internal_representation) + + +@dataclass(frozen=True, eq=False) +class SumOfProducts(AbstractControlValues): + """Represents control values as a union of n-bit tuples. + + `SumOfProducts` representation describes the control values as a union + of n-bit tuples, where each n-bit tuple represents an allowed assignment + of bits for which the control should be activated. This expanded + representation allows us to create control values combinations which + cannot be factored as a `ProductOfSums` representation. + + For example: + + 1) `(|00><00| + |11><11|) X + (|01><01| + |10><10|) I` represents an + operator which flips the third qubit if the first two qubits + are `00` or `11`, and does nothing otherwise. + This can be constructed as + >>> xor_control_values = cirq.SumOfProducts(((0, 0), (1, 1))) + >>> q0, q1, q2 = cirq.LineQubit.range(3) + >>> xor_cop = cirq.X(q2).controlled_by(q0, q1, control_values=xor_control_values) + + 2) `(|00><00| + |01><01| + |10><10|) X + (|11><11|) I` represents an + operators which flips the third qubit if the `nand` of first two + qubits is `1` (i.e. first two qubits are either `00`, `01` or `10`), + and does nothing otherwise. This can be constructed as: + + >>> nand_control_values = cirq.SumOfProducts(((0, 0), (0, 1), (1, 0))) + >>> q0, q1, q2 = cirq.LineQubit.range(3) + >>> nan_cop = cirq.X(q2).controlled_by(q0, q1, control_values=nand_control_values) + """ + + _internal_representation: Tuple[Tuple[int, ...], ...] + + def __post_init__(self): + if not len(self._internal_representation): + raise ValueError('SumOfProducts can\'t be empty.') + num_qubits = len(self._internal_representation[0]) + for p in self._internal_representation: + if len(p) != num_qubits: + raise ValueError( + f'size mismatch between different products of {self._internal_representation}' + ) + if len(self._internal_representation) != len( + set(map(tuple, self._internal_representation)) + ): + raise ValueError('SumOfProducts can\'t have duplicate products.') + + def _identifier(self) -> Tuple[Tuple[int, ...], ...]: + return self._internal_representation + + def _expand(self) -> Iterator[Tuple[int, ...]]: + """Returns the combinations tracked by the object.""" + self = cast('SumOfProducts', self) + return iter(self._internal_representation) + + def __repr__(self) -> str: + return f'cirq.SumOfProducts({str(self._identifier())})' + + def _number_variables(self) -> int: + return len(self._internal_representation[0]) + + def __len__(self) -> int: + return self._number_variables() + + def __hash__(self) -> int: + return hash(self._internal_representation) + + def validate(self, qid_shapes: Union[Tuple[int, ...], List[int]]) -> None: + for vals in self._internal_representation: + if len(qid_shapes) != len(vals): + raise ValueError( + f'number of values in product {vals} doesn\'t equal number' + f' of qubits(={len(qid_shapes)})' + ) + + for i, v in enumerate(vals): + if not (0 <= v and v < qid_shapes[i]): + raise ValueError( + f'Control values <{v}> in combination {vals} is outside' + f' of range for control qubit number <{i}>.' + ) + + def _are_ones(self) -> bool: + return frozenset(self._internal_representation) == {(1,) * self._number_variables()} + + def diagram_repr(self, label: Optional[str] = None) -> str: + if label: + return label + return ','.join(map(lambda p: ''.join(map(str, p)), self._internal_representation)) + + def _json_dict_(self) -> Dict[str, Any]: + return {'_internal_representation': self._internal_representation} + + def __and__(self, other: AbstractControlValues) -> 'SumOfProducts': + if isinstance(other, ProductOfSums): + other = SumOfProducts(tuple(p for p in other)) + if not isinstance(other, SumOfProducts): + raise TypeError( + f'And operation not supported between types SumOfProducts and {type(other)}' + ) + combined = map( + lambda p: tuple(itertools.chain(*p)), + itertools.product(self._internal_representation, other._internal_representation), + ) + return SumOfProducts(tuple(combined)) diff --git a/cirq/ops/control_values_test.py b/cirq/ops/control_values_test.py index f2c482d8526..bc773fbea7f 100644 --- a/cirq/ops/control_values_test.py +++ b/cirq/ops/control_values_test.py @@ -25,22 +25,63 @@ def test_init_productOfSum(): ((((0, 1), (1, 0))), {(0, 0), (0, 1), (1, 0), (1, 1)}), ] for control_values, want in tests: - print(control_values) got = {c for c in cv.ProductOfSums(control_values)} eq.add_equality_group(got, want) +def test_init_SumOfProducts(): + eq = cirq.testing.EqualsTester() + tests = [ + (((1,),), {(1,)}), + (((0, 1), (1, 0)), {(0, 1), (1, 0)}), # XOR + (((0, 0), (0, 1), (1, 0)), {(0, 0), (0, 1), (1, 0)}), # NAND + ] + for control_values, want in tests: + got = {c for c in cv.SumOfProducts(control_values)} + eq.add_equality_group(got, want) + + with pytest.raises(ValueError): + _ = cv.SumOfProducts([]) + + # size mismatch + with pytest.raises(ValueError): + _ = cv.SumOfProducts([[1], [1, 0]]) + + # can't have duplicates + with pytest.raises(ValueError): + _ = cv.SumOfProducts([[1, 0], [0, 1], [1, 0]]) + + def test_and_operation(): eq = cirq.testing.EqualsTester() - originals = [((1,),), ((0, 1), (1,)), (((0, 1), (1, 0)))] - for control_values1 in originals: - for control_values2 in originals: + product_of_sums_data = [((1,),), ((0, 1), (1,)), (((0, 1), (1, 0)))] + for control_values1 in product_of_sums_data: + for control_values2 in product_of_sums_data: control_vals1 = cv.ProductOfSums(control_values1) control_vals2 = cv.ProductOfSums(control_values2) want = [v1 + v2 for v1 in control_vals1 for v2 in control_vals2] got = [c for c in control_vals1 & control_vals2] eq.add_equality_group(got, want) + sum_of_products_data = [((1,),), ((0, 1),), ((0, 0), (0, 1), (1, 0))] + eq = cirq.testing.EqualsTester() + for control_values1 in sum_of_products_data: + for control_values2 in sum_of_products_data: + control_vals1 = cv.SumOfProducts(control_values1) + control_vals2 = cv.SumOfProducts(control_values2) + want = [v1 + v2 for v1 in control_vals1 for v2 in control_vals2] + got = [c for c in control_vals1 & control_vals2] + eq.add_equality_group(got, want) + + pos = cv.ProductOfSums(((1,), (0,))) + sop = cv.SumOfProducts(((1, 0), (0, 1))) + assert tuple(p for p in pos & sop) == ((1, 0, 1, 0), (1, 0, 0, 1)) + + assert tuple(p for p in sop & pos) == ((1, 0, 1, 0), (0, 1, 1, 0)) + + with pytest.raises(TypeError): + _ = sop & 1 + def test_and_supported_types(): CV = cv.ProductOfSums((1,)) @@ -52,3 +93,49 @@ def test_repr(): product_of_sums_data = [((1,),), ((0, 1), (1,)), (((0, 1), (1, 0)))] for t in map(cv.ProductOfSums, product_of_sums_data): cirq.testing.assert_equivalent_repr(t) + + sum_of_products_data = [((1,),), ((0, 1),), ((0, 0), (0, 1), (1, 0))] + for t in map(cv.SumOfProducts, sum_of_products_data): + cirq.testing.assert_equivalent_repr(t) + + +def test_validate(): + control_val = cv.SumOfProducts(((1, 2), (0, 1))) + + _ = control_val.validate([2, 3]) + + with pytest.raises(ValueError): + _ = control_val.validate([2, 2]) + + # number of qubits != number of control values. + with pytest.raises(ValueError): + _ = control_val.validate([2]) + + +def test_len(): + data = [((1,),), ((0, 1),), ((0, 0), (0, 1), (1, 0))] + for vals in data: + c = cv.SumOfProducts(vals) + assert len(c) == len(vals[0]) + + +def test_hash(): + data = [((1,),), ((0, 1),), ((0, 0), (0, 1), (1, 0))] + assert len(set(map(hash, map(cv.SumOfProducts, data)))) == 3 + + +def test_are_ones(): + data = [((1,),), ((0, 1),), ((0, 0), (0, 1), (1, 0)), ((1, 1, 1, 1),)] + are_ones = [True, False, False] + for vals, want in zip(data, are_ones): + c = cv.SumOfProducts(vals) + assert c._are_ones() == want + + +def test_diagram_repr(): + c = cv.SumOfProducts(((1, 0), (0, 1))) + assert c.diagram_repr() == '10,01' + + assert c.diagram_repr('xor') == 'xor' + + assert cv.ProductOfSums(((1,), (0,))).diagram_repr('10') == '10' diff --git a/cirq/ops/controlled_gate_test.py b/cirq/ops/controlled_gate_test.py index d1669e69d30..a345e013eba 100644 --- a/cirq/ops/controlled_gate_test.py +++ b/cirq/ops/controlled_gate_test.py @@ -633,9 +633,6 @@ def _are_ones(self): def _json_dict_(self): pass - def __getitem__(self, idx): - pass - def test_decompose_applies_only_to_ProductOfSums(): g = cirq.ControlledGate(cirq.X, control_values=MockControlValues()) diff --git a/cirq/protocols/json_test_data/SumOfProducts.json b/cirq/protocols/json_test_data/SumOfProducts.json new file mode 100644 index 00000000000..27653f60b61 --- /dev/null +++ b/cirq/protocols/json_test_data/SumOfProducts.json @@ -0,0 +1,4 @@ +{ + "_internal_representation": [[1, 0], [1, 2]], + "cirq_type": "SumOfProducts" +} diff --git a/cirq/protocols/json_test_data/SumOfProducts.repr b/cirq/protocols/json_test_data/SumOfProducts.repr new file mode 100644 index 00000000000..d487d5ef040 --- /dev/null +++ b/cirq/protocols/json_test_data/SumOfProducts.repr @@ -0,0 +1 @@ +cirq.SumOfProducts([[1, 0], [1, 2]]) \ No newline at end of file