diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index 75e2ae03ccb..53cc732100c 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -36,6 +36,10 @@ if TYPE_CHECKING: import cirq +controlled_gate_decomposition = _import.LazyLoader( + 'controlled_gate_decomposition', globals(), 'cirq.transformers.analytical_decompositions' +) +common_gates = _import.LazyLoader('common_gates', globals(), 'cirq.ops') line_qubit = _import.LazyLoader('line_qubit', globals(), 'cirq.devices') @@ -156,6 +160,40 @@ def _qid_shape_(self) -> Tuple[int, ...]: return self.control_qid_shape + protocols.qid_shape(self.sub_gate) def _decompose_(self, qubits): + if ( + protocols.has_unitary(self.sub_gate) + and protocols.num_qubits(self.sub_gate) == 1 + and self._qid_shape_() == (2,) * len(self._qid_shape_()) + ): + control_qubits = list(qubits[: self.num_controls()]) + invert_ops: List['cirq.Operation'] = [] + for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]): + if set(cvals) == {0}: + invert_ops.append(common_gates.X(cqbit)) + elif set(cvals) == {0, 1}: + control_qubits.remove(cqbit) + decomposed_ops = controlled_gate_decomposition.decompose_multi_controlled_rotation( + protocols.unitary(self.sub_gate), control_qubits, qubits[-1] + ) + return invert_ops + decomposed_ops + invert_ops + + if isinstance(self.sub_gate, common_gates.CZPowGate): + z_sub_gate = common_gates.ZPowGate( + exponent=self.sub_gate.exponent, global_shift=self.sub_gate.global_shift + ) + kwargs = { + 'num_controls': self.num_controls() + 1, + 'control_values': self.control_values + (1,), + 'control_qid_shape': self.control_qid_shape + (2,), + } + controlled_z = ( + z_sub_gate.controlled(**kwargs) + if protocols.is_parameterized(self) + else ControlledGate(z_sub_gate, **kwargs) + ) + if self != controlled_z: + return protocols.decompose_once_with_qubits(controlled_z, qubits, NotImplemented) + if isinstance(self.sub_gate, matrix_gates.MatrixGate): # Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is # local phase in the controlled variant and hence cannot be ignored. @@ -170,7 +208,7 @@ def _decompose_(self, qubits): decomposed: List['cirq.Operation'] = [] for op in result: decomposed.append( - cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values) + op.controlled_by(*qubits[: self.num_controls()], control_values=self.control_values) ) return decomposed diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index b0545b9de5c..67495a5f47a 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -39,6 +39,9 @@ def __repr__(self): class GateAllocatingNewSpaceForResult(cirq.SingleQubitGate): + def __init__(self): + self._matrix = cirq.testing.random_unitary(2, random_state=4321) + def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> Union[np.ndarray, NotImplementedType]: assert len(args.axes) == 1 a = args.axes[0] @@ -46,12 +49,18 @@ def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> Union[np.ndarray, NotI zero = seed * a + (0, Ellipsis) one = seed * a + (1, Ellipsis) result = np.zeros(args.target_tensor.shape, args.target_tensor.dtype) - result[zero] = args.target_tensor[zero] * 2 + args.target_tensor[one] * 3 - result[one] = args.target_tensor[zero] * 5 + args.target_tensor[one] * 7 + result[zero] = ( + args.target_tensor[zero] * self._matrix[0][0] + + args.target_tensor[one] * self._matrix[0][1] + ) + result[one] = ( + args.target_tensor[zero] * self._matrix[1][0] + + args.target_tensor[one] * self._matrix[1][1] + ) return result def _unitary_(self): - return np.array([[2, 3], [5, 7]]) + return self._matrix def __eq__(self, other): return isinstance(other, type(self)) @@ -316,28 +325,42 @@ def test_unitary(): @pytest.mark.parametrize( - 'gate', + 'gate, should_decompose_to_target', [ - cirq.X, - cirq.X ** 0.5, - cirq.rx(np.pi), - cirq.rx(np.pi / 2), - cirq.Z, - cirq.H, - cirq.CNOT, - cirq.SWAP, - cirq.CCZ, - cirq.ControlledGate(cirq.ControlledGate(cirq.CCZ)), - GateUsingWorkspaceForApplyUnitary(), - GateAllocatingNewSpaceForResult(), - cirq.IdentityGate(qid_shape=(3, 4)), + (cirq.X, True), + (cirq.X ** 0.5, True), + (cirq.rx(np.pi), True), + (cirq.rx(np.pi / 2), True), + (cirq.Z, True), + (cirq.H, True), + (cirq.CNOT, True), + (cirq.SWAP, True), + (cirq.CCZ, True), + (cirq.ControlledGate(cirq.ControlledGate(cirq.CCZ)), True), + (GateUsingWorkspaceForApplyUnitary(), True), + (GateAllocatingNewSpaceForResult(), True), + (cirq.IdentityGate(qid_shape=(3, 4)), True), + ( + cirq.ControlledGate( + cirq.XXPowGate(exponent=0.25, global_shift=-0.5), + num_controls=2, + control_values=(1, (1, 0)), + ), + True, + ), # Single qudit gate with dimension 4. - cirq.MatrixGate(np.kron(*(cirq.unitary(cirq.H),) * 2)), + (cirq.MatrixGate(np.kron(*(cirq.unitary(cirq.H),) * 2), qid_shape=(4,)), False), + (cirq.MatrixGate(cirq.testing.random_unitary(4, random_state=1234)), False), + (cirq.XX ** sympy.Symbol("s"), True), + (cirq.CZ ** sympy.Symbol("s"), True), ], ) -def test_controlled_gate_is_consistent(gate: cirq.Gate): +def test_controlled_gate_is_consistent(gate: cirq.Gate, should_decompose_to_target): cgate = cirq.ControlledGate(gate) cirq.testing.assert_implements_consistent_protocols(cgate) + cirq.testing.assert_decompose_ends_at_default_gateset( + cgate, ignore_known_gates=not should_decompose_to_target + ) def test_pow_inverse(): diff --git a/cirq-core/cirq/ops/controlled_operation.py b/cirq-core/cirq/ops/controlled_operation.py index 1fb2dca8475..0cee798322e 100644 --- a/cirq-core/cirq/ops/controlled_operation.py +++ b/cirq-core/cirq/ops/controlled_operation.py @@ -30,7 +30,7 @@ from cirq import protocols, qis, value from cirq._compat import deprecated -from cirq.ops import raw_types, gate_operation, controlled_gate +from cirq.ops import raw_types, gate_operation, controlled_gate, matrix_gates from cirq.type_workarounds import NotImplementedType if TYPE_CHECKING: @@ -130,11 +130,22 @@ def with_qubits(self, *new_qubits): ) def _decompose_(self): + result = protocols.decompose_once_with_qubits(self.gate, self.qubits, NotImplemented) + if result is not NotImplemented: + return result + + if isinstance(self.sub_operation.gate, matrix_gates.MatrixGate): + # Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is + # local phase in the controlled variant and hence cannot be ignored. + return NotImplemented + result = protocols.decompose_once(self.sub_operation, NotImplemented) if result is NotImplemented: return NotImplemented - return [ControlledOperation(self.controls, op, self.control_values) for op in result] + return [ + op.controlled_by(*self.controls, control_values=self.control_values) for op in result + ] def _value_equality_values_(self): return (frozenset(zip(self.controls, self.control_values)), self.sub_operation) diff --git a/cirq-core/cirq/ops/controlled_operation_test.py b/cirq-core/cirq/ops/controlled_operation_test.py index 16c314b45fb..18198871cab 100644 --- a/cirq-core/cirq/ops/controlled_operation_test.py +++ b/cirq-core/cirq/ops/controlled_operation_test.py @@ -40,6 +40,9 @@ def __repr__(self): class GateAllocatingNewSpaceForResult(cirq.SingleQubitGate): + def __init__(self): + self._matrix = cirq.testing.random_unitary(2, random_state=1234) + def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> Union[np.ndarray, NotImplementedType]: assert len(args.axes) == 1 a = args.axes[0] @@ -47,12 +50,18 @@ def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> Union[np.ndarray, NotI zero = seed * a + (0, Ellipsis) one = seed * a + (1, Ellipsis) result = np.zeros(args.target_tensor.shape, args.target_tensor.dtype) - result[zero] = args.target_tensor[zero] * 2 + args.target_tensor[one] * 3 - result[one] = args.target_tensor[zero] * 5 + args.target_tensor[one] * 7 + result[zero] = ( + args.target_tensor[zero] * self._matrix[0][0] + + args.target_tensor[one] * self._matrix[0][1] + ) + result[one] = ( + args.target_tensor[zero] * self._matrix[1][0] + + args.target_tensor[one] * self._matrix[1][1] + ) return result def _unitary_(self): - return np.array([[2, 3], [5, 7]]) + return self._matrix def __eq__(self, other): return isinstance(other, type(self)) @@ -297,33 +306,82 @@ class UndiagrammableGate(cirq.SingleQubitGate): @pytest.mark.parametrize( - 'gate', + 'gate, should_decompose_to_target', [ - cirq.X(cirq.NamedQubit('q1')), - cirq.X(cirq.NamedQubit('q1')) ** 0.5, - cirq.rx(np.pi)(cirq.NamedQubit('q1')), - cirq.rx(np.pi / 2)(cirq.NamedQubit('q1')), - cirq.Z(cirq.NamedQubit('q1')), - cirq.H(cirq.NamedQubit('q1')), - cirq.CNOT(cirq.NamedQubit('q1'), cirq.NamedQubit('q2')), - cirq.SWAP(cirq.NamedQubit('q1'), cirq.NamedQubit('q2')), - cirq.CCZ(cirq.NamedQubit('q1'), cirq.NamedQubit('q2'), cirq.NamedQubit('q3')), - cirq.ControlledGate(cirq.ControlledGate(cirq.CCZ))(*cirq.LineQubit.range(5)), - GateUsingWorkspaceForApplyUnitary()(cirq.NamedQubit('q1')), - GateAllocatingNewSpaceForResult()(cirq.NamedQubit('q1')), + (cirq.X(cirq.NamedQubit('q1')), True), + (cirq.X(cirq.NamedQubit('q1')) ** 0.5, True), + (cirq.rx(np.pi)(cirq.NamedQubit('q1')), True), + (cirq.rx(np.pi / 2)(cirq.NamedQubit('q1')), True), + (cirq.Z(cirq.NamedQubit('q1')), True), + (cirq.H(cirq.NamedQubit('q1')), True), + (cirq.CNOT(cirq.NamedQubit('q1'), cirq.NamedQubit('q2')), True), + (cirq.SWAP(cirq.NamedQubit('q1'), cirq.NamedQubit('q2')), True), + (cirq.CCZ(cirq.NamedQubit('q1'), cirq.NamedQubit('q2'), cirq.NamedQubit('q3')), True), + (cirq.ControlledGate(cirq.ControlledGate(cirq.CCZ))(*cirq.LineQubit.range(5)), True), + (GateUsingWorkspaceForApplyUnitary()(cirq.NamedQubit('q1')), True), + (GateAllocatingNewSpaceForResult()(cirq.NamedQubit('q1')), True), + ( + cirq.MatrixGate(np.kron(*(cirq.unitary(cirq.H),) * 2), qid_shape=(4,)).on( + cirq.NamedQid("q", 4) + ), + False, + ), + ( + cirq.MatrixGate(cirq.testing.random_unitary(4, random_state=1234)).on( + cirq.NamedQubit('q1'), cirq.NamedQubit('q2') + ), + False, + ), + (cirq.XX(cirq.NamedQubit('q1'), cirq.NamedQubit('q2')) ** sympy.Symbol("s"), True), + (cirq.DiagonalGate(sympy.symbols("s1, s2")).on(cirq.NamedQubit("q")), False), ], ) -def test_controlled_operation_is_consistent(gate: cirq.GateOperation): +def test_controlled_operation_is_consistent( + gate: cirq.GateOperation, should_decompose_to_target: bool +): cb = cirq.NamedQubit('ctr') cgate = cirq.ControlledOperation([cb], gate) cirq.testing.assert_implements_consistent_protocols(cgate) + cirq.testing.assert_decompose_ends_at_default_gateset( + cgate, ignore_known_gates=not should_decompose_to_target + ) cgate = cirq.ControlledOperation([cb], gate, control_values=[0]) cirq.testing.assert_implements_consistent_protocols(cgate) + cirq.testing.assert_decompose_ends_at_default_gateset( + cgate, ignore_known_gates=(not should_decompose_to_target or cirq.is_parameterized(gate)) + ) + + cgate = cirq.ControlledOperation([cb], gate, control_values=[(0, 1)]) + cirq.testing.assert_implements_consistent_protocols(cgate) + cirq.testing.assert_decompose_ends_at_default_gateset( + cgate, ignore_known_gates=(not should_decompose_to_target or cirq.is_parameterized(gate)) + ) cb3 = cb.with_dimension(3) cgate = cirq.ControlledOperation([cb3], gate, control_values=[(0, 2)]) cirq.testing.assert_implements_consistent_protocols(cgate) + cirq.testing.assert_decompose_ends_at_default_gateset(cgate) + + +def test_controlled_circuit_operation_is_consistent(): + op = cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.XXPowGate(exponent=0.25, global_shift=-0.5).on(*cirq.LineQubit.range(2)) + ) + ) + cb = cirq.NamedQubit('ctr') + cop = cirq.ControlledOperation([cb], op) + cirq.testing.assert_implements_consistent_protocols(cop, exponents=(-1, 1, 2)) + cirq.testing.assert_decompose_ends_at_default_gateset(cop) + + cop = cirq.ControlledOperation([cb], op, control_values=[0]) + cirq.testing.assert_implements_consistent_protocols(cop, exponents=(-1, 1, 2)) + cirq.testing.assert_decompose_ends_at_default_gateset(cop) + + cop = cirq.ControlledOperation([cb], op, control_values=[(0, 1)]) + cirq.testing.assert_implements_consistent_protocols(cop, exponents=(-1, 1, 2)) + cirq.testing.assert_decompose_ends_at_default_gateset(cop) @pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once]) diff --git a/cirq-core/cirq/testing/consistent_decomposition.py b/cirq-core/cirq/testing/consistent_decomposition.py index 313c074f14a..c66eead8c2f 100644 --- a/cirq-core/cirq/testing/consistent_decomposition.py +++ b/cirq-core/cirq/testing/consistent_decomposition.py @@ -53,16 +53,22 @@ def _known_gate_with_no_decomposition(val: Any): """Checks whether `val` is a known gate with no default decomposition to default gateset.""" if isinstance(val, ops.MatrixGate): return protocols.qid_shape(val) not in [(2,), (2,) * 2, (2,) * 3] + if isinstance(val, ops.ControlledGate): + if protocols.is_parameterized(val): + return True + if isinstance(val.sub_gate, ops.MatrixGate) and protocols.num_qubits(val.sub_gate) > 1: + return True + if val.control_qid_shape != (2,) * val.num_controls(): + return True + return _known_gate_with_no_decomposition(val.sub_gate) return False -def assert_decompose_ends_at_default_gateset(val: Any): +def assert_decompose_ends_at_default_gateset(val: Any, ignore_known_gates: bool = True): """Asserts that cirq.decompose(val) ends at default cirq gateset or a known gate.""" - if _known_gate_with_no_decomposition(val): - return # coverage: ignore args = () if isinstance(val, ops.Operation) else (tuple(devices.LineQid.for_gate(val)),) dec_once = protocols.decompose_once(val, [val(*args[0]) if args else val], *args) for op in [*ops.flatten_to_ops(protocols.decompose(d) for d in dec_once)]: - assert _known_gate_with_no_decomposition(op.gate) or ( + assert (_known_gate_with_no_decomposition(op.gate) and ignore_known_gates) or ( op in protocols.decompose_protocol.DECOMPOSE_TARGET_GATESET ), f'{val} decomposed to {op}, which is not part of default cirq target gateset.'