Skip to content

Commit

Permalink
Extend default decomposition of cirq.ControlledGate and `cirq.Contr…
Browse files Browse the repository at this point in the history
…olledOperation` to end in X/Y/Z/CZ target gateset (#5091)

When decomposed, controlled gates and operations simply fall back on the decomposition of underlying sub_gate / sub_operation and return apply appropriate controls to each decomposed operation. 

If we can ensure that all underlying gates / operations decompose to X/Y/Z/CZ target gateset, then their controlled versions will decompose to:
 - Multi controlled single qubit rotations (corresponding to (X/Y/Z).controlled_by(...)) OR
 - Multi controlled CZs, which is also equivalent to a multi controlled single qubit rotation (Z.controlled_by(...))

In Cirq, we have an analytical method to decompose a multi controlled rotation into X/Y/Z/CZ - `cirq.decompose_multi_controlled_rotation`, which is now used in the `_decompose_` method of controlled gates. 

However, there are many corner cases and limitations of the current approach, which are dealt appropriately in this PR to enable a "best-effort" decomposition of controlled gates to the cirq target gateset. Some of the limitations are:
 - If decomposition of sub_gate / sub_operation ignores global phase, then the controlled operation cannot directly rely on decomposing the sub operation. An explicit check is added to not fallback on sub_gate if sub_gate is a  MatrixGate.
 - `decompose_multi_controlled_rotation` works only for qubits (doesn't work for qudits) and when all control_values are 1. Appropriate logic is added to extend its functionality to handle control_values which are 0 or (0, 1). 
 - We have explicit types for a few important controlled gates, like `CCZ`, `CZ`, `CCX`, `CX` etc. in cirq. Appropriate type conversion logic is added to smartly infer the types of equivalent gates (eg: Controlled(sub_gate=CZ) should be inferred as CCZ) such that their decompositions can be used for decomposing the controlled gates. 


This is definitely the most tricky one to get right and I've added appropriate tests to cover the different cases. 


Part of #4858
  • Loading branch information
tanujkhattar authored Mar 17, 2022
1 parent 90e70c6 commit 1f47082
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 43 deletions.
40 changes: 39 additions & 1 deletion cirq-core/cirq/ops/controlled_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
if TYPE_CHECKING:
import cirq

controlled_gate_decomposition = _import.LazyLoader(
'controlled_gate_decomposition', globals(), 'cirq.transformers.analytical_decompositions'
)
common_gates = _import.LazyLoader('common_gates', globals(), 'cirq.ops')
line_qubit = _import.LazyLoader('line_qubit', globals(), 'cirq.devices')


Expand Down Expand Up @@ -156,6 +160,40 @@ def _qid_shape_(self) -> Tuple[int, ...]:
return self.control_qid_shape + protocols.qid_shape(self.sub_gate)

def _decompose_(self, qubits):
if (
protocols.has_unitary(self.sub_gate)
and protocols.num_qubits(self.sub_gate) == 1
and self._qid_shape_() == (2,) * len(self._qid_shape_())
):
control_qubits = list(qubits[: self.num_controls()])
invert_ops: List['cirq.Operation'] = []
for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]):
if set(cvals) == {0}:
invert_ops.append(common_gates.X(cqbit))
elif set(cvals) == {0, 1}:
control_qubits.remove(cqbit)
decomposed_ops = controlled_gate_decomposition.decompose_multi_controlled_rotation(
protocols.unitary(self.sub_gate), control_qubits, qubits[-1]
)
return invert_ops + decomposed_ops + invert_ops

if isinstance(self.sub_gate, common_gates.CZPowGate):
z_sub_gate = common_gates.ZPowGate(
exponent=self.sub_gate.exponent, global_shift=self.sub_gate.global_shift
)
kwargs = {
'num_controls': self.num_controls() + 1,
'control_values': self.control_values + (1,),
'control_qid_shape': self.control_qid_shape + (2,),
}
controlled_z = (
z_sub_gate.controlled(**kwargs)
if protocols.is_parameterized(self)
else ControlledGate(z_sub_gate, **kwargs)
)
if self != controlled_z:
return protocols.decompose_once_with_qubits(controlled_z, qubits, NotImplemented)

if isinstance(self.sub_gate, matrix_gates.MatrixGate):
# Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is
# local phase in the controlled variant and hence cannot be ignored.
Expand All @@ -170,7 +208,7 @@ def _decompose_(self, qubits):
decomposed: List['cirq.Operation'] = []
for op in result:
decomposed.append(
cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values)
op.controlled_by(*qubits[: self.num_controls()], control_values=self.control_values)
)
return decomposed

Expand Down
61 changes: 42 additions & 19 deletions cirq-core/cirq/ops/controlled_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,28 @@ def __repr__(self):


class GateAllocatingNewSpaceForResult(cirq.SingleQubitGate):
def __init__(self):
self._matrix = cirq.testing.random_unitary(2, random_state=4321)

def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> Union[np.ndarray, NotImplementedType]:
assert len(args.axes) == 1
a = args.axes[0]
seed = cast(Tuple[Union[int, slice, 'ellipsis'], ...], (slice(None),))
zero = seed * a + (0, Ellipsis)
one = seed * a + (1, Ellipsis)
result = np.zeros(args.target_tensor.shape, args.target_tensor.dtype)
result[zero] = args.target_tensor[zero] * 2 + args.target_tensor[one] * 3
result[one] = args.target_tensor[zero] * 5 + args.target_tensor[one] * 7
result[zero] = (
args.target_tensor[zero] * self._matrix[0][0]
+ args.target_tensor[one] * self._matrix[0][1]
)
result[one] = (
args.target_tensor[zero] * self._matrix[1][0]
+ args.target_tensor[one] * self._matrix[1][1]
)
return result

def _unitary_(self):
return np.array([[2, 3], [5, 7]])
return self._matrix

def __eq__(self, other):
return isinstance(other, type(self))
Expand Down Expand Up @@ -316,28 +325,42 @@ def test_unitary():


@pytest.mark.parametrize(
'gate',
'gate, should_decompose_to_target',
[
cirq.X,
cirq.X ** 0.5,
cirq.rx(np.pi),
cirq.rx(np.pi / 2),
cirq.Z,
cirq.H,
cirq.CNOT,
cirq.SWAP,
cirq.CCZ,
cirq.ControlledGate(cirq.ControlledGate(cirq.CCZ)),
GateUsingWorkspaceForApplyUnitary(),
GateAllocatingNewSpaceForResult(),
cirq.IdentityGate(qid_shape=(3, 4)),
(cirq.X, True),
(cirq.X ** 0.5, True),
(cirq.rx(np.pi), True),
(cirq.rx(np.pi / 2), True),
(cirq.Z, True),
(cirq.H, True),
(cirq.CNOT, True),
(cirq.SWAP, True),
(cirq.CCZ, True),
(cirq.ControlledGate(cirq.ControlledGate(cirq.CCZ)), True),
(GateUsingWorkspaceForApplyUnitary(), True),
(GateAllocatingNewSpaceForResult(), True),
(cirq.IdentityGate(qid_shape=(3, 4)), True),
(
cirq.ControlledGate(
cirq.XXPowGate(exponent=0.25, global_shift=-0.5),
num_controls=2,
control_values=(1, (1, 0)),
),
True,
),
# Single qudit gate with dimension 4.
cirq.MatrixGate(np.kron(*(cirq.unitary(cirq.H),) * 2)),
(cirq.MatrixGate(np.kron(*(cirq.unitary(cirq.H),) * 2), qid_shape=(4,)), False),
(cirq.MatrixGate(cirq.testing.random_unitary(4, random_state=1234)), False),
(cirq.XX ** sympy.Symbol("s"), True),
(cirq.CZ ** sympy.Symbol("s"), True),
],
)
def test_controlled_gate_is_consistent(gate: cirq.Gate):
def test_controlled_gate_is_consistent(gate: cirq.Gate, should_decompose_to_target):
cgate = cirq.ControlledGate(gate)
cirq.testing.assert_implements_consistent_protocols(cgate)
cirq.testing.assert_decompose_ends_at_default_gateset(
cgate, ignore_known_gates=not should_decompose_to_target
)


def test_pow_inverse():
Expand Down
15 changes: 13 additions & 2 deletions cirq-core/cirq/ops/controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from cirq import protocols, qis, value
from cirq._compat import deprecated
from cirq.ops import raw_types, gate_operation, controlled_gate
from cirq.ops import raw_types, gate_operation, controlled_gate, matrix_gates
from cirq.type_workarounds import NotImplementedType

if TYPE_CHECKING:
Expand Down Expand Up @@ -130,11 +130,22 @@ def with_qubits(self, *new_qubits):
)

def _decompose_(self):
result = protocols.decompose_once_with_qubits(self.gate, self.qubits, NotImplemented)
if result is not NotImplemented:
return result

if isinstance(self.sub_operation.gate, matrix_gates.MatrixGate):
# Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is
# local phase in the controlled variant and hence cannot be ignored.
return NotImplemented

result = protocols.decompose_once(self.sub_operation, NotImplemented)
if result is NotImplemented:
return NotImplemented

return [ControlledOperation(self.controls, op, self.control_values) for op in result]
return [
op.controlled_by(*self.controls, control_values=self.control_values) for op in result
]

def _value_equality_values_(self):
return (frozenset(zip(self.controls, self.control_values)), self.sub_operation)
Expand Down
92 changes: 75 additions & 17 deletions cirq-core/cirq/ops/controlled_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,28 @@ def __repr__(self):


class GateAllocatingNewSpaceForResult(cirq.SingleQubitGate):
def __init__(self):
self._matrix = cirq.testing.random_unitary(2, random_state=1234)

def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> Union[np.ndarray, NotImplementedType]:
assert len(args.axes) == 1
a = args.axes[0]
seed = cast(Tuple[Union[int, slice, 'ellipsis'], ...], (slice(None),))
zero = seed * a + (0, Ellipsis)
one = seed * a + (1, Ellipsis)
result = np.zeros(args.target_tensor.shape, args.target_tensor.dtype)
result[zero] = args.target_tensor[zero] * 2 + args.target_tensor[one] * 3
result[one] = args.target_tensor[zero] * 5 + args.target_tensor[one] * 7
result[zero] = (
args.target_tensor[zero] * self._matrix[0][0]
+ args.target_tensor[one] * self._matrix[0][1]
)
result[one] = (
args.target_tensor[zero] * self._matrix[1][0]
+ args.target_tensor[one] * self._matrix[1][1]
)
return result

def _unitary_(self):
return np.array([[2, 3], [5, 7]])
return self._matrix

def __eq__(self, other):
return isinstance(other, type(self))
Expand Down Expand Up @@ -297,33 +306,82 @@ class UndiagrammableGate(cirq.SingleQubitGate):


@pytest.mark.parametrize(
'gate',
'gate, should_decompose_to_target',
[
cirq.X(cirq.NamedQubit('q1')),
cirq.X(cirq.NamedQubit('q1')) ** 0.5,
cirq.rx(np.pi)(cirq.NamedQubit('q1')),
cirq.rx(np.pi / 2)(cirq.NamedQubit('q1')),
cirq.Z(cirq.NamedQubit('q1')),
cirq.H(cirq.NamedQubit('q1')),
cirq.CNOT(cirq.NamedQubit('q1'), cirq.NamedQubit('q2')),
cirq.SWAP(cirq.NamedQubit('q1'), cirq.NamedQubit('q2')),
cirq.CCZ(cirq.NamedQubit('q1'), cirq.NamedQubit('q2'), cirq.NamedQubit('q3')),
cirq.ControlledGate(cirq.ControlledGate(cirq.CCZ))(*cirq.LineQubit.range(5)),
GateUsingWorkspaceForApplyUnitary()(cirq.NamedQubit('q1')),
GateAllocatingNewSpaceForResult()(cirq.NamedQubit('q1')),
(cirq.X(cirq.NamedQubit('q1')), True),
(cirq.X(cirq.NamedQubit('q1')) ** 0.5, True),
(cirq.rx(np.pi)(cirq.NamedQubit('q1')), True),
(cirq.rx(np.pi / 2)(cirq.NamedQubit('q1')), True),
(cirq.Z(cirq.NamedQubit('q1')), True),
(cirq.H(cirq.NamedQubit('q1')), True),
(cirq.CNOT(cirq.NamedQubit('q1'), cirq.NamedQubit('q2')), True),
(cirq.SWAP(cirq.NamedQubit('q1'), cirq.NamedQubit('q2')), True),
(cirq.CCZ(cirq.NamedQubit('q1'), cirq.NamedQubit('q2'), cirq.NamedQubit('q3')), True),
(cirq.ControlledGate(cirq.ControlledGate(cirq.CCZ))(*cirq.LineQubit.range(5)), True),
(GateUsingWorkspaceForApplyUnitary()(cirq.NamedQubit('q1')), True),
(GateAllocatingNewSpaceForResult()(cirq.NamedQubit('q1')), True),
(
cirq.MatrixGate(np.kron(*(cirq.unitary(cirq.H),) * 2), qid_shape=(4,)).on(
cirq.NamedQid("q", 4)
),
False,
),
(
cirq.MatrixGate(cirq.testing.random_unitary(4, random_state=1234)).on(
cirq.NamedQubit('q1'), cirq.NamedQubit('q2')
),
False,
),
(cirq.XX(cirq.NamedQubit('q1'), cirq.NamedQubit('q2')) ** sympy.Symbol("s"), True),
(cirq.DiagonalGate(sympy.symbols("s1, s2")).on(cirq.NamedQubit("q")), False),
],
)
def test_controlled_operation_is_consistent(gate: cirq.GateOperation):
def test_controlled_operation_is_consistent(
gate: cirq.GateOperation, should_decompose_to_target: bool
):
cb = cirq.NamedQubit('ctr')
cgate = cirq.ControlledOperation([cb], gate)
cirq.testing.assert_implements_consistent_protocols(cgate)
cirq.testing.assert_decompose_ends_at_default_gateset(
cgate, ignore_known_gates=not should_decompose_to_target
)

cgate = cirq.ControlledOperation([cb], gate, control_values=[0])
cirq.testing.assert_implements_consistent_protocols(cgate)
cirq.testing.assert_decompose_ends_at_default_gateset(
cgate, ignore_known_gates=(not should_decompose_to_target or cirq.is_parameterized(gate))
)

cgate = cirq.ControlledOperation([cb], gate, control_values=[(0, 1)])
cirq.testing.assert_implements_consistent_protocols(cgate)
cirq.testing.assert_decompose_ends_at_default_gateset(
cgate, ignore_known_gates=(not should_decompose_to_target or cirq.is_parameterized(gate))
)

cb3 = cb.with_dimension(3)
cgate = cirq.ControlledOperation([cb3], gate, control_values=[(0, 2)])
cirq.testing.assert_implements_consistent_protocols(cgate)
cirq.testing.assert_decompose_ends_at_default_gateset(cgate)


def test_controlled_circuit_operation_is_consistent():
op = cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.XXPowGate(exponent=0.25, global_shift=-0.5).on(*cirq.LineQubit.range(2))
)
)
cb = cirq.NamedQubit('ctr')
cop = cirq.ControlledOperation([cb], op)
cirq.testing.assert_implements_consistent_protocols(cop, exponents=(-1, 1, 2))
cirq.testing.assert_decompose_ends_at_default_gateset(cop)

cop = cirq.ControlledOperation([cb], op, control_values=[0])
cirq.testing.assert_implements_consistent_protocols(cop, exponents=(-1, 1, 2))
cirq.testing.assert_decompose_ends_at_default_gateset(cop)

cop = cirq.ControlledOperation([cb], op, control_values=[(0, 1)])
cirq.testing.assert_implements_consistent_protocols(cop, exponents=(-1, 1, 2))
cirq.testing.assert_decompose_ends_at_default_gateset(cop)


@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once])
Expand Down
14 changes: 10 additions & 4 deletions cirq-core/cirq/testing/consistent_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,22 @@ def _known_gate_with_no_decomposition(val: Any):
"""Checks whether `val` is a known gate with no default decomposition to default gateset."""
if isinstance(val, ops.MatrixGate):
return protocols.qid_shape(val) not in [(2,), (2,) * 2, (2,) * 3]
if isinstance(val, ops.ControlledGate):
if protocols.is_parameterized(val):
return True
if isinstance(val.sub_gate, ops.MatrixGate) and protocols.num_qubits(val.sub_gate) > 1:
return True
if val.control_qid_shape != (2,) * val.num_controls():
return True
return _known_gate_with_no_decomposition(val.sub_gate)
return False


def assert_decompose_ends_at_default_gateset(val: Any):
def assert_decompose_ends_at_default_gateset(val: Any, ignore_known_gates: bool = True):
"""Asserts that cirq.decompose(val) ends at default cirq gateset or a known gate."""
if _known_gate_with_no_decomposition(val):
return # coverage: ignore
args = () if isinstance(val, ops.Operation) else (tuple(devices.LineQid.for_gate(val)),)
dec_once = protocols.decompose_once(val, [val(*args[0]) if args else val], *args)
for op in [*ops.flatten_to_ops(protocols.decompose(d) for d in dec_once)]:
assert _known_gate_with_no_decomposition(op.gate) or (
assert (_known_gate_with_no_decomposition(op.gate) and ignore_known_gates) or (
op in protocols.decompose_protocol.DECOMPOSE_TARGET_GATESET
), f'{val} decomposed to {op}, which is not part of default cirq target gateset.'

0 comments on commit 1f47082

Please sign in to comment.