From f0e92ba5fae5945d408e05dbf46d0a8577fea4a0 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 7 Feb 2022 00:23:42 -0800 Subject: [PATCH 1/4] Add cirq.eject_phased_paulis transformer to replace cirq.EjectPhasedPaulis --- cirq-core/cirq/__init__.py | 1 + cirq-core/cirq/ion/ion_decomposition.py | 2 +- .../cirq/optimizers/eject_phased_paulis.py | 307 +--------- .../optimizers/eject_phased_paulis_test.py | 4 +- .../optimizers/merge_interactions_test.py | 2 +- .../merge_interactions_to_sqrt_iswap_test.py | 2 +- cirq-core/cirq/transformers/__init__.py | 2 + .../two_qubit_to_cz.py | 4 +- .../cirq/transformers/eject_phased_paulis.py | 330 +++++++++++ .../transformers/eject_phased_paulis_test.py | 560 ++++++++++++++++++ .../optimizers/optimize_for_sycamore.py | 2 +- docs/tutorials/google/spin_echoes.ipynb | 6 +- 12 files changed, 913 insertions(+), 309 deletions(-) create mode 100644 cirq-core/cirq/transformers/eject_phased_paulis.py create mode 100644 cirq-core/cirq/transformers/eject_phased_paulis_test.py diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 640612f63c3..c48d89303b2 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -362,6 +362,7 @@ decompose_two_qubit_interaction_into_four_fsim_gates, drop_empty_moments, drop_negligible_operations, + eject_phased_paulis, expand_composite, is_negligible_turn, map_moments, diff --git a/cirq-core/cirq/ion/ion_decomposition.py b/cirq-core/cirq/ion/ion_decomposition.py index 0dfeda160b3..9c16cb17942 100644 --- a/cirq-core/cirq/ion/ion_decomposition.py +++ b/cirq-core/cirq/ion/ion_decomposition.py @@ -53,7 +53,7 @@ def two_qubit_matrix_to_ion_operations( def _cleanup_operations(operations: List[ops.Operation]): circuit = circuits.Circuit(operations) optimizers.merge_single_qubit_gates.merge_single_qubit_gates_into_phased_x_z(circuit) - optimizers.eject_phased_paulis.EjectPhasedPaulis().optimize_circuit(circuit) + circuit = transformers.eject_phased_paulis(circuit) optimizers.eject_z.EjectZ().optimize_circuit(circuit) circuit = circuits.Circuit(circuit.all_operations(), strategy=circuits.InsertStrategy.EARLIEST) return list(circuit.all_operations()) diff --git a/cirq-core/cirq/optimizers/eject_phased_paulis.py b/cirq-core/cirq/optimizers/eject_phased_paulis.py index c6c84ee8587..a846a532ea6 100644 --- a/cirq-core/cirq/optimizers/eject_phased_paulis.py +++ b/cirq-core/cirq/optimizers/eject_phased_paulis.py @@ -15,27 +15,11 @@ """Pushes 180 degree rotations around axes in the XY plane later in the circuit. """ -from typing import Optional, cast, TYPE_CHECKING, Iterable, Tuple, Dict, List -import sympy - -from cirq import circuits, ops, value, protocols -from cirq.transformers.analytical_decompositions import single_qubit_decompositions - -if TYPE_CHECKING: - import cirq - - -class _OptimizerState: - def __init__(self): - # The phases of the W gates currently being pushed along each qubit. - self.held_w_phases: Dict[ops.Qid, value.TParamVal] = {} - - # Accumulated commands to batch-apply to the circuit later. - self.deletions: List[Tuple[int, ops.Operation]] = [] - self.inline_intos: List[Tuple[int, ops.Operation]] = [] - self.insertions: List[Tuple[int, ops.Operation]] = [] +from cirq import circuits, transformers +from cirq._compat import deprecated_class +@deprecated_class(deadline='v1.0', fix='Use cirq.eject_phased_paulis instead.') class EjectPhasedPaulis: """Pushes X, Y, and PhasedX gates towards the end of the circuit. @@ -60,283 +44,8 @@ def __init__(self, tolerance: float = 1e-8, eject_parameterized: bool = False) - self.eject_parameterized = eject_parameterized def optimize_circuit(self, circuit: circuits.Circuit): - state = _OptimizerState() - - for moment_index, moment in enumerate(circuit): - for op in moment.operations: - affected = [q for q in op.qubits if q in state.held_w_phases] - - # Collect, phase, and merge Ws. - w = _try_get_known_phased_pauli(op, no_symbolic=not self.eject_parameterized) - if w is not None: - if single_qubit_decompositions.is_negligible_turn( - (w[0] - 1) / 2, self.tolerance - ): - _potential_cross_whole_w(moment_index, op, self.tolerance, state) - else: - _potential_cross_partial_w(moment_index, op, state) - continue - - if not affected: - continue - - # Absorb Z rotations. - t = _try_get_known_z_half_turns(op, no_symbolic=not self.eject_parameterized) - if t is not None: - _absorb_z_into_w(moment_index, op, state) - continue - - # Dump coherent flips into measurement bit flips. - if isinstance(op.gate, ops.MeasurementGate): - _dump_into_measurement(moment_index, op, state) - - # Cross CZs using kickback. - if ( - _try_get_known_cz_half_turns(op, no_symbolic=not self.eject_parameterized) - is not None - ): - if len(affected) == 1: - _single_cross_over_cz(moment_index, op, affected[0], state) - else: - _double_cross_over_cz(op, state) - continue - - # Don't know how to handle this situation. Dump the gates. - _dump_held(op.qubits, moment_index, state) - - # Put anything that's still held at the end of the circuit. - _dump_held(state.held_w_phases.keys(), len(circuit), state) - - circuit.batch_remove(state.deletions) - circuit.batch_insert_into(state.inline_intos) - circuit.batch_insert(state.insertions) - - -def _absorb_z_into_w(moment_index: int, op: ops.Operation, state: _OptimizerState) -> None: - """Absorbs a Z^t gate into a W(a) flip. - - [Where W(a) is shorthand for PhasedX(phase_exponent=a).] - - Uses the following identity: - ───W(a)───Z^t─── - ≡ ───W(a)───────────Z^t/2──────────Z^t/2─── (split Z) - ≡ ───W(a)───W(a)───Z^-t/2───W(a)───Z^t/2─── (flip Z) - ≡ ───W(a)───W(a)──────────W(a+t/2)───────── (phase W) - ≡ ────────────────────────W(a+t/2)───────── (cancel Ws) - ≡ ───W(a+t/2)─── - """ - t = cast(value.TParamVal, _try_get_known_z_half_turns(op)) - q = op.qubits[0] - state.held_w_phases[q] += t / 2 - state.deletions.append((moment_index, op)) - - -def _dump_held(qubits: Iterable[ops.Qid], moment_index: int, state: _OptimizerState): - # Note: sorting is to avoid non-determinism in the insertion order. - for q in sorted(qubits): - p = state.held_w_phases.get(q) - if p is not None: - dump_op = ops.PhasedXPowGate(phase_exponent=p).on(q) - state.insertions.append((moment_index, dump_op)) - state.held_w_phases.pop(q, None) - - -def _dump_into_measurement(moment_index: int, op: ops.Operation, state: _OptimizerState) -> None: - measurement = cast(ops.MeasurementGate, cast(ops.GateOperation, op).gate) - new_measurement = measurement.with_bits_flipped( - *[i for i, q in enumerate(op.qubits) if q in state.held_w_phases] - ).on(*op.qubits) - for q in op.qubits: - state.held_w_phases.pop(q, None) - state.deletions.append((moment_index, op)) - state.inline_intos.append((moment_index, new_measurement)) - - -def _potential_cross_whole_w( - moment_index: int, op: ops.Operation, tolerance: float, state: _OptimizerState -) -> None: - """Grabs or cancels a held W gate against an existing W gate. - - [Where W(a) is shorthand for PhasedX(phase_exponent=a).] - - Uses the following identity: - ───W(a)───W(b)─── - ≡ ───Z^-a───X───Z^a───Z^-b───X───Z^b─── - ≡ ───Z^-a───Z^-a───Z^b───X───X───Z^b─── - ≡ ───Z^-a───Z^-a───Z^b───Z^b─── - ≡ ───Z^2(b-a)─── - """ - state.deletions.append((moment_index, op)) - - _, phase_exponent = cast( - Tuple[value.TParamVal, value.TParamVal], _try_get_known_phased_pauli(op) - ) - q = op.qubits[0] - a = state.held_w_phases.get(q, None) - b = phase_exponent - - if a is None: - # Collect the gate. - state.held_w_phases[q] = b - else: - # Cancel the gate. - del state.held_w_phases[q] - t = 2 * (b - a) - if not single_qubit_decompositions.is_negligible_turn(t / 2, tolerance): - leftover_phase = ops.Z(q) ** t - state.inline_intos.append((moment_index, leftover_phase)) - - -def _potential_cross_partial_w( - moment_index: int, op: ops.Operation, state: _OptimizerState -) -> None: - """Cross the held W over a partial W gate. - - [Where W(a) is shorthand for PhasedX(phase_exponent=a).] - - Uses the following identity: - ───W(a)───W(b)^t─── - ≡ ───Z^-a───X───Z^a───W(b)^t────── (expand W(a)) - ≡ ───Z^-a───X───W(b-a)^t───Z^a──── (move Z^a across, phasing axis) - ≡ ───Z^-a───W(a-b)^t───X───Z^a──── (move X across, negating axis angle) - ≡ ───W(2a-b)^t───Z^-a───X───Z^a─── (move Z^-a across, phasing axis) - ≡ ───W(2a-b)^t───W(a)─── - """ - a = state.held_w_phases.get(op.qubits[0], None) - if a is None: - return - exponent, phase_exponent = cast( - Tuple[value.TParamVal, value.TParamVal], _try_get_known_phased_pauli(op) - ) - new_op = ops.PhasedXPowGate(exponent=exponent, phase_exponent=2 * a - phase_exponent).on( - op.qubits[0] - ) - state.deletions.append((moment_index, op)) - state.inline_intos.append((moment_index, new_op)) - - -def _single_cross_over_cz( - moment_index: int, op: ops.Operation, qubit_with_w: 'cirq.Qid', state: _OptimizerState -) -> None: - """Crosses exactly one W flip over a partial CZ. - - [Where W(a) is shorthand for PhasedX(phase_exponent=a).] - - Uses the following identity: - - ──────────@───── - │ - ───W(a)───@^t─── - - - ≡ ───@──────O──────@──────────────────── - | | │ (split into on/off cases) - ───W(a)───W(a)───@^t────────────────── - - ≡ ───@─────────────@─────────────O────── - | │ | (off doesn't interact with on) - ───W(a)──────────@^t───────────W(a)─── - - ≡ ───────────Z^t───@──────@──────O────── - │ | | (crossing causes kickback) - ─────────────────@^-t───W(a)───W(a)─── (X Z^t X Z^-t = exp(pi t) I) - - ≡ ───────────Z^t───@──────────────────── - │ (merge on/off cases) - ─────────────────@^-t───W(a)────────── - - ≡ ───Z^t───@────────────── - │ - ─────────@^-t───W(a)──── - """ - t = cast(value.TParamVal, _try_get_known_cz_half_turns(op)) - other_qubit = op.qubits[0] if qubit_with_w == op.qubits[1] else op.qubits[1] - negated_cz = ops.CZ(*op.qubits) ** -t - kickback = ops.Z(other_qubit) ** t - - state.deletions.append((moment_index, op)) - state.inline_intos.append((moment_index, negated_cz)) - state.insertions.append((moment_index, kickback)) - - -def _double_cross_over_cz(op: ops.Operation, state: _OptimizerState) -> None: - """Crosses two W flips over a partial CZ. - - [Where W(a) is shorthand for PhasedX(phase_exponent=a).] - - Uses the following identity: - - ───W(a)───@───── - │ - ───W(b)───@^t─── - - - ≡ ──────────@────────────W(a)─── - │ (single-cross top W over CZ) - ───W(b)───@^-t─────────Z^t──── - - - ≡ ──────────@─────Z^-t───W(a)─── - │ (single-cross bottom W over CZ) - ──────────@^t───W(b)───Z^t──── - - - ≡ ──────────@─────W(a)───Z^t──── - │ (flip over Z^-t) - ──────────@^t───W(b)───Z^t──── - - - ≡ ──────────@─────W(a+t/2)────── - │ (absorb Zs into Ws) - ──────────@^t───W(b+t/2)────── - - ≡ ───@─────W(a+t/2)─── - │ - ───@^t───W(b+t/2)─── - """ - t = cast(value.TParamVal, _try_get_known_cz_half_turns(op)) - for q in op.qubits: - state.held_w_phases[q] = cast(value.TParamVal, state.held_w_phases[q]) + t / 2 - - -def _try_get_known_cz_half_turns( - op: ops.Operation, no_symbolic: bool = False -) -> Optional[value.TParamVal]: - if not isinstance(op, ops.GateOperation) or not isinstance(op.gate, ops.CZPowGate): - return None - h = op.gate.exponent - if no_symbolic and isinstance(h, sympy.Basic): - return None - return h - - -def _try_get_known_phased_pauli( - op: ops.Operation, no_symbolic: bool = False -) -> Optional[Tuple[value.TParamVal, value.TParamVal]]: - if (no_symbolic and protocols.is_parameterized(op)) or not isinstance(op, ops.GateOperation): - return None - gate = op.gate - - if isinstance(gate, ops.PhasedXPowGate): - e = gate.exponent - p = gate.phase_exponent - elif isinstance(gate, ops.YPowGate): - e = gate.exponent - p = 0.5 - elif isinstance(gate, ops.XPowGate): - e = gate.exponent - p = 0.0 - else: - return None - return value.canonicalize_half_turns(e), value.canonicalize_half_turns(p) - - -def _try_get_known_z_half_turns( - op: ops.Operation, no_symbolic: bool = False -) -> Optional[value.TParamVal]: - if not isinstance(op, ops.GateOperation) or not isinstance(op.gate, ops.ZPowGate): - return None - h = op.gate.exponent - if no_symbolic and isinstance(h, sympy.Basic): - return None - return h + circuit._moments = [ + *transformers.eject_phased_paulis( + circuit, atol=self.tolerance, eject_parameterized=self.eject_parameterized + ) + ] diff --git a/cirq-core/cirq/optimizers/eject_phased_paulis_test.py b/cirq-core/cirq/optimizers/eject_phased_paulis_test.py index ad3f3bd67c8..04c1c2b1f98 100644 --- a/cirq-core/cirq/optimizers/eject_phased_paulis_test.py +++ b/cirq-core/cirq/optimizers/eject_phased_paulis_test.py @@ -26,9 +26,11 @@ def assert_optimizes( compare_unitaries: bool = True, eject_parameterized: bool = False, ): - opt = cirq.EjectPhasedPaulis(eject_parameterized=eject_parameterized) + with cirq.testing.assert_deprecated("Use cirq.eject_phased_paulis", deadline='v1.0'): + opt = cirq.EjectPhasedPaulis(eject_parameterized=eject_parameterized) circuit = before.copy() + expected = cirq.drop_empty_moments(expected) opt.optimize_circuit(circuit) # They should have equivalent effects. diff --git a/cirq-core/cirq/optimizers/merge_interactions_test.py b/cirq-core/cirq/optimizers/merge_interactions_test.py index f1dc637748e..49b352009aa 100644 --- a/cirq-core/cirq/optimizers/merge_interactions_test.py +++ b/cirq-core/cirq/optimizers/merge_interactions_test.py @@ -28,7 +28,6 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit): # Ignore differences that would be caught by follow-up optimizations. followup_optimizations: List[Callable[[cirq.Circuit], None]] = [ cirq.merge_single_qubit_gates_into_phased_x_z, - cirq.EjectPhasedPaulis().optimize_circuit, cirq.EjectZ().optimize_circuit, ] for post in followup_optimizations: @@ -36,6 +35,7 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit): post(expected) followup_transformers: List[cirq.TRANSFORMER] = [ + cirq.eject_phased_paulis, cirq.drop_negligible_operations, cirq.drop_empty_moments, ] diff --git a/cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py b/cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py index 3dbbcb4fdc2..772218ea1e4 100644 --- a/cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py +++ b/cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py @@ -39,7 +39,6 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit, **kwargs): # Ignore differences that would be caught by follow-up optimizations. followup_optimizations: List[Callable[[cirq.Circuit], None]] = [ cirq.merge_single_qubit_gates_into_phased_x_z, - cirq.EjectPhasedPaulis().optimize_circuit, cirq.EjectZ().optimize_circuit, ] for post in followup_optimizations: @@ -47,6 +46,7 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit, **kwargs): post(expected) followup_transformers: List[cirq.TRANSFORMER] = [ + cirq.eject_phased_paulis, cirq.drop_negligible_operations, cirq.drop_empty_moments, ] diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index 70fb448f9dd..30b305a9d33 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -45,6 +45,8 @@ from cirq.transformers.expand_composite import expand_composite +from cirq.transformers.eject_phased_paulis import eject_phased_paulis + from cirq.transformers.drop_empty_moments import drop_empty_moments from cirq.transformers.drop_negligible_operations import drop_negligible_operations diff --git a/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz.py b/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz.py index df2509a2440..ce35271ab39 100644 --- a/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz.py +++ b/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz.py @@ -23,9 +23,9 @@ from cirq import ops, linalg, protocols, circuits from cirq.transformers.analytical_decompositions import single_qubit_decompositions +from cirq.transformers.eject_phased_paulis import eject_phased_paulis from cirq.optimizers import ( eject_z, - eject_phased_paulis, merge_single_qubit_gates, ) @@ -164,7 +164,7 @@ def _xx_yy_zz_interaction_via_full_czs( def _cleanup_operations(operations: Sequence[ops.Operation]): circuit = circuits.Circuit(operations) merge_single_qubit_gates.merge_single_qubit_gates_into_phased_x_z(circuit) - eject_phased_paulis.EjectPhasedPaulis().optimize_circuit(circuit) + circuit = eject_phased_paulis(circuit) eject_z.EjectZ().optimize_circuit(circuit) circuit = circuits.Circuit(circuit.all_operations(), strategy=circuits.InsertStrategy.EARLIEST) return list(circuit.all_operations()) diff --git a/cirq-core/cirq/transformers/eject_phased_paulis.py b/cirq-core/cirq/transformers/eject_phased_paulis.py new file mode 100644 index 00000000000..3e299779a2a --- /dev/null +++ b/cirq-core/cirq/transformers/eject_phased_paulis.py @@ -0,0 +1,330 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer pass that pushes 180° rotations around axes in the XY plane later in the circuit.""" + +from typing import Optional, cast, TYPE_CHECKING, Iterable, Tuple, Dict +import sympy + +from cirq import circuits, ops, value, protocols +from cirq.transformers import transformer_api, transformer_primitives +from cirq.transformers.analytical_decompositions import single_qubit_decompositions + +if TYPE_CHECKING: + import cirq + + +@transformer_api.transformer +def eject_phased_paulis( + circuit: 'cirq.AbstractCircuit', + *, + context: Optional['cirq.TransformerContext'] = None, + atol: float = 1e-8, + eject_parameterized: bool = False, +) -> 'cirq.Circuit': + """Transformer pass to push X, Y, and PhasedX gates towards the end of the circuit. + + As the gates get pushed, they may absorb Z gates, cancel against other + X, Y, or PhasedX gates with exponent=1, get merged into measurements (as + output bit flips), and cause phase kickback operations across CZs (which can + then be removed by the EjectZ optimization). + Args: + circuit: Input circuit to transform. + context: `cirq.TransformerContext` storing common configurable options for transformers. + atol: Maximum absolute error tolerance. The optimization is permitted to simply drop + negligible combinations gates with a threshold determined by this tolerance. + eject_parameterized: If True, the optimization will attempt to eject parameterized gates + as well. This may result in other gates parameterized by symbolic expressions. + Returns: + Copy of the transformed input circuit. + """ + held_w_phases: Dict[ops.Qid, value.TParamVal] = {} + tags_to_ignore = set(context.tags_to_ignore) if context else set() + + def map_func(op: 'cirq.Operation', _: int) -> 'cirq.OP_TREE': + # Dump if `op` marked with a no compile tag. + if set(op.tags) & tags_to_ignore: + return [_dump_held(op.qubits, held_w_phases), op] + + # Collect, phase, and merge Ws. + w = _try_get_known_phased_pauli(op, no_symbolic=not eject_parameterized) + if w is not None: + return ( + _potential_cross_whole_w(op, atol, held_w_phases) + if single_qubit_decompositions.is_negligible_turn((w[0] - 1) / 2, atol) + else _potential_cross_partial_w(op, held_w_phases) + ) + + affected = [q for q in op.qubits if q in held_w_phases] + if not affected: + return op + + # Absorb Z rotations. + t = _try_get_known_z_half_turns(op, no_symbolic=not eject_parameterized) + if t is not None: + return _absorb_z_into_w(op, held_w_phases) + + # Dump coherent flips into measurement bit flips. + if isinstance(op.gate, ops.MeasurementGate): + return _dump_into_measurement(op, held_w_phases) + + # Cross CZs using kickback. + if _try_get_known_cz_half_turns(op, no_symbolic=not eject_parameterized) is not None: + return ( + _single_cross_over_cz(op, affected[0]) + if len(affected) == 1 + else _double_cross_over_cz(op, held_w_phases) + ) + + # Don't know how to handle this situation. Dump the gates. + return [_dump_held(op.qubits, held_w_phases), op] + + # Map operations and put anything that's still held at the end of the circuit. + return circuits.Circuit( + transformer_primitives.map_operations_and_unroll(circuit, map_func), + _dump_held(held_w_phases.keys(), held_w_phases), + ) + + +def _absorb_z_into_w( + op: ops.Operation, held_w_phases: Dict[ops.Qid, value.TParamVal] +) -> 'cirq.OP_TREE': + """Absorbs a Z^t gate into a W(a) flip. + + [Where W(a) is shorthand for PhasedX(phase_exponent=a).] + + Uses the following identity: + ───W(a)───Z^t─── + ≡ ───W(a)───────────Z^t/2──────────Z^t/2─── (split Z) + ≡ ───W(a)───W(a)───Z^-t/2───W(a)───Z^t/2─── (flip Z) + ≡ ───W(a)───W(a)──────────W(a+t/2)───────── (phase W) + ≡ ────────────────────────W(a+t/2)───────── (cancel Ws) + ≡ ───W(a+t/2)─── + """ + t = cast(value.TParamVal, _try_get_known_z_half_turns(op)) + q = op.qubits[0] + held_w_phases[q] += t / 2 + return [] + + +def _dump_held( + qubits: Iterable[ops.Qid], held_w_phases: Dict[ops.Qid, value.TParamVal] +) -> 'cirq.OP_TREE': + # Note: sorting is to avoid non-determinism in the insertion order. + for q in sorted(qubits): + p = held_w_phases.get(q) + if p is not None: + dump_op = ops.PhasedXPowGate(phase_exponent=p).on(q) + yield dump_op + held_w_phases.pop(q, None) + + +def _dump_into_measurement( + op: ops.Operation, held_w_phases: Dict[ops.Qid, value.TParamVal] +) -> 'cirq.OP_TREE': + measurement = cast(ops.MeasurementGate, cast(ops.GateOperation, op).gate) + new_measurement = measurement.with_bits_flipped( + *[i for i, q in enumerate(op.qubits) if q in held_w_phases] + ).on(*op.qubits) + for q in op.qubits: + held_w_phases.pop(q, None) + return new_measurement + + +def _potential_cross_whole_w( + op: ops.Operation, + atol: float, + held_w_phases: Dict[ops.Qid, value.TParamVal], +) -> 'cirq.OP_TREE': + """Grabs or cancels a held W gate against an existing W gate. + + [Where W(a) is shorthand for PhasedX(phase_exponent=a).] + + Uses the following identity: + ───W(a)───W(b)─── + ≡ ───Z^-a───X───Z^a───Z^-b───X───Z^b─── + ≡ ───Z^-a───Z^-a───Z^b───X───X───Z^b─── + ≡ ───Z^-a───Z^-a───Z^b───Z^b─── + ≡ ───Z^2(b-a)─── + """ + _, phase_exponent = cast( + Tuple[value.TParamVal, value.TParamVal], _try_get_known_phased_pauli(op) + ) + q = op.qubits[0] + a = held_w_phases.get(q, None) + b = phase_exponent + + if a is None: + # Collect the gate. + held_w_phases[q] = b + else: + # Cancel the gate. + del held_w_phases[q] + t = 2 * (b - a) + if not single_qubit_decompositions.is_negligible_turn(t / 2, atol): + return ops.Z(q) ** t + return [] + + +def _potential_cross_partial_w( + op: ops.Operation, + held_w_phases: Dict[ops.Qid, value.TParamVal], +) -> 'cirq.OP_TREE': + """Cross the held W over a partial W gate. + + [Where W(a) is shorthand for PhasedX(phase_exponent=a).] + + Uses the following identity: + ───W(a)───W(b)^t─── + ≡ ───Z^-a───X───Z^a───W(b)^t────── (expand W(a)) + ≡ ───Z^-a───X───W(b-a)^t───Z^a──── (move Z^a across, phasing axis) + ≡ ───Z^-a───W(a-b)^t───X───Z^a──── (move X across, negating axis angle) + ≡ ───W(2a-b)^t───Z^-a───X───Z^a─── (move Z^-a across, phasing axis) + ≡ ───W(2a-b)^t───W(a)─── + """ + a = held_w_phases.get(op.qubits[0], None) + if a is None: + return op + exponent, phase_exponent = cast( + Tuple[value.TParamVal, value.TParamVal], _try_get_known_phased_pauli(op) + ) + new_op = ops.PhasedXPowGate(exponent=exponent, phase_exponent=2 * a - phase_exponent).on( + op.qubits[0] + ) + return new_op + + +def _single_cross_over_cz(op: ops.Operation, qubit_with_w: 'cirq.Qid') -> 'cirq.OP_TREE': + """Crosses exactly one W flip over a partial CZ. + + [Where W(a) is shorthand for PhasedX(phase_exponent=a).] + + Uses the following identity: + + ──────────@───── + │ + ───W(a)───@^t─── + + + ≡ ───@──────O──────@──────────────────── + | | │ (split into on/off cases) + ───W(a)───W(a)───@^t────────────────── + + ≡ ───@─────────────@─────────────O────── + | │ | (off doesn't interact with on) + ───W(a)──────────@^t───────────W(a)─── + + ≡ ───────────Z^t───@──────@──────O────── + │ | | (crossing causes kickback) + ─────────────────@^-t───W(a)───W(a)─── (X Z^t X Z^-t = exp(pi t) I) + + ≡ ───────────Z^t───@──────────────────── + │ (merge on/off cases) + ─────────────────@^-t───W(a)────────── + + ≡ ───Z^t───@────────────── + │ + ─────────@^-t───W(a)──── + """ + t = cast(value.TParamVal, _try_get_known_cz_half_turns(op)) + other_qubit = op.qubits[0] if qubit_with_w == op.qubits[1] else op.qubits[1] + negated_cz = ops.CZ(*op.qubits) ** -t + kickback = ops.Z(other_qubit) ** t + return [kickback, negated_cz] + + +def _double_cross_over_cz( + op: ops.Operation, held_w_phases: Dict[ops.Qid, value.TParamVal] +) -> 'cirq.OP_TREE': + """Crosses two W flips over a partial CZ. + + [Where W(a) is shorthand for PhasedX(phase_exponent=a).] + + Uses the following identity: + + ───W(a)───@───── + │ + ───W(b)───@^t─── + + + ≡ ──────────@────────────W(a)─── + │ (single-cross top W over CZ) + ───W(b)───@^-t─────────Z^t──── + + + ≡ ──────────@─────Z^-t───W(a)─── + │ (single-cross bottom W over CZ) + ──────────@^t───W(b)───Z^t──── + + + ≡ ──────────@─────W(a)───Z^t──── + │ (flip over Z^-t) + ──────────@^t───W(b)───Z^t──── + + + ≡ ──────────@─────W(a+t/2)────── + │ (absorb Zs into Ws) + ──────────@^t───W(b+t/2)────── + + ≡ ───@─────W(a+t/2)─── + │ + ───@^t───W(b+t/2)─── + """ + t = cast(value.TParamVal, _try_get_known_cz_half_turns(op)) + for q in op.qubits: + held_w_phases[q] = cast(value.TParamVal, held_w_phases[q]) + t / 2 + return op + + +def _try_get_known_cz_half_turns( + op: ops.Operation, no_symbolic: bool = False +) -> Optional[value.TParamVal]: + if not isinstance(op.gate, ops.CZPowGate): + return None + h = op.gate.exponent + if no_symbolic and isinstance(h, sympy.Basic): + return None + return h + + +def _try_get_known_phased_pauli( + op: ops.Operation, no_symbolic: bool = False +) -> Optional[Tuple[value.TParamVal, value.TParamVal]]: + if no_symbolic and protocols.is_parameterized(op): + return None + gate = op.gate + + if isinstance(gate, ops.PhasedXPowGate): + e = gate.exponent + p = gate.phase_exponent + elif isinstance(gate, ops.YPowGate): + e = gate.exponent + p = 0.5 + elif isinstance(gate, ops.XPowGate): + e = gate.exponent + p = 0.0 + else: + return None + return value.canonicalize_half_turns(e), value.canonicalize_half_turns(p) + + +def _try_get_known_z_half_turns( + op: ops.Operation, no_symbolic: bool = False +) -> Optional[value.TParamVal]: + if not isinstance(op.gate, ops.ZPowGate): + return None + h = op.gate.exponent + if no_symbolic and isinstance(h, sympy.Basic): + return None + return h diff --git a/cirq-core/cirq/transformers/eject_phased_paulis_test.py b/cirq-core/cirq/transformers/eject_phased_paulis_test.py new file mode 100644 index 00000000000..de5ba512dfc --- /dev/null +++ b/cirq-core/cirq/transformers/eject_phased_paulis_test.py @@ -0,0 +1,560 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterable, cast + +import numpy as np +import pytest +import sympy + +import cirq + + +def assert_optimizes( + before: cirq.Circuit, + expected: cirq.Circuit, + compare_unitaries: bool = True, + eject_parameterized: bool = False, + *, + with_context: bool = False, +): + context = cirq.TransformerContext(tags_to_ignore=("nocompile",)) if with_context else None + circuit = cirq.eject_phased_paulis( + before, eject_parameterized=eject_parameterized, context=context + ) + + # They should have equivalent effects. + if compare_unitaries: + if cirq.is_parameterized(circuit): + for a in (0, 0.1, 0.5, -1.0, np.pi, np.pi / 2): + params = {'x': a, 'y': a / 2, 'z': -2 * a} + ( + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + cirq.resolve_parameters(circuit, params), + cirq.resolve_parameters(expected, params), + 1e-8, + ) + ) + else: + ( + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + circuit, expected, 1e-8 + ) + ) + + # And match the expected circuit. + assert circuit == expected, ( + "Circuit wasn't optimized as expected.\n" + "INPUT:\n" + "{}\n" + "\n" + "EXPECTED OUTPUT:\n" + "{}\n" + "\n" + "ACTUAL OUTPUT:\n" + "{}\n" + "\n" + "EXPECTED OUTPUT (detailed):\n" + "{!r}\n" + "\n" + "ACTUAL OUTPUT (detailed):\n" + "{!r}" + ).format(before, expected, circuit, expected, circuit) + + # And it should be idempotent. + circuit = cirq.eject_phased_paulis( + circuit, eject_parameterized=eject_parameterized, context=context + ) + assert circuit == expected + + +def quick_circuit(*moments: Iterable[cirq.OP_TREE]) -> cirq.Circuit: + return cirq.Circuit( + [cirq.Moment(cast(Iterable[cirq.Operation], cirq.flatten_op_tree(m))) for m in moments] + ) + + +def test_absorbs_z(): + q = cirq.NamedQubit('q') + x = sympy.Symbol('x') + + # Full Z. + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.125).on(q)], + [cirq.Z(q)], + ), + expected=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.625).on(q)], + ), + ) + + # Partial Z. + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.125).on(q)], + [cirq.S(q)], + ), + expected=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.375).on(q)], + ), + ) + + # parameterized Z. + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.125).on(q)], + [cirq.Z(q) ** x], + ), + expected=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.125 + x / 2).on(q)], + ), + eject_parameterized=True, + ) + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.125).on(q)], + [cirq.Z(q) ** (x + 1)], + ), + expected=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.625 + x / 2).on(q)], + ), + eject_parameterized=True, + ) + + # Multiple Zs. + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.125).on(q)], + [cirq.S(q)], + [cirq.T(q) ** -1], + ), + expected=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.25).on(q)], + ), + ) + + # Multiple Parameterized Zs. + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.125).on(q)], + [cirq.S(q) ** x], + [cirq.T(q) ** -x], + ), + expected=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.125 + x * 0.125).on(q)], + ), + eject_parameterized=True, + ) + + # Parameterized Phase and Partial Z + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=x).on(q)], + [cirq.S(q)], + ), + expected=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=x + 0.25).on(q)], + ), + eject_parameterized=True, + ) + + +def test_crosses_czs(): + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + x = sympy.Symbol('x') + y = sympy.Symbol('y') + z = sympy.Symbol('z') + + # Full CZ. + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.25).on(a)], + [cirq.CZ(a, b)], + ), + expected=quick_circuit( + [cirq.Z(b)], + [cirq.CZ(a, b)], + [cirq.PhasedXPowGate(phase_exponent=0.25).on(a)], + ), + ) + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.125).on(a)], + [cirq.CZ(b, a)], + ), + expected=quick_circuit( + [cirq.Z(b)], + [cirq.CZ(a, b)], + [cirq.PhasedXPowGate(phase_exponent=0.125).on(a)], + ), + ) + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=x).on(a)], + [cirq.CZ(b, a)], + ), + expected=quick_circuit( + [cirq.Z(b)], + [cirq.CZ(a, b)], + [cirq.PhasedXPowGate(phase_exponent=x).on(a)], + ), + eject_parameterized=True, + ) + + # Partial CZ. + assert_optimizes( + before=quick_circuit( + [cirq.X(a)], + [cirq.CZ(a, b) ** 0.25], + ), + expected=quick_circuit( + [cirq.Z(b) ** 0.25], + [cirq.CZ(a, b) ** -0.25], + [cirq.X(a)], + ), + ) + assert_optimizes( + before=quick_circuit( + [cirq.X(a)], + [cirq.CZ(a, b) ** x], + ), + expected=quick_circuit( + [cirq.Z(b) ** x], + [cirq.CZ(a, b) ** -x], + [cirq.X(a)], + ), + eject_parameterized=True, + ) + + # Double cross. + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.125).on(a)], + [cirq.PhasedXPowGate(phase_exponent=0.375).on(b)], + [cirq.CZ(a, b) ** 0.25], + ), + expected=quick_circuit( + [cirq.CZ(a, b) ** 0.25], + [ + cirq.PhasedXPowGate(phase_exponent=0.5).on(b), + cirq.PhasedXPowGate(phase_exponent=0.25).on(a), + ], + ), + ) + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=x).on(a)], + [cirq.PhasedXPowGate(phase_exponent=y).on(b)], + [cirq.CZ(a, b) ** z], + ), + expected=quick_circuit( + [cirq.CZ(a, b) ** z], + [ + cirq.PhasedXPowGate(phase_exponent=y + z / 2).on(b), + cirq.PhasedXPowGate(phase_exponent=x + z / 2).on(a), + ], + ), + eject_parameterized=True, + ) + + +def test_toggles_measurements(): + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + x = sympy.Symbol('x') + + # Single. + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.25).on(a)], + [cirq.measure(a, b)], + ), + expected=quick_circuit( + [cirq.measure(a, b, invert_mask=(True,))], + ), + ) + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.25).on(b)], + [cirq.measure(a, b)], + ), + expected=quick_circuit( + [cirq.measure(a, b, invert_mask=(False, True))], + ), + ) + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=x).on(b)], + [cirq.measure(a, b)], + ), + expected=quick_circuit( + [cirq.measure(a, b, invert_mask=(False, True))], + ), + eject_parameterized=True, + ) + + # Multiple. + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.25).on(a)], + [cirq.PhasedXPowGate(phase_exponent=0.25).on(b)], + [cirq.measure(a, b)], + ), + expected=quick_circuit( + [cirq.measure(a, b, invert_mask=(True, True))], + ), + ) + + # Xmon. + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.25).on(a)], + [cirq.measure(a, b, key='t')], + ), + expected=quick_circuit( + [cirq.measure(a, b, invert_mask=(True,), key='t')], + ), + ) + + +def test_cancels_other_full_w(): + q = cirq.NamedQubit('q') + x = sympy.Symbol('x') + y = sympy.Symbol('y') + + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.25).on(q)], + [cirq.PhasedXPowGate(phase_exponent=0.25).on(q)], + ), + expected=quick_circuit(), + ) + + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=x).on(q)], + [cirq.PhasedXPowGate(phase_exponent=x).on(q)], + ), + expected=quick_circuit(), + eject_parameterized=True, + ) + + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.25).on(q)], + [cirq.PhasedXPowGate(phase_exponent=0.125).on(q)], + ), + expected=quick_circuit( + [cirq.Z(q) ** -0.25], + ), + ) + + assert_optimizes( + before=quick_circuit( + [cirq.X(q)], + [cirq.PhasedXPowGate(phase_exponent=0.25).on(q)], + ), + expected=quick_circuit( + [cirq.Z(q) ** 0.5], + ), + ) + + assert_optimizes( + before=quick_circuit( + [cirq.Y(q)], + [cirq.PhasedXPowGate(phase_exponent=0.25).on(q)], + ), + expected=quick_circuit( + [cirq.Z(q) ** -0.5], + ), + ) + + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.25).on(q)], + [cirq.X(q)], + ), + expected=quick_circuit( + [cirq.Z(q) ** -0.5], + ), + ) + + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.25).on(q)], + [cirq.Y(q)], + ), + expected=quick_circuit( + [cirq.Z(q) ** 0.5], + ), + ) + + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=x).on(q)], + [cirq.PhasedXPowGate(phase_exponent=y).on(q)], + ), + expected=quick_circuit( + [cirq.Z(q) ** (2 * (y - x))], + ), + eject_parameterized=True, + ) + + +def test_phases_partial_ws(): + q = cirq.NamedQubit('q') + x = sympy.Symbol('x') + y = sympy.Symbol('y') + z = sympy.Symbol('z') + + assert_optimizes( + before=quick_circuit( + [cirq.X(q)], + [cirq.PhasedXPowGate(phase_exponent=0.25, exponent=0.5).on(q)], + ), + expected=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=-0.25, exponent=0.5).on(q)], + [cirq.X(q)], + ), + ) + + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.25).on(q)], + [cirq.X(q) ** 0.5], + ), + expected=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.5, exponent=0.5).on(q)], + [cirq.PhasedXPowGate(phase_exponent=0.25).on(q)], + ), + ) + + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.25).on(q)], + [cirq.PhasedXPowGate(phase_exponent=0.5, exponent=0.75).on(q)], + ), + expected=quick_circuit( + [cirq.X(q) ** 0.75], + [cirq.PhasedXPowGate(phase_exponent=0.25).on(q)], + ), + ) + + assert_optimizes( + before=quick_circuit( + [cirq.X(q)], [cirq.PhasedXPowGate(exponent=-0.25, phase_exponent=0.5).on(q)] + ), + expected=quick_circuit( + [cirq.PhasedXPowGate(exponent=-0.25, phase_exponent=-0.5).on(q)], + [cirq.X(q)], + ), + ) + + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=x).on(q)], + [cirq.PhasedXPowGate(phase_exponent=y, exponent=z).on(q)], + ), + expected=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=2 * x - y, exponent=z).on(q)], + [cirq.PhasedXPowGate(phase_exponent=x).on(q)], + ), + eject_parameterized=True, + ) + + +@pytest.mark.parametrize( + 'sym', + [ + sympy.Symbol('x'), + sympy.Symbol('x') + 1, + ], +) +def test_blocked_by_unknown_and_symbols(sym): + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + + assert_optimizes( + before=quick_circuit( + [cirq.X(a)], + [cirq.SWAP(a, b)], + [cirq.X(a)], + ), + expected=quick_circuit( + [cirq.X(a)], + [cirq.SWAP(a, b)], + [cirq.X(a)], + ), + ) + + assert_optimizes( + before=quick_circuit( + [cirq.X(a)], + [cirq.Z(a) ** sym], + [cirq.X(a)], + ), + expected=quick_circuit( + [cirq.X(a)], + [cirq.Z(a) ** sym], + [cirq.X(a)], + ), + compare_unitaries=False, + ) + + assert_optimizes( + before=quick_circuit( + [cirq.X(a)], + [cirq.CZ(a, b) ** sym], + [cirq.X(a)], + ), + expected=quick_circuit( + [cirq.X(a)], + [cirq.CZ(a, b) ** sym], + [cirq.X(a)], + ), + compare_unitaries=False, + ) + + +def test_blocked_by_nocompile_tag(): + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + + assert_optimizes( + before=quick_circuit( + [cirq.X(a)], + [cirq.CZ(a, b).with_tags("nocompile")], + [cirq.X(a)], + ), + expected=quick_circuit( + [cirq.X(a)], + [cirq.CZ(a, b).with_tags("nocompile")], + [cirq.X(a)], + ), + with_context=True, + ) + + +def test_zero_x_rotation(): + a = cirq.NamedQubit('a') + + assert_optimizes( + before=quick_circuit( + [cirq.rx(0)(a)], + ), + expected=quick_circuit( + [cirq.rx(0)(a)], + ), + ) diff --git a/cirq-google/cirq_google/optimizers/optimize_for_sycamore.py b/cirq-google/cirq_google/optimizers/optimize_for_sycamore.py index 58d6bce64c8..dc72b4aec5b 100644 --- a/cirq-google/cirq_google/optimizers/optimize_for_sycamore.py +++ b/cirq-google/cirq_google/optimizers/optimize_for_sycamore.py @@ -31,7 +31,6 @@ def _get_common_cleanup_optimizers(tolerance: float) -> List[Callable[[cirq.Circuit], None]]: return [ - cirq.EjectPhasedPaulis(tolerance=tolerance).optimize_circuit, cirq.EjectZ(tolerance=tolerance).optimize_circuit, ] @@ -166,6 +165,7 @@ def optimized_for_sycamore( for optimizer in opts: optimizer(copy) + copy = cirq.eject_phased_paulis(copy, atol=tolerance) copy = cirq.drop_negligible_operations(copy, atol=tolerance) ret = cirq.Circuit( diff --git a/docs/tutorials/google/spin_echoes.ipynb b/docs/tutorials/google/spin_echoes.ipynb index 969f3f65e1a..034c13cc20e 100644 --- a/docs/tutorials/google/spin_echoes.ipynb +++ b/docs/tutorials/google/spin_echoes.ipynb @@ -230,7 +230,7 @@ " # Run optimization.\n", " if with_optimization:\n", " cirq.MergeInteractionsToSqrtIswap().optimize_circuit(circuit)\n", - " cirq.EjectPhasedPaulis().optimize_circuit(circuit)\n", + " circuit = cirq.eject_phased_paulis(circuit)\n", " cirq.EjectZ().optimize_circuit(circuit)\n", " circuit = cirq.drop_negligible_operations(circuit)\n", " circuit = cirq.drop_empty_moments(circuit)\n", @@ -596,7 +596,7 @@ "id": "DiR_6NvV_mb6" }, "source": [ - "The `cirq.EjectPhasedPaulis` optimizer pushes `cirq.X`, `cirq.Y`, and `cirq.PhasedXPowGate` gates towards the end of the circuit." + "The `cirq.eject_phased_paulis` optimizer pushes `cirq.X`, `cirq.Y`, and `cirq.PhasedXPowGate` gates towards the end of the circuit." ] }, { @@ -645,7 +645,7 @@ } ], "source": [ - "cirq.EjectPhasedPaulis().optimize_circuit(circuit)\n", + "circuit = cirq.eject_phased_paulis(circuit)\n", "circuit" ] }, From 647e31c1911450ca2d51426ccdbcff53a1de592d Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 7 Feb 2022 15:12:16 -0800 Subject: [PATCH 2/4] Add CCO tests, support PhasedXZGates --- .../cirq/optimizers/eject_phased_paulis.py | 5 ++- .../cirq/transformers/eject_phased_paulis.py | 9 +++++ .../transformers/eject_phased_paulis_test.py | 34 +++++++++++++++++-- 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/optimizers/eject_phased_paulis.py b/cirq-core/cirq/optimizers/eject_phased_paulis.py index a846a532ea6..468c90c6730 100644 --- a/cirq-core/cirq/optimizers/eject_phased_paulis.py +++ b/cirq-core/cirq/optimizers/eject_phased_paulis.py @@ -15,11 +15,10 @@ """Pushes 180 degree rotations around axes in the XY plane later in the circuit. """ -from cirq import circuits, transformers -from cirq._compat import deprecated_class +from cirq import _compat, circuits, transformers -@deprecated_class(deadline='v1.0', fix='Use cirq.eject_phased_paulis instead.') +@_compat.deprecated_class(deadline='v1.0', fix='Use cirq.eject_phased_paulis instead.') class EjectPhasedPaulis: """Pushes X, Y, and PhasedX gates towards the end of the circuit. diff --git a/cirq-core/cirq/transformers/eject_phased_paulis.py b/cirq-core/cirq/transformers/eject_phased_paulis.py index 3e299779a2a..dc8e9384276 100644 --- a/cirq-core/cirq/transformers/eject_phased_paulis.py +++ b/cirq-core/cirq/transformers/eject_phased_paulis.py @@ -16,6 +16,7 @@ from typing import Optional, cast, TYPE_CHECKING, Iterable, Tuple, Dict import sympy +import numpy as np from cirq import circuits, ops, value, protocols from cirq.transformers import transformer_api, transformer_primitives @@ -39,6 +40,7 @@ def eject_phased_paulis( X, Y, or PhasedX gates with exponent=1, get merged into measurements (as output bit flips), and cause phase kickback operations across CZs (which can then be removed by the EjectZ optimization). + Args: circuit: Input circuit to transform. context: `cirq.TransformerContext` storing common configurable options for transformers. @@ -314,6 +316,13 @@ def _try_get_known_phased_pauli( elif isinstance(gate, ops.XPowGate): e = gate.exponent p = 0.0 + elif ( + isinstance(gate, ops.PhasedXZGate) + and not protocols.is_parameterized(gate.z_exponent) + and np.isclose(gate.z_exponent, 0) + ): + e = gate.x_exponent + p = gate.axis_phase_exponent else: return None return value.canonicalize_half_turns(e), value.canonicalize_half_turns(p) diff --git a/cirq-core/cirq/transformers/eject_phased_paulis_test.py b/cirq-core/cirq/transformers/eject_phased_paulis_test.py index de5ba512dfc..3ae66ece7de 100644 --- a/cirq-core/cirq/transformers/eject_phased_paulis_test.py +++ b/cirq-core/cirq/transformers/eject_phased_paulis_test.py @@ -99,10 +99,10 @@ def test_absorbs_z(): ), ) - # Partial Z. + # Partial Z. PhasedXZGate with z_exponent = 0. assert_optimizes( before=quick_circuit( - [cirq.PhasedXPowGate(phase_exponent=0.125).on(q)], + [cirq.PhasedXZGate(x_exponent=1, axis_phase_exponent=0.125, z_exponent=0).on(q)], [cirq.S(q)], ), expected=quick_circuit( @@ -328,6 +328,36 @@ def test_toggles_measurements(): ), ) + # CCOs + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.25).on(a)], + [cirq.measure(a, key="m")], + [cirq.X(b).with_classical_controls("m")], + ), + expected=quick_circuit( + [cirq.measure(a, invert_mask=(True,), key="m")], + [cirq.X(b).with_classical_controls("m")], + ), + compare_unitaries=False, + ) + + +def test_eject_phased_xz(): + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + c = cirq.Circuit( + cirq.PhasedXZGate(x_exponent=1, z_exponent=0.5, axis_phase_exponent=0.5).on(a), + cirq.CZ(a, b) ** 0.25, + ) + c_expected = cirq.Circuit( + cirq.CZ(a, b) ** -0.25, cirq.PhasedXPowGate(phase_exponent=0.75).on(a), cirq.T(b) + ) + cirq.testing.assert_same_circuits( + cirq.eject_z(cirq.eject_phased_paulis(cirq.eject_z(c))), c_expected + ) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c, c_expected, 1e-8) + def test_cancels_other_full_w(): q = cirq.NamedQubit('q') From 89f5b1c5dad52642acfb541b94f1658cbaa9d65d Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 7 Feb 2022 15:18:44 -0800 Subject: [PATCH 3/4] Improve docstrings --- cirq-core/cirq/transformers/eject_phased_paulis.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cirq-core/cirq/transformers/eject_phased_paulis.py b/cirq-core/cirq/transformers/eject_phased_paulis.py index dc8e9384276..f7de297fe9a 100644 --- a/cirq-core/cirq/transformers/eject_phased_paulis.py +++ b/cirq-core/cirq/transformers/eject_phased_paulis.py @@ -39,7 +39,10 @@ def eject_phased_paulis( As the gates get pushed, they may absorb Z gates, cancel against other X, Y, or PhasedX gates with exponent=1, get merged into measurements (as output bit flips), and cause phase kickback operations across CZs (which can - then be removed by the EjectZ optimization). + then be removed by the `cirq.eject_z` transformation). + + `cirq.PhasedXZGate` with `z_exponent=0` are also supported. To eject `PhasedXZGates` with + arbitrary `z_exponent`, run `cirq.eject_z(cirq.eject_phased_paulis(cirq.eject_z(circuit)))`. Args: circuit: Input circuit to transform. From 90d76239a611977f40492f3d317f66a4de3d095d Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 7 Feb 2022 19:55:54 -0800 Subject: [PATCH 4/4] Support PhasedXZGates equivalent to z rotations and update docstrings --- .../cirq/transformers/eject_phased_paulis.py | 20 ++++++++++++++----- .../transformers/eject_phased_paulis_test.py | 11 ++++++++++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/transformers/eject_phased_paulis.py b/cirq-core/cirq/transformers/eject_phased_paulis.py index f7de297fe9a..a354fe9f6f1 100644 --- a/cirq-core/cirq/transformers/eject_phased_paulis.py +++ b/cirq-core/cirq/transformers/eject_phased_paulis.py @@ -34,15 +34,17 @@ def eject_phased_paulis( atol: float = 1e-8, eject_parameterized: bool = False, ) -> 'cirq.Circuit': - """Transformer pass to push X, Y, and PhasedX gates towards the end of the circuit. + """Transformer pass to push X, Y, PhasedX & (certain) PhasedXZ gates to the end of the circuit. As the gates get pushed, they may absorb Z gates, cancel against other X, Y, or PhasedX gates with exponent=1, get merged into measurements (as output bit flips), and cause phase kickback operations across CZs (which can then be removed by the `cirq.eject_z` transformation). - `cirq.PhasedXZGate` with `z_exponent=0` are also supported. To eject `PhasedXZGates` with - arbitrary `z_exponent`, run `cirq.eject_z(cirq.eject_phased_paulis(cirq.eject_z(circuit)))`. + `cirq.PhasedXZGate` with `z_exponent=0` (i.e. equivalent to PhasedXPow) or with `x_exponent=0` + and `axis_phase_exponent=0` (i.e. equivalent to ZPowGate) are also supported. + To eject `PhasedXZGates` with arbitrary x/z/axis exponents, run + `cirq.eject_z(cirq.eject_phased_paulis(cirq.eject_z(circuit)))`. Args: circuit: Input circuit to transform. @@ -334,9 +336,17 @@ def _try_get_known_phased_pauli( def _try_get_known_z_half_turns( op: ops.Operation, no_symbolic: bool = False ) -> Optional[value.TParamVal]: - if not isinstance(op.gate, ops.ZPowGate): + g = op.gate + if ( + isinstance(g, ops.PhasedXZGate) + and np.isclose(g.x_exponent, 0) + and np.isclose(g.axis_phase_exponent, 0) + ): + h = g.z_exponent + elif isinstance(g, ops.ZPowGate): + h = g.exponent + else: return None - h = op.gate.exponent if no_symbolic and isinstance(h, sympy.Basic): return None return h diff --git a/cirq-core/cirq/transformers/eject_phased_paulis_test.py b/cirq-core/cirq/transformers/eject_phased_paulis_test.py index 3ae66ece7de..5f4663da5f4 100644 --- a/cirq-core/cirq/transformers/eject_phased_paulis_test.py +++ b/cirq-core/cirq/transformers/eject_phased_paulis_test.py @@ -99,6 +99,17 @@ def test_absorbs_z(): ), ) + # PhasedXZGate + assert_optimizes( + before=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.125).on(q)], + [cirq.PhasedXZGate(x_exponent=0, axis_phase_exponent=0, z_exponent=1).on(q)], + ), + expected=quick_circuit( + [cirq.PhasedXPowGate(phase_exponent=0.625).on(q)], + ), + ) + # Partial Z. PhasedXZGate with z_exponent = 0. assert_optimizes( before=quick_circuit(