Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created SumOfProducts subclass of ControlValues #5755

Merged
merged 10 commits into from
Jul 15, 2022
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@
SQRT_ISWAP_INV,
SWAP,
SwapPowGate,
SumOfProducts,
T,
TaggedOperation,
ThreeQubitDiagonalGate,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def _symmetricalqidpair(qids):
'SqrtIswapTargetGateset': cirq.SqrtIswapTargetGateset,
'StabilizerStateChForm': cirq.StabilizerStateChForm,
'StatePreparationChannel': cirq.StatePreparationChannel,
'SumOfProducts': cirq.SumOfProducts,
'SwapPowGate': cirq.SwapPowGate,
'SympyCondition': cirq.SympyCondition,
'TaggedOperation': cirq.TaggedOperation,
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,4 @@

from cirq.ops.state_preparation_channel import StatePreparationChannel

from cirq.ops.control_values import AbstractControlValues, ProductOfSums
from cirq.ops.control_values import AbstractControlValues, ProductOfSums, SumOfProducts
28 changes: 28 additions & 0 deletions cirq-core/cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,17 @@ def controlled(
A `cirq.ControlledGate` (or `cirq.CXPowGate` if possible) representing
`self` controlled by the given control values and qubits.
"""
if control_values and not isinstance(control_values, cv.AbstractControlValues):
control_values = cv.ProductOfSums(
tuple(
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
)
)
result = super().controlled(num_controls, control_values, control_qid_shape)
if (
self._global_shift == 0
and isinstance(result, controlled_gate.ControlledGate)
and isinstance(result.control_values, cv.ProductOfSums)
and result.control_values[-1] == (1,)
and result.control_qid_shape[-1] == 2
):
Expand Down Expand Up @@ -680,10 +687,17 @@ def controlled(
A `cirq.ControlledGate` (or `cirq.CZPowGate` if possible) representing
`self` controlled by the given control values and qubits.
"""
if control_values and not isinstance(control_values, cv.AbstractControlValues):
control_values = cv.ProductOfSums(
tuple(
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
)
)
result = super().controlled(num_controls, control_values, control_qid_shape)
if (
self._global_shift == 0
and isinstance(result, controlled_gate.ControlledGate)
and isinstance(result.control_values, cv.ProductOfSums)
and result.control_values[-1] == (1,)
and result.control_qid_shape[-1] == 2
):
Expand Down Expand Up @@ -1116,10 +1130,17 @@ def controlled(
A `cirq.ControlledGate` (or `cirq.CCZPowGate` if possible) representing
`self` controlled by the given control values and qubits.
"""
if control_values and not isinstance(control_values, cv.AbstractControlValues):
control_values = cv.ProductOfSums(
tuple(
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
)
)
result = super().controlled(num_controls, control_values, control_qid_shape)
if (
self._global_shift == 0
and isinstance(result, controlled_gate.ControlledGate)
and isinstance(result.control_values, cv.ProductOfSums)
and result.control_values[-1] == (1,)
and result.control_qid_shape[-1] == 2
):
Expand Down Expand Up @@ -1315,10 +1336,17 @@ def controlled(
A `cirq.ControlledGate` (or `cirq.CCXPowGate` if possible) representing
`self` controlled by the given control values and qubits.
"""
if control_values and not isinstance(control_values, cv.AbstractControlValues):
control_values = cv.ProductOfSums(
tuple(
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
)
)
result = super().controlled(num_controls, control_values, control_qid_shape)
if (
self._global_shift == 0
and isinstance(result, controlled_gate.ControlledGate)
and isinstance(result.control_values, cv.ProductOfSums)
and result.control_values[-1] == (1,)
and result.control_qid_shape[-1] == 2
):
Expand Down
124 changes: 114 additions & 10 deletions cirq-core/cirq/ops/control_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Union, Tuple, List, TYPE_CHECKING, Any, Dict, Generator, cast, Iterator
from typing import Union, Tuple, List, TYPE_CHECKING, Any, Dict, Generator, cast, Iterator, Optional
from dataclasses import dataclass

import itertools
Expand Down Expand Up @@ -53,7 +53,7 @@ def _expand(self) -> Iterator[Tuple[int, ...]]:
"""Expands the (possibly compressed) internal representation into a sum of products representation.""" # pylint: disable=line-too-long

@abc.abstractmethod
def diagram_repr(self) -> str:
def diagram_repr(self, label: Optional[str] = None) -> str:
"""Returns a string representation to be used in circuit diagrams."""

@abc.abstractmethod
Expand Down Expand Up @@ -92,12 +92,6 @@ def _are_ones(self) -> bool:
def _json_dict_(self) -> Dict[str, Any]:
pass

@abc.abstractmethod
def __getitem__(
self, key: Union[slice, int]
) -> Union['AbstractControlValues', Tuple[int, ...]]:
pass

def __iter__(self) -> Generator[Tuple[int, ...], None, None]:
for assignment in self._expand():
yield assignment
Expand Down Expand Up @@ -154,7 +148,9 @@ def validate(self, qid_shapes: Union[Tuple[int, ...], List[int]]) -> None:
def _are_ones(self) -> bool:
return frozenset(self._internal_representation) == {(1,)}

def diagram_repr(self) -> str:
def diagram_repr(self, label: Optional[str] = None) -> str:
if label:
return label
if self._are_ones():
return 'C' * self._number_variables()

Expand All @@ -174,9 +170,117 @@ def __getitem__(
def _json_dict_(self) -> Dict[str, Any]:
return {'_internal_representation': self._internal_representation}

def __and__(self, other: AbstractControlValues) -> 'ProductOfSums':
def __and__(self, other: AbstractControlValues) -> AbstractControlValues:
if isinstance(other, SumOfProducts):
return SumOfProducts(tuple(p for p in self)) & other
if not isinstance(other, ProductOfSums):
raise TypeError(
f'And operation not supported between types ProductOfSums and {type(other)}'
)
return type(self)(self._internal_representation + other._internal_representation)


@dataclass(frozen=True, eq=False)
class SumOfProducts(AbstractControlValues):
"""Represents control values as a union of n-bit tuples.

`SumOfProducts` representation describes the control values as a union
of n-bit tuples, where each n-bit tuple represents an allowed assignment
of bits for which the control should be activated. This expanded
representation allows us to create control values combinations which
cannot be factored as a `ProductOfSums` representation.

For example:

1) `(|00><00| + |11><11|) X + (|01><01| + |10><10|) I` represents an
operator which flips the third qubit if the first two qubits
are `00` or `11`, and does nothing otherwise.
This can be constructed as
>>> xor_control_values = cirq.SumOfProducts(((0, 0), (1, 1)))
>>> q0, q1, q2 = cirq.LineQubit.range(3)
>>> xor_cop = cirq.X(q2).controlled_by(q0, q1, control_values=xor_control_values)

2) `(|00><00| + |01><01| + |10><10|) X + (|11><11|) I` represents an
operators which flips the third qubit if the `nand` of first two
qubits is `1` (i.e. first two qubits are either `00`, `01` or `10`),
and does nothing otherwise. This can be constructed as:

>>> nand_control_values = cirq.SumOfProducts(((0, 0), (0, 1), (1, 0)))
>>> q0, q1, q2 = cirq.LineQubit.range(3)
>>> nan_cop = cirq.X(q2).controlled_by(q0, q1, control_values=nand_control_values)
"""

_internal_representation: Tuple[Tuple[int, ...], ...]
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a post init method which verifies that length of all nested tuples in self._internal_representation is the same? This is assumed to be true in the implementation of methods like _number_variables below

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added, I also test for the uniqueness of products in it


def __post_init__(self):
if not len(self._internal_representation):
raise ValueError('SumOfProducts can\'t be empty.')
num_qubits = len(self._internal_representation[0])
for p in self._internal_representation:
if len(p) != num_qubits:
raise ValueError(
f'size mismatch between different products of {self._internal_representation}'
)
if len(self._internal_representation) != len(
set(map(tuple, self._internal_representation))
):
raise ValueError('SumOfProducts can\'t have duplicate products.')

def _identifier(self) -> Tuple[Tuple[int, ...], ...]:
return self._internal_representation

def _expand(self) -> Iterator[Tuple[int, ...]]:
"""Returns the combinations tracked by the object."""
self = cast('SumOfProducts', self)
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved
return iter(self._internal_representation)

def __repr__(self) -> str:
return f'cirq.SumOfProducts({str(self._identifier())})'

def _number_variables(self) -> int:
return len(self._internal_representation[0])

def __len__(self) -> int:
return self._number_variables()

def __hash__(self) -> int:
return hash(self._internal_representation)

def validate(self, qid_shapes: Union[Tuple[int, ...], List[int]]) -> None:
for vals in self._internal_representation:
if len(qid_shapes) != len(vals):
raise ValueError(
f'number of values in product {vals} doesn\'t equal number'
f' of qubits(={len(qid_shapes)})'
)

for i, v in enumerate(vals):
if not (0 <= v and v < qid_shapes[i]):
raise ValueError(
f'Control values <{v}> in combination {vals} is outside'
f' of range for control qubit number <{i}>.'
)

def _are_ones(self) -> bool:
return frozenset(self._internal_representation) == {(1,) * self._number_variables()}

def diagram_repr(self, label: Optional[str] = None) -> str:
if label:
return label
return ','.join(map(lambda p: ''.join(map(str, p)), self._internal_representation))

def _json_dict_(self) -> Dict[str, Any]:
return {'_internal_representation': self._internal_representation}

def __and__(self, other: AbstractControlValues) -> 'SumOfProducts':
if isinstance(other, ProductOfSums):
other = SumOfProducts(tuple(p for p in other))
if not isinstance(other, SumOfProducts):
raise TypeError(
f'And operation not supported between types SumOfProducts and {type(other)}'
)
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines +278 to +281
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not doing the the linked-list style concatenation discussed earlier? Will that come in a follow-up PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would wait until next PR, I want to introduct the SumOfProducts class before Cirq 1.0, because it has full expressive power (i.e. fixes the original issue), the linked structure can be introducted later without affecting users

the difference between the current state and what will happen when the linked structure is introduced is expressions like this ((x xor y) and (z and w)) and can now be represented by a SumOfProducts objects that has 8 products, while when we introduce the linked structure then we will only need to store 5 products (2 for the first and 3 for the second) and similarly for larger expressions where we would could in some cases store an exponenetial number of products in pseudopolynomial space.

combined = map(
lambda p: tuple(itertools.chain(*p)),
itertools.product(self._internal_representation, other._internal_representation),
)
return SumOfProducts(tuple(combined))
95 changes: 91 additions & 4 deletions cirq-core/cirq/ops/control_values_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,63 @@ def test_init_productOfSum():
((((0, 1), (1, 0))), {(0, 0), (0, 1), (1, 0), (1, 1)}),
]
for control_values, want in tests:
print(control_values)
got = {c for c in cv.ProductOfSums(control_values)}
eq.add_equality_group(got, want)


def test_init_SumOfProducts():
eq = cirq.testing.EqualsTester()
tests = [
(((1,),), {(1,)}),
(((0, 1), (1, 0)), {(0, 1), (1, 0)}), # XOR
(((0, 0), (0, 1), (1, 0)), {(0, 0), (0, 1), (1, 0)}), # NAND
]
for control_values, want in tests:
got = {c for c in cv.SumOfProducts(control_values)}
eq.add_equality_group(got, want)

with pytest.raises(ValueError):
_ = cv.SumOfProducts([])

# size mismatch
with pytest.raises(ValueError):
_ = cv.SumOfProducts([[1], [1, 0]])

# can't have duplicates
with pytest.raises(ValueError):
_ = cv.SumOfProducts([[1, 0], [0, 1], [1, 0]])


def test_and_operation():
eq = cirq.testing.EqualsTester()
originals = [((1,),), ((0, 1), (1,)), (((0, 1), (1, 0)))]
for control_values1 in originals:
for control_values2 in originals:
product_of_sums_data = [((1,),), ((0, 1), (1,)), (((0, 1), (1, 0)))]
for control_values1 in product_of_sums_data:
for control_values2 in product_of_sums_data:
control_vals1 = cv.ProductOfSums(control_values1)
control_vals2 = cv.ProductOfSums(control_values2)
want = [v1 + v2 for v1 in control_vals1 for v2 in control_vals2]
got = [c for c in control_vals1 & control_vals2]
eq.add_equality_group(got, want)

sum_of_products_data = [((1,),), ((0, 1),), ((0, 0), (0, 1), (1, 0))]
eq = cirq.testing.EqualsTester()
for control_values1 in sum_of_products_data:
for control_values2 in sum_of_products_data:
control_vals1 = cv.SumOfProducts(control_values1)
control_vals2 = cv.SumOfProducts(control_values2)
want = [v1 + v2 for v1 in control_vals1 for v2 in control_vals2]
got = [c for c in control_vals1 & control_vals2]
eq.add_equality_group(got, want)

pos = cv.ProductOfSums(((1,), (0,)))
sop = cv.SumOfProducts(((1, 0), (0, 1)))
assert tuple(p for p in pos & sop) == ((1, 0, 1, 0), (1, 0, 0, 1))

assert tuple(p for p in sop & pos) == ((1, 0, 1, 0), (0, 1, 1, 0))

with pytest.raises(TypeError):
_ = sop & 1


def test_and_supported_types():
CV = cv.ProductOfSums((1,))
Expand All @@ -52,3 +93,49 @@ def test_repr():
product_of_sums_data = [((1,),), ((0, 1), (1,)), (((0, 1), (1, 0)))]
for t in map(cv.ProductOfSums, product_of_sums_data):
cirq.testing.assert_equivalent_repr(t)

sum_of_products_data = [((1,),), ((0, 1),), ((0, 0), (0, 1), (1, 0))]
for t in map(cv.SumOfProducts, sum_of_products_data):
cirq.testing.assert_equivalent_repr(t)


def test_validate():
control_val = cv.SumOfProducts(((1, 2), (0, 1)))

_ = control_val.validate([2, 3])

with pytest.raises(ValueError):
_ = control_val.validate([2, 2])

# number of qubits != number of control values.
with pytest.raises(ValueError):
_ = control_val.validate([2])


def test_len():
data = [((1,),), ((0, 1),), ((0, 0), (0, 1), (1, 0))]
for vals in data:
c = cv.SumOfProducts(vals)
assert len(c) == len(vals[0])


def test_hash():
data = [((1,),), ((0, 1),), ((0, 0), (0, 1), (1, 0))]
assert len(set(map(hash, map(cv.SumOfProducts, data)))) == 3


def test_are_ones():
data = [((1,),), ((0, 1),), ((0, 0), (0, 1), (1, 0)), ((1, 1, 1, 1),)]
are_ones = [True, False, False]
for vals, want in zip(data, are_ones):
c = cv.SumOfProducts(vals)
assert c._are_ones() == want


def test_diagram_repr():
c = cv.SumOfProducts(((1, 0), (0, 1)))
assert c.diagram_repr() == '10,01'

assert c.diagram_repr('xor') == 'xor'

assert cv.ProductOfSums(((1,), (0,))).diagram_repr('10') == '10'
3 changes: 0 additions & 3 deletions cirq-core/cirq/ops/controlled_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,9 +633,6 @@ def _are_ones(self):
def _json_dict_(self):
pass

def __getitem__(self, idx):
pass


def test_decompose_applies_only_to_ProductOfSums():
g = cirq.ControlledGate(cirq.X, control_values=MockControlValues())
Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/SumOfProducts.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"_internal_representation": [[1, 0], [1, 2]],
"cirq_type": "SumOfProducts"
}
1 change: 1 addition & 0 deletions cirq-core/cirq/protocols/json_test_data/SumOfProducts.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.SumOfProducts([[1, 0], [1, 2]])