diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index 9eb3add785e..a5035ec7f62 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -23,6 +23,7 @@ List, Optional, Sequence, + Union, TYPE_CHECKING, ) @@ -57,7 +58,7 @@ def _create_target_circuit_type(ops: ops.OP_TREE, target_circuit: CIRCUIT_TYPE) def map_moments( circuit: CIRCUIT_TYPE, - map_func: Callable[[ops.Moment, int], Sequence[ops.Moment]], + map_func: Callable[[ops.Moment, int], Union[ops.Moment, Sequence[ops.Moment]]], ) -> CIRCUIT_TYPE: """Applies local transformation on moments, by calling `map_func(moment)` for each moment. @@ -76,6 +77,8 @@ def map_moments( def map_operations( circuit: CIRCUIT_TYPE, map_func: Callable[[ops.Operation, int], ops.OP_TREE], + *, + raise_if_add_qubits=True, ) -> CIRCUIT_TYPE: """Applies local transformations on operations, by calling `map_func(op)` for each op. @@ -88,9 +91,12 @@ def map_operations( `cirq.CircuitOperation(cirq.FrozenCircuit(op_tree)).with_tags(MAPPED_CIRCUIT_OP_TAG)` to preserve moment structure. Utility methods like `cirq.unroll_circuit_op` can subsequently be used to unroll the mapped circuit operation. + raise_if_add_qubits: Set to True by default. If True, Raises ValueError if `map_func(op)` + adds operations on qubits outside of `op.qubits`. Raises: - ValueError if `issubset(qubit_set(map_func(op)), op.qubits) is False`. + ValueError if `issubset(qubit_set(map_func(op)), op.qubits) is False` and + `raise_if_add_qubits is True`. Returns: Copy of input circuit with mapped operations (wrapped in a tagged CircuitOperation). @@ -98,14 +104,15 @@ def map_operations( def apply_map(op: ops.Operation, idx: int) -> ops.OP_TREE: c = circuits.FrozenCircuit(map_func(op, idx)) - if not c.all_qubits().issubset(op.qubits): + if raise_if_add_qubits and not c.all_qubits().issubset(op.qubits): raise ValueError( f"Mapped operations {c.all_operations()} should act on a subset " f"of qubits of the original operation {op}" ) - if len(c) == 1: - # All operations act in the same moment; so we don't need to wrap them in a circuit_op. - return c[0].operations + if len(c) <= 1: + # Either empty circuit or all operations act in the same moment; + # So, we don't need to wrap them in a circuit_op. + return c[0].operations if c else [] circuit_op = circuits.CircuitOperation(c).with_tags(MAPPED_CIRCUIT_OP_TAG) return circuit_op diff --git a/cirq-core/cirq/transformers/transformer_primitives_test.py b/cirq-core/cirq/transformers/transformer_primitives_test.py index 7911fc5de35..b851f537d7a 100644 --- a/cirq-core/cirq/transformers/transformer_primitives_test.py +++ b/cirq-core/cirq/transformers/transformer_primitives_test.py @@ -195,6 +195,21 @@ def test_map_operations_raises_qubits_not_subset(): ) +def test_map_operations_can_add_qubits_if_flag_false(): + q = cirq.LineQubit.range(2) + c = cirq.Circuit(cirq.H(q[0])) + c_mapped = cirq.map_operations(c, lambda *_: cirq.CNOT(q[0], q[1]), raise_if_add_qubits=False) + cirq.testing.assert_same_circuits(c_mapped, cirq.Circuit(cirq.CNOT(q[0], q[1]))) + + +def test_map_operations_can_drop_operations(): + q = cirq.LineQubit.range(2) + c = cirq.Circuit(cirq.X(q[0]), cirq.Y(q[1]), cirq.X(q[1]), cirq.Y(q[0])) + c_mapped = cirq.map_operations(c, lambda op, _: op if op.gate == cirq.X else []) + c_expected = cirq.Circuit(cirq.Moment(cirq.X(q[0])), cirq.Moment(cirq.X(q[1]))) + cirq.testing.assert_same_circuits(c_mapped, c_expected) + + def test_map_moments_drop_empty_moments(): op = cirq.X(cirq.NamedQubit("x")) c = cirq.Circuit(cirq.Moment(op), cirq.Moment(), cirq.Moment(op))