Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cirq.eject_phased_paulis transformer to replace cirq.EjectPhasedPaulis #4958

Merged
merged 6 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@
decompose_two_qubit_interaction_into_four_fsim_gates,
drop_empty_moments,
drop_negligible_operations,
eject_phased_paulis,
eject_z,
expand_composite,
is_negligible_turn,
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ion/ion_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def two_qubit_matrix_to_ion_operations(
def _cleanup_operations(operations: List[ops.Operation]):
circuit = circuits.Circuit(operations)
optimizers.merge_single_qubit_gates.merge_single_qubit_gates_into_phased_x_z(circuit)
optimizers.eject_phased_paulis.EjectPhasedPaulis().optimize_circuit(circuit)
circuit = transformers.eject_phased_paulis(circuit)
circuit = transformers.eject_z(circuit)
circuit = circuits.Circuit(circuit.all_operations(), strategy=circuits.InsertStrategy.EARLIEST)
return list(circuit.all_operations())
Expand Down
306 changes: 7 additions & 299 deletions cirq-core/cirq/optimizers/eject_phased_paulis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,10 @@
"""Pushes 180 degree rotations around axes in the XY plane later in the circuit.
"""

from typing import Optional, cast, TYPE_CHECKING, Iterable, Tuple, Dict, List
import sympy

from cirq import circuits, ops, value, protocols
from cirq.transformers.analytical_decompositions import single_qubit_decompositions

if TYPE_CHECKING:
import cirq


class _OptimizerState:
def __init__(self):
# The phases of the W gates currently being pushed along each qubit.
self.held_w_phases: Dict[ops.Qid, value.TParamVal] = {}

# Accumulated commands to batch-apply to the circuit later.
self.deletions: List[Tuple[int, ops.Operation]] = []
self.inline_intos: List[Tuple[int, ops.Operation]] = []
self.insertions: List[Tuple[int, ops.Operation]] = []
from cirq import _compat, circuits, transformers


@_compat.deprecated_class(deadline='v1.0', fix='Use cirq.eject_phased_paulis instead.')
class EjectPhasedPaulis:
"""Pushes X, Y, and PhasedX gates towards the end of the circuit.

Expand All @@ -60,283 +43,8 @@ def __init__(self, tolerance: float = 1e-8, eject_parameterized: bool = False) -
self.eject_parameterized = eject_parameterized

def optimize_circuit(self, circuit: circuits.Circuit):
state = _OptimizerState()

for moment_index, moment in enumerate(circuit):
for op in moment.operations:
affected = [q for q in op.qubits if q in state.held_w_phases]

# Collect, phase, and merge Ws.
w = _try_get_known_phased_pauli(op, no_symbolic=not self.eject_parameterized)
if w is not None:
if single_qubit_decompositions.is_negligible_turn(
(w[0] - 1) / 2, self.tolerance
):
_potential_cross_whole_w(moment_index, op, self.tolerance, state)
else:
_potential_cross_partial_w(moment_index, op, state)
continue

if not affected:
continue

# Absorb Z rotations.
t = _try_get_known_z_half_turns(op, no_symbolic=not self.eject_parameterized)
if t is not None:
_absorb_z_into_w(moment_index, op, state)
continue

# Dump coherent flips into measurement bit flips.
if isinstance(op.gate, ops.MeasurementGate):
_dump_into_measurement(moment_index, op, state)

# Cross CZs using kickback.
if (
_try_get_known_cz_half_turns(op, no_symbolic=not self.eject_parameterized)
is not None
):
if len(affected) == 1:
_single_cross_over_cz(moment_index, op, affected[0], state)
else:
_double_cross_over_cz(op, state)
continue

# Don't know how to handle this situation. Dump the gates.
_dump_held(op.qubits, moment_index, state)

# Put anything that's still held at the end of the circuit.
_dump_held(state.held_w_phases.keys(), len(circuit), state)

circuit.batch_remove(state.deletions)
circuit.batch_insert_into(state.inline_intos)
circuit.batch_insert(state.insertions)


def _absorb_z_into_w(moment_index: int, op: ops.Operation, state: _OptimizerState) -> None:
"""Absorbs a Z^t gate into a W(a) flip.

[Where W(a) is shorthand for PhasedX(phase_exponent=a).]

Uses the following identity:
───W(a)───Z^t───
≡ ───W(a)───────────Z^t/2──────────Z^t/2─── (split Z)
≡ ───W(a)───W(a)───Z^-t/2───W(a)───Z^t/2─── (flip Z)
≡ ───W(a)───W(a)──────────W(a+t/2)───────── (phase W)
≡ ────────────────────────W(a+t/2)───────── (cancel Ws)
≡ ───W(a+t/2)───
"""
t = cast(value.TParamVal, _try_get_known_z_half_turns(op))
q = op.qubits[0]
state.held_w_phases[q] += t / 2
state.deletions.append((moment_index, op))


def _dump_held(qubits: Iterable[ops.Qid], moment_index: int, state: _OptimizerState):
# Note: sorting is to avoid non-determinism in the insertion order.
for q in sorted(qubits):
p = state.held_w_phases.get(q)
if p is not None:
dump_op = ops.PhasedXPowGate(phase_exponent=p).on(q)
state.insertions.append((moment_index, dump_op))
state.held_w_phases.pop(q, None)


def _dump_into_measurement(moment_index: int, op: ops.Operation, state: _OptimizerState) -> None:
measurement = cast(ops.MeasurementGate, cast(ops.GateOperation, op).gate)
new_measurement = measurement.with_bits_flipped(
*[i for i, q in enumerate(op.qubits) if q in state.held_w_phases]
).on(*op.qubits)
for q in op.qubits:
state.held_w_phases.pop(q, None)
state.deletions.append((moment_index, op))
state.inline_intos.append((moment_index, new_measurement))


def _potential_cross_whole_w(
moment_index: int, op: ops.Operation, tolerance: float, state: _OptimizerState
) -> None:
"""Grabs or cancels a held W gate against an existing W gate.

[Where W(a) is shorthand for PhasedX(phase_exponent=a).]

Uses the following identity:
───W(a)───W(b)───
≡ ───Z^-a───X───Z^a───Z^-b───X───Z^b───
≡ ───Z^-a───Z^-a───Z^b───X───X───Z^b───
≡ ───Z^-a───Z^-a───Z^b───Z^b───
≡ ───Z^2(b-a)───
"""
state.deletions.append((moment_index, op))

_, phase_exponent = cast(
Tuple[value.TParamVal, value.TParamVal], _try_get_known_phased_pauli(op)
)
q = op.qubits[0]
a = state.held_w_phases.get(q, None)
b = phase_exponent

if a is None:
# Collect the gate.
state.held_w_phases[q] = b
else:
# Cancel the gate.
del state.held_w_phases[q]
t = 2 * (b - a)
if not single_qubit_decompositions.is_negligible_turn(t / 2, tolerance):
leftover_phase = ops.Z(q) ** t
state.inline_intos.append((moment_index, leftover_phase))


def _potential_cross_partial_w(
moment_index: int, op: ops.Operation, state: _OptimizerState
) -> None:
"""Cross the held W over a partial W gate.

[Where W(a) is shorthand for PhasedX(phase_exponent=a).]

Uses the following identity:
───W(a)───W(b)^t───
≡ ───Z^-a───X───Z^a───W(b)^t────── (expand W(a))
≡ ───Z^-a───X───W(b-a)^t───Z^a──── (move Z^a across, phasing axis)
≡ ───Z^-a───W(a-b)^t───X───Z^a──── (move X across, negating axis angle)
≡ ───W(2a-b)^t───Z^-a───X───Z^a─── (move Z^-a across, phasing axis)
≡ ───W(2a-b)^t───W(a)───
"""
a = state.held_w_phases.get(op.qubits[0], None)
if a is None:
return
exponent, phase_exponent = cast(
Tuple[value.TParamVal, value.TParamVal], _try_get_known_phased_pauli(op)
)
new_op = ops.PhasedXPowGate(exponent=exponent, phase_exponent=2 * a - phase_exponent).on(
op.qubits[0]
)
state.deletions.append((moment_index, op))
state.inline_intos.append((moment_index, new_op))


def _single_cross_over_cz(
moment_index: int, op: ops.Operation, qubit_with_w: 'cirq.Qid', state: _OptimizerState
) -> None:
"""Crosses exactly one W flip over a partial CZ.

[Where W(a) is shorthand for PhasedX(phase_exponent=a).]

Uses the following identity:

──────────@─────
───W(a)───@^t───


≡ ───@──────O──────@────────────────────
| | │ (split into on/off cases)
───W(a)───W(a)───@^t──────────────────

≡ ───@─────────────@─────────────O──────
| │ | (off doesn't interact with on)
───W(a)──────────@^t───────────W(a)───

≡ ───────────Z^t───@──────@──────O──────
│ | | (crossing causes kickback)
─────────────────@^-t───W(a)───W(a)─── (X Z^t X Z^-t = exp(pi t) I)

≡ ───────────Z^t───@────────────────────
│ (merge on/off cases)
─────────────────@^-t───W(a)──────────

≡ ───Z^t───@──────────────
─────────@^-t───W(a)────
"""
t = cast(value.TParamVal, _try_get_known_cz_half_turns(op))
other_qubit = op.qubits[0] if qubit_with_w == op.qubits[1] else op.qubits[1]
negated_cz = ops.CZ(*op.qubits) ** -t
kickback = ops.Z(other_qubit) ** t

state.deletions.append((moment_index, op))
state.inline_intos.append((moment_index, negated_cz))
state.insertions.append((moment_index, kickback))


def _double_cross_over_cz(op: ops.Operation, state: _OptimizerState) -> None:
"""Crosses two W flips over a partial CZ.

[Where W(a) is shorthand for PhasedX(phase_exponent=a).]

Uses the following identity:

───W(a)───@─────
───W(b)───@^t───


≡ ──────────@────────────W(a)───
│ (single-cross top W over CZ)
───W(b)───@^-t─────────Z^t────


≡ ──────────@─────Z^-t───W(a)───
│ (single-cross bottom W over CZ)
──────────@^t───W(b)───Z^t────


≡ ──────────@─────W(a)───Z^t────
│ (flip over Z^-t)
──────────@^t───W(b)───Z^t────


≡ ──────────@─────W(a+t/2)──────
│ (absorb Zs into Ws)
──────────@^t───W(b+t/2)──────

≡ ───@─────W(a+t/2)───
───@^t───W(b+t/2)───
"""
t = cast(value.TParamVal, _try_get_known_cz_half_turns(op))
for q in op.qubits:
state.held_w_phases[q] = cast(value.TParamVal, state.held_w_phases[q]) + t / 2


def _try_get_known_cz_half_turns(
op: ops.Operation, no_symbolic: bool = False
) -> Optional[value.TParamVal]:
if not isinstance(op, ops.GateOperation) or not isinstance(op.gate, ops.CZPowGate):
return None
h = op.gate.exponent
if no_symbolic and isinstance(h, sympy.Basic):
return None
return h


def _try_get_known_phased_pauli(
op: ops.Operation, no_symbolic: bool = False
) -> Optional[Tuple[value.TParamVal, value.TParamVal]]:
if (no_symbolic and protocols.is_parameterized(op)) or not isinstance(op, ops.GateOperation):
return None
gate = op.gate

if isinstance(gate, ops.PhasedXPowGate):
e = gate.exponent
p = gate.phase_exponent
elif isinstance(gate, ops.YPowGate):
e = gate.exponent
p = 0.5
elif isinstance(gate, ops.XPowGate):
e = gate.exponent
p = 0.0
else:
return None
return value.canonicalize_half_turns(e), value.canonicalize_half_turns(p)


def _try_get_known_z_half_turns(
op: ops.Operation, no_symbolic: bool = False
) -> Optional[value.TParamVal]:
if not isinstance(op, ops.GateOperation) or not isinstance(op.gate, ops.ZPowGate):
return None
h = op.gate.exponent
if no_symbolic and isinstance(h, sympy.Basic):
return None
return h
circuit._moments = [
*transformers.eject_phased_paulis(
circuit, atol=self.tolerance, eject_parameterized=self.eject_parameterized
)
]
4 changes: 3 additions & 1 deletion cirq-core/cirq/optimizers/eject_phased_paulis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ def assert_optimizes(
compare_unitaries: bool = True,
eject_parameterized: bool = False,
):
opt = cirq.EjectPhasedPaulis(eject_parameterized=eject_parameterized)
with cirq.testing.assert_deprecated("Use cirq.eject_phased_paulis", deadline='v1.0'):
opt = cirq.EjectPhasedPaulis(eject_parameterized=eject_parameterized)

circuit = before.copy()
expected = cirq.drop_empty_moments(expected)
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
opt.optimize_circuit(circuit)

# They should have equivalent effects.
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/optimizers/merge_interactions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit):
# Ignore differences that would be caught by follow-up optimizations.
followup_optimizations: List[Callable[[cirq.Circuit], None]] = [
cirq.merge_single_qubit_gates_into_phased_x_z,
cirq.EjectPhasedPaulis().optimize_circuit,
]
for post in followup_optimizations:
post(actual)
post(expected)

followup_transformers: List[cirq.TRANSFORMER] = [
cirq.eject_phased_paulis,
cirq.eject_z,
cirq.drop_negligible_operations,
cirq.drop_empty_moments,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit, **kwargs):
# Ignore differences that would be caught by follow-up optimizations.
followup_optimizations: List[Callable[[cirq.Circuit], None]] = [
cirq.merge_single_qubit_gates_into_phased_x_z,
cirq.EjectPhasedPaulis().optimize_circuit,
]
for post in followup_optimizations:
post(actual)
post(expected)

followup_transformers: List[cirq.TRANSFORMER] = [
cirq.eject_phased_paulis,
cirq.eject_z,
cirq.drop_negligible_operations,
cirq.drop_empty_moments,
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@

from cirq.transformers.expand_composite import expand_composite

from cirq.transformers.eject_phased_paulis import eject_phased_paulis

from cirq.transformers.drop_empty_moments import drop_empty_moments

from cirq.transformers.drop_negligible_operations import drop_negligible_operations
Expand Down
Loading