Skip to content

Commit

Permalink
Speed up cirq.map_operations and cirq.map_operations_and_unroll (#…
Browse files Browse the repository at this point in the history
…6250)

* Speed up cirq.map_operations and cirq.map_operations_and_unroll

* Mypy typing and minor bug fixes

* Fix pylint

* Revert unrelated change to mypy script

* Address nits
  • Loading branch information
tanujkhattar authored Aug 24, 2023
1 parent be6218e commit 7ed95aa
Showing 1 changed file with 138 additions and 28 deletions.
166 changes: 138 additions & 28 deletions cirq-core/cirq/transformers/transformer_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,134 @@ def map_moments(
)


def _map_operations_impl(
circuit: CIRCUIT_TYPE,
map_func: Callable[[ops.Operation, int], ops.OP_TREE],
*,
deep: bool = False,
raise_if_add_qubits=True,
tags_to_ignore: Sequence[Hashable] = (),
wrap_in_circuit_op: bool = True,
) -> CIRCUIT_TYPE:
"""Applies local transformations, by calling `map_func(op, moment_index)` for each operation.
This method provides a fast, iterative implementation for the two `map_operations_*` variants
exposed as public transformer primitives. The high level idea for the iterative implementation
is to
1) For each operation `op`, find the corresponding mapped operation(s) `mapped_ops`. The
set of mapped operations can be either wrapped in a circuit operation or not, depending
on the value of flag `wrap_in_circuit_op` and whether the mapped operations will end up
occupying more than one moment or not.
2) Use the `get_earliest_accommodating_moment_index` infrastructure built for `cirq.Circuit`
construction to determine the index at which the mapped operations should be inserted.
This step takes care of the nuances that arise due to (a) preserving moment structure
and (b) mapped operations spanning across multiple moments (these both are trivial when
`op` is mapped to a single `mapped_op` that acts on the same set of qubits).
By default, the function assumes `issubset(qubit_set(map_func(op, moment_index)), op.qubits)` is
True.
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
map_func: Mapping function from (cirq.Operation, moment_index) to a cirq.OP_TREE. If the
resulting optree spans more than 1 moment, it's either wrapped in a tagged circuit
operation and inserted in-place in the same moment (if `wrap_in_circuit_op` is True)
OR the mapped operations are inserted directly in the circuit, preserving moment
strucutre. The effect is equivalent to (but much faster) a two-step approach of first
wrapping the operations in a circuit operation and then calling `cirq.unroll_circuit_op`
to unroll the corresponding circuit ops.
deep: If true, `map_func` will be recursively applied to circuits wrapped inside
any circuit operations contained within `circuit`.
raise_if_add_qubits: Set to True by default. If True, raises ValueError if
`map_func(op, idx)` adds operations on qubits outside of `op.qubits`.
tags_to_ignore: Sequence of tags which should be ignored while applying `map_func` on
tagged operations -- i.e. `map_func(op, idx)` will be called only for operations that
satisfy `set(op.tags).isdisjoint(tags_to_ignore)`.
wrap_in_circuit_op: If True, the mapped operations will be wrapped in a tagged circuit
operation and inserted in-place if they occupy more than one moment.
Raises:
ValueError if `issubset(qubit_set(map_func(op, idx)), op.qubits) is False` and
`raise_if_add_qubits is True`.
Returns:
Copy of input circuit with mapped operations.
"""
tags_to_ignore_set = set(tags_to_ignore)

def apply_map_func(op: 'cirq.Operation', idx: int) -> List['cirq.Operation']:
if tags_to_ignore_set.intersection(op.tags):
return [op]
if deep and isinstance(op.untagged, circuits.CircuitOperation):
op = op.untagged.replace(
circuit=_map_operations_impl(
op.untagged.circuit,
map_func,
deep=deep,
raise_if_add_qubits=raise_if_add_qubits,
tags_to_ignore=tags_to_ignore,
wrap_in_circuit_op=wrap_in_circuit_op,
)
).with_tags(*op.tags)
mapped_ops = [*ops.flatten_to_ops(map_func(op, idx))]
op_qubits = set(op.qubits)
mapped_ops_qubits: Set['cirq.Qid'] = set()
has_overlapping_ops = False
for mapped_op in mapped_ops:
if raise_if_add_qubits and not op_qubits.issuperset(mapped_op.qubits):
raise ValueError(
f"Mapped operations {mapped_ops} should act on a subset "
f"of qubits of the original operation {op}"
)
if mapped_ops_qubits.intersection(mapped_op.qubits):
has_overlapping_ops = True
mapped_ops_qubits = mapped_ops_qubits.union(mapped_op.qubits)
if wrap_in_circuit_op and has_overlapping_ops:
# Mapped operations should be wrapped in a `CircuitOperation` only iff they occupy more
# than one moment, i.e. there are at least two operations that share a qubit.
mapped_ops = [
circuits.CircuitOperation(circuits.FrozenCircuit(mapped_ops)).with_tags(
MAPPED_CIRCUIT_OP_TAG
)
]
return mapped_ops

new_moments: List[List['cirq.Operation']] = []

# Keep track of the latest time index for each qubit, measurement key, and control key.
qubit_time_index: Dict['cirq.Qid', int] = {}
measurement_time_index: Dict['cirq.MeasurementKey', int] = {}
control_time_index: Dict['cirq.MeasurementKey', int] = {}

# New mapped operations in the current moment should be inserted after `last_moment_time_index`.
last_moment_time_index = -1

for idx, moment in enumerate(circuit):
if wrap_in_circuit_op:
new_moments.append([])
for op in moment:
mapped_ops = apply_map_func(op, idx)

for mapped_op in mapped_ops:
# Identify the earliest moment that can accommodate this op.
placement_index = circuits.circuit.get_earliest_accommodating_moment_index(
mapped_op, qubit_time_index, measurement_time_index, control_time_index
)
placement_index = max(placement_index, last_moment_time_index + 1)
new_moments.extend([[] for _ in range(placement_index - len(new_moments) + 1)])
new_moments[placement_index].append(mapped_op)
for qubit in mapped_op.qubits:
qubit_time_index[qubit] = placement_index
for key in protocols.measurement_key_objs(mapped_op):
measurement_time_index[key] = placement_index
for key in protocols.control_keys(mapped_op):
control_time_index[key] = placement_index

last_moment_time_index = len(new_moments) - 1

return _create_target_circuit_type([circuits.Moment(moment) for moment in new_moments], circuit)


def map_operations(
circuit: CIRCUIT_TYPE,
map_func: Callable[[ops.Operation, int], ops.OP_TREE],
Expand Down Expand Up @@ -139,29 +267,13 @@ def map_operations(
Returns:
Copy of input circuit with mapped operations (wrapped in a tagged CircuitOperation).
"""

def apply_map(op: ops.Operation, idx: int) -> ops.OP_TREE:
if not set(op.tags).isdisjoint(tags_to_ignore):
return op
c = circuits.FrozenCircuit(map_func(op, idx))
if raise_if_add_qubits and not c.all_qubits().issubset(op.qubits):
raise ValueError(
f"Mapped operations {c.all_operations()} should act on a subset "
f"of qubits of the original operation {op}"
)
if len(c) <= 1:
# Either empty circuit or all operations act in the same moment;
# So, we don't need to wrap them in a circuit_op.
return c[0].operations if c else []
circuit_op = circuits.CircuitOperation(c).with_tags(MAPPED_CIRCUIT_OP_TAG)
return circuit_op

return map_moments(
return _map_operations_impl(
circuit,
lambda m, i: circuits.Circuit(apply_map(op, i) for op in m.operations).moments
or [circuits.Moment()],
map_func,
deep=deep,
raise_if_add_qubits=raise_if_add_qubits,
tags_to_ignore=tags_to_ignore,
wrap_in_circuit_op=True,
)


Expand Down Expand Up @@ -191,15 +303,13 @@ def map_operations_and_unroll(
Returns:
Copy of input circuit with mapped operations, unrolled in a moment preserving way.
"""
return unroll_circuit_op(
map_operations(
circuit,
map_func,
deep=deep,
raise_if_add_qubits=raise_if_add_qubits,
tags_to_ignore=tags_to_ignore,
),
return _map_operations_impl(
circuit,
map_func,
deep=deep,
raise_if_add_qubits=raise_if_add_qubits,
tags_to_ignore=tags_to_ignore,
wrap_in_circuit_op=False,
)


Expand Down

0 comments on commit 7ed95aa

Please sign in to comment.