diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index bf6239a56a1..38ec7bbde7f 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -87,7 +87,9 @@ def map_moments( continue op_untagged = cast(circuits.CircuitOperation, op.untagged) mapped_op = op_untagged.replace( - circuit=map_moments(op_untagged.circuit, map_func, deep=deep) + circuit=map_moments( + op_untagged.circuit, map_func, tags_to_ignore=tags_to_ignore, deep=deep + ) ).with_tags(*op.tags) batch_replace.append((i, op, mapped_op)) mutable_circuit = circuit.unfreeze(copy=True) @@ -149,7 +151,10 @@ def apply_map(op: ops.Operation, idx: int) -> ops.OP_TREE: return circuit_op return map_moments( - circuit, lambda m, i: [circuits.Moment(apply_map(op, i) for op in m.operations)], deep=deep + circuit, + lambda m, i: [circuits.Moment(apply_map(op, i) for op in m.operations)], + deep=deep, + tags_to_ignore=tags_to_ignore, ) diff --git a/cirq-core/cirq/transformers/transformer_primitives_test.py b/cirq-core/cirq/transformers/transformer_primitives_test.py index 9ac2a87b8a1..51f947c645a 100644 --- a/cirq-core/cirq/transformers/transformer_primitives_test.py +++ b/cirq-core/cirq/transformers/transformer_primitives_test.py @@ -213,6 +213,47 @@ def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE: # pylint: enable=line-too-long +def test_map_operations_deep_respects_tags_to_ignore(): + q = cirq.LineQubit.range(2) + c_nested = cirq.FrozenCircuit(cirq.CX(*q), cirq.CX(*q).with_tags("ignore"), cirq.CX(*q)) + c_nested_mapped = cirq.FrozenCircuit(cirq.CZ(*q), cirq.CX(*q).with_tags("ignore"), cirq.CZ(*q)) + c_orig = cirq.Circuit( + c_nested, + cirq.CircuitOperation(c_nested).repeat(4).with_tags("ignore"), + c_nested, + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"), + cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"), + cirq.CircuitOperation(c_nested).repeat(7), + ) + ), + c_nested, + ) + c_expected = cirq.Circuit( + c_nested_mapped, + cirq.CircuitOperation(c_nested).repeat(4).with_tags("ignore"), + c_nested_mapped, + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.CircuitOperation(c_nested_mapped).repeat(5).with_tags("preserve_tag"), + cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"), + cirq.CircuitOperation(c_nested_mapped).repeat(7), + ) + ), + c_nested_mapped, + ) + cirq.testing.assert_same_circuits( + cirq.map_operations( + c_orig, + lambda op, _: cirq.CZ(*op.qubits) if op.gate == cirq.CX else op, + tags_to_ignore=["ignore"], + deep=True, + ), + c_expected, + ) + + def test_map_operations_respects_tags_to_ignore(): q = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.CNOT(*q), cirq.CNOT(*q).with_tags("ignore"), cirq.CNOT(*q)) @@ -402,17 +443,29 @@ def test_map_moments_drop_empty_moments(): def test_map_moments_drop_empty_moments_deep(): op = cirq.X(cirq.NamedQubit("q")) c_nested = cirq.FrozenCircuit(cirq.Moment(op), cirq.Moment(), cirq.Moment(op)) + circuit_op = cirq.CircuitOperation(c_nested).repeat(2) + circuit_op_dropped = cirq.CircuitOperation(cirq.FrozenCircuit([op, op])).repeat(2) c_orig = cirq.Circuit( c_nested, cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"), c_nested, - cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"), + cirq.CircuitOperation( + cirq.FrozenCircuit(circuit_op, circuit_op.with_tags("ignore"), circuit_op) + ) + .repeat(5) + .with_tags("preserve_tag"), ) c_expected = cirq.Circuit( [op, op], cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"), [op, op], - cirq.CircuitOperation(cirq.FrozenCircuit([op, op])).repeat(5).with_tags("preserve_tag"), + cirq.CircuitOperation( + cirq.FrozenCircuit( + circuit_op_dropped, circuit_op.with_tags("ignore"), circuit_op_dropped + ) + ) + .repeat(5) + .with_tags("preserve_tag"), ) c_mapped = cirq.map_moments( c_orig, lambda m, i: [] if len(m) == 0 else [m], deep=True, tags_to_ignore=("ignore",)