Skip to content

Commit

Permalink
Created SumOfProducts subclass of ControlValues (quantumlib#5755)
Browse files Browse the repository at this point in the history
This class allows the creation of control values that can't be factored into simple products hence solving quantumlib#4512
  • Loading branch information
NoureldinYosri authored Jul 15, 2022
1 parent 7f58492 commit a85d481
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 18 deletions.
1 change: 1 addition & 0 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@
SQRT_ISWAP_INV,
SWAP,
SwapPowGate,
SumOfProducts,
T,
TaggedOperation,
ThreeQubitDiagonalGate,
Expand Down
1 change: 1 addition & 0 deletions cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def _symmetricalqidpair(qids):
'SqrtIswapTargetGateset': cirq.SqrtIswapTargetGateset,
'StabilizerStateChForm': cirq.StabilizerStateChForm,
'StatePreparationChannel': cirq.StatePreparationChannel,
'SumOfProducts': cirq.SumOfProducts,
'SwapPowGate': cirq.SwapPowGate,
'SympyCondition': cirq.SympyCondition,
'TaggedOperation': cirq.TaggedOperation,
Expand Down
2 changes: 1 addition & 1 deletion cirq/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,4 @@

from cirq.ops.state_preparation_channel import StatePreparationChannel

from cirq.ops.control_values import AbstractControlValues, ProductOfSums
from cirq.ops.control_values import AbstractControlValues, ProductOfSums, SumOfProducts
28 changes: 28 additions & 0 deletions cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,17 @@ def controlled(
A `cirq.ControlledGate` (or `cirq.CXPowGate` if possible) representing
`self` controlled by the given control values and qubits.
"""
if control_values and not isinstance(control_values, cv.AbstractControlValues):
control_values = cv.ProductOfSums(
tuple(
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
)
)
result = super().controlled(num_controls, control_values, control_qid_shape)
if (
self._global_shift == 0
and isinstance(result, controlled_gate.ControlledGate)
and isinstance(result.control_values, cv.ProductOfSums)
and result.control_values[-1] == (1,)
and result.control_qid_shape[-1] == 2
):
Expand Down Expand Up @@ -680,10 +687,17 @@ def controlled(
A `cirq.ControlledGate` (or `cirq.CZPowGate` if possible) representing
`self` controlled by the given control values and qubits.
"""
if control_values and not isinstance(control_values, cv.AbstractControlValues):
control_values = cv.ProductOfSums(
tuple(
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
)
)
result = super().controlled(num_controls, control_values, control_qid_shape)
if (
self._global_shift == 0
and isinstance(result, controlled_gate.ControlledGate)
and isinstance(result.control_values, cv.ProductOfSums)
and result.control_values[-1] == (1,)
and result.control_qid_shape[-1] == 2
):
Expand Down Expand Up @@ -1116,10 +1130,17 @@ def controlled(
A `cirq.ControlledGate` (or `cirq.CCZPowGate` if possible) representing
`self` controlled by the given control values and qubits.
"""
if control_values and not isinstance(control_values, cv.AbstractControlValues):
control_values = cv.ProductOfSums(
tuple(
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
)
)
result = super().controlled(num_controls, control_values, control_qid_shape)
if (
self._global_shift == 0
and isinstance(result, controlled_gate.ControlledGate)
and isinstance(result.control_values, cv.ProductOfSums)
and result.control_values[-1] == (1,)
and result.control_qid_shape[-1] == 2
):
Expand Down Expand Up @@ -1315,10 +1336,17 @@ def controlled(
A `cirq.ControlledGate` (or `cirq.CCXPowGate` if possible) representing
`self` controlled by the given control values and qubits.
"""
if control_values and not isinstance(control_values, cv.AbstractControlValues):
control_values = cv.ProductOfSums(
tuple(
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
)
)
result = super().controlled(num_controls, control_values, control_qid_shape)
if (
self._global_shift == 0
and isinstance(result, controlled_gate.ControlledGate)
and isinstance(result.control_values, cv.ProductOfSums)
and result.control_values[-1] == (1,)
and result.control_qid_shape[-1] == 2
):
Expand Down
124 changes: 114 additions & 10 deletions cirq/ops/control_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Union, Tuple, List, TYPE_CHECKING, Any, Dict, Generator, cast, Iterator
from typing import Union, Tuple, List, TYPE_CHECKING, Any, Dict, Generator, cast, Iterator, Optional
from dataclasses import dataclass

import itertools
Expand Down Expand Up @@ -53,7 +53,7 @@ def _expand(self) -> Iterator[Tuple[int, ...]]:
"""Expands the (possibly compressed) internal representation into a sum of products representation.""" # pylint: disable=line-too-long

@abc.abstractmethod
def diagram_repr(self) -> str:
def diagram_repr(self, label: Optional[str] = None) -> str:
"""Returns a string representation to be used in circuit diagrams."""

@abc.abstractmethod
Expand Down Expand Up @@ -92,12 +92,6 @@ def _are_ones(self) -> bool:
def _json_dict_(self) -> Dict[str, Any]:
pass

@abc.abstractmethod
def __getitem__(
self, key: Union[slice, int]
) -> Union['AbstractControlValues', Tuple[int, ...]]:
pass

def __iter__(self) -> Generator[Tuple[int, ...], None, None]:
for assignment in self._expand():
yield assignment
Expand Down Expand Up @@ -154,7 +148,9 @@ def validate(self, qid_shapes: Union[Tuple[int, ...], List[int]]) -> None:
def _are_ones(self) -> bool:
return frozenset(self._internal_representation) == {(1,)}

def diagram_repr(self) -> str:
def diagram_repr(self, label: Optional[str] = None) -> str:
if label:
return label
if self._are_ones():
return 'C' * self._number_variables()

Expand All @@ -174,9 +170,117 @@ def __getitem__(
def _json_dict_(self) -> Dict[str, Any]:
return {'_internal_representation': self._internal_representation}

def __and__(self, other: AbstractControlValues) -> 'ProductOfSums':
def __and__(self, other: AbstractControlValues) -> AbstractControlValues:
if isinstance(other, SumOfProducts):
return SumOfProducts(tuple(p for p in self)) & other
if not isinstance(other, ProductOfSums):
raise TypeError(
f'And operation not supported between types ProductOfSums and {type(other)}'
)
return type(self)(self._internal_representation + other._internal_representation)


@dataclass(frozen=True, eq=False)
class SumOfProducts(AbstractControlValues):
"""Represents control values as a union of n-bit tuples.
`SumOfProducts` representation describes the control values as a union
of n-bit tuples, where each n-bit tuple represents an allowed assignment
of bits for which the control should be activated. This expanded
representation allows us to create control values combinations which
cannot be factored as a `ProductOfSums` representation.
For example:
1) `(|00><00| + |11><11|) X + (|01><01| + |10><10|) I` represents an
operator which flips the third qubit if the first two qubits
are `00` or `11`, and does nothing otherwise.
This can be constructed as
>>> xor_control_values = cirq.SumOfProducts(((0, 0), (1, 1)))
>>> q0, q1, q2 = cirq.LineQubit.range(3)
>>> xor_cop = cirq.X(q2).controlled_by(q0, q1, control_values=xor_control_values)
2) `(|00><00| + |01><01| + |10><10|) X + (|11><11|) I` represents an
operators which flips the third qubit if the `nand` of first two
qubits is `1` (i.e. first two qubits are either `00`, `01` or `10`),
and does nothing otherwise. This can be constructed as:
>>> nand_control_values = cirq.SumOfProducts(((0, 0), (0, 1), (1, 0)))
>>> q0, q1, q2 = cirq.LineQubit.range(3)
>>> nan_cop = cirq.X(q2).controlled_by(q0, q1, control_values=nand_control_values)
"""

_internal_representation: Tuple[Tuple[int, ...], ...]

def __post_init__(self):
if not len(self._internal_representation):
raise ValueError('SumOfProducts can\'t be empty.')
num_qubits = len(self._internal_representation[0])
for p in self._internal_representation:
if len(p) != num_qubits:
raise ValueError(
f'size mismatch between different products of {self._internal_representation}'
)
if len(self._internal_representation) != len(
set(map(tuple, self._internal_representation))
):
raise ValueError('SumOfProducts can\'t have duplicate products.')

def _identifier(self) -> Tuple[Tuple[int, ...], ...]:
return self._internal_representation

def _expand(self) -> Iterator[Tuple[int, ...]]:
"""Returns the combinations tracked by the object."""
self = cast('SumOfProducts', self)
return iter(self._internal_representation)

def __repr__(self) -> str:
return f'cirq.SumOfProducts({str(self._identifier())})'

def _number_variables(self) -> int:
return len(self._internal_representation[0])

def __len__(self) -> int:
return self._number_variables()

def __hash__(self) -> int:
return hash(self._internal_representation)

def validate(self, qid_shapes: Union[Tuple[int, ...], List[int]]) -> None:
for vals in self._internal_representation:
if len(qid_shapes) != len(vals):
raise ValueError(
f'number of values in product {vals} doesn\'t equal number'
f' of qubits(={len(qid_shapes)})'
)

for i, v in enumerate(vals):
if not (0 <= v and v < qid_shapes[i]):
raise ValueError(
f'Control values <{v}> in combination {vals} is outside'
f' of range for control qubit number <{i}>.'
)

def _are_ones(self) -> bool:
return frozenset(self._internal_representation) == {(1,) * self._number_variables()}

def diagram_repr(self, label: Optional[str] = None) -> str:
if label:
return label
return ','.join(map(lambda p: ''.join(map(str, p)), self._internal_representation))

def _json_dict_(self) -> Dict[str, Any]:
return {'_internal_representation': self._internal_representation}

def __and__(self, other: AbstractControlValues) -> 'SumOfProducts':
if isinstance(other, ProductOfSums):
other = SumOfProducts(tuple(p for p in other))
if not isinstance(other, SumOfProducts):
raise TypeError(
f'And operation not supported between types SumOfProducts and {type(other)}'
)
combined = map(
lambda p: tuple(itertools.chain(*p)),
itertools.product(self._internal_representation, other._internal_representation),
)
return SumOfProducts(tuple(combined))
95 changes: 91 additions & 4 deletions cirq/ops/control_values_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,63 @@ def test_init_productOfSum():
((((0, 1), (1, 0))), {(0, 0), (0, 1), (1, 0), (1, 1)}),
]
for control_values, want in tests:
print(control_values)
got = {c for c in cv.ProductOfSums(control_values)}
eq.add_equality_group(got, want)


def test_init_SumOfProducts():
eq = cirq.testing.EqualsTester()
tests = [
(((1,),), {(1,)}),
(((0, 1), (1, 0)), {(0, 1), (1, 0)}), # XOR
(((0, 0), (0, 1), (1, 0)), {(0, 0), (0, 1), (1, 0)}), # NAND
]
for control_values, want in tests:
got = {c for c in cv.SumOfProducts(control_values)}
eq.add_equality_group(got, want)

with pytest.raises(ValueError):
_ = cv.SumOfProducts([])

# size mismatch
with pytest.raises(ValueError):
_ = cv.SumOfProducts([[1], [1, 0]])

# can't have duplicates
with pytest.raises(ValueError):
_ = cv.SumOfProducts([[1, 0], [0, 1], [1, 0]])


def test_and_operation():
eq = cirq.testing.EqualsTester()
originals = [((1,),), ((0, 1), (1,)), (((0, 1), (1, 0)))]
for control_values1 in originals:
for control_values2 in originals:
product_of_sums_data = [((1,),), ((0, 1), (1,)), (((0, 1), (1, 0)))]
for control_values1 in product_of_sums_data:
for control_values2 in product_of_sums_data:
control_vals1 = cv.ProductOfSums(control_values1)
control_vals2 = cv.ProductOfSums(control_values2)
want = [v1 + v2 for v1 in control_vals1 for v2 in control_vals2]
got = [c for c in control_vals1 & control_vals2]
eq.add_equality_group(got, want)

sum_of_products_data = [((1,),), ((0, 1),), ((0, 0), (0, 1), (1, 0))]
eq = cirq.testing.EqualsTester()
for control_values1 in sum_of_products_data:
for control_values2 in sum_of_products_data:
control_vals1 = cv.SumOfProducts(control_values1)
control_vals2 = cv.SumOfProducts(control_values2)
want = [v1 + v2 for v1 in control_vals1 for v2 in control_vals2]
got = [c for c in control_vals1 & control_vals2]
eq.add_equality_group(got, want)

pos = cv.ProductOfSums(((1,), (0,)))
sop = cv.SumOfProducts(((1, 0), (0, 1)))
assert tuple(p for p in pos & sop) == ((1, 0, 1, 0), (1, 0, 0, 1))

assert tuple(p for p in sop & pos) == ((1, 0, 1, 0), (0, 1, 1, 0))

with pytest.raises(TypeError):
_ = sop & 1


def test_and_supported_types():
CV = cv.ProductOfSums((1,))
Expand All @@ -52,3 +93,49 @@ def test_repr():
product_of_sums_data = [((1,),), ((0, 1), (1,)), (((0, 1), (1, 0)))]
for t in map(cv.ProductOfSums, product_of_sums_data):
cirq.testing.assert_equivalent_repr(t)

sum_of_products_data = [((1,),), ((0, 1),), ((0, 0), (0, 1), (1, 0))]
for t in map(cv.SumOfProducts, sum_of_products_data):
cirq.testing.assert_equivalent_repr(t)


def test_validate():
control_val = cv.SumOfProducts(((1, 2), (0, 1)))

_ = control_val.validate([2, 3])

with pytest.raises(ValueError):
_ = control_val.validate([2, 2])

# number of qubits != number of control values.
with pytest.raises(ValueError):
_ = control_val.validate([2])


def test_len():
data = [((1,),), ((0, 1),), ((0, 0), (0, 1), (1, 0))]
for vals in data:
c = cv.SumOfProducts(vals)
assert len(c) == len(vals[0])


def test_hash():
data = [((1,),), ((0, 1),), ((0, 0), (0, 1), (1, 0))]
assert len(set(map(hash, map(cv.SumOfProducts, data)))) == 3


def test_are_ones():
data = [((1,),), ((0, 1),), ((0, 0), (0, 1), (1, 0)), ((1, 1, 1, 1),)]
are_ones = [True, False, False]
for vals, want in zip(data, are_ones):
c = cv.SumOfProducts(vals)
assert c._are_ones() == want


def test_diagram_repr():
c = cv.SumOfProducts(((1, 0), (0, 1)))
assert c.diagram_repr() == '10,01'

assert c.diagram_repr('xor') == 'xor'

assert cv.ProductOfSums(((1,), (0,))).diagram_repr('10') == '10'
3 changes: 0 additions & 3 deletions cirq/ops/controlled_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,9 +633,6 @@ def _are_ones(self):
def _json_dict_(self):
pass

def __getitem__(self, idx):
pass


def test_decompose_applies_only_to_ProductOfSums():
g = cirq.ControlledGate(cirq.X, control_values=MockControlValues())
Expand Down
4 changes: 4 additions & 0 deletions cirq/protocols/json_test_data/SumOfProducts.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"_internal_representation": [[1, 0], [1, 2]],
"cirq_type": "SumOfProducts"
}
1 change: 1 addition & 0 deletions cirq/protocols/json_test_data/SumOfProducts.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.SumOfProducts([[1, 0], [1, 2]])

0 comments on commit a85d481

Please sign in to comment.