Skip to content

Commit

Permalink
Bugfixes in handling nested tags_to_ignore + deep=True in `cirq.map_m…
Browse files Browse the repository at this point in the history
…oments` and `cirq.map_operations` transformer primitives (#5109)

- Fixes a few more bugs in the handling of deep=True flag and nested operations to ignore using `tags_to_ignore` in `cirq.map_operations` and `cirq.map_moments` transformer primitives. Also added more tests. 
- Step towards fixing #5039
  • Loading branch information
tanujkhattar authored Mar 21, 2022
1 parent e0a64dd commit aed4eb8
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 4 deletions.
9 changes: 7 additions & 2 deletions cirq-core/cirq/transformers/transformer_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def map_moments(
continue
op_untagged = cast(circuits.CircuitOperation, op.untagged)
mapped_op = op_untagged.replace(
circuit=map_moments(op_untagged.circuit, map_func, deep=deep)
circuit=map_moments(
op_untagged.circuit, map_func, tags_to_ignore=tags_to_ignore, deep=deep
)
).with_tags(*op.tags)
batch_replace.append((i, op, mapped_op))
mutable_circuit = circuit.unfreeze(copy=True)
Expand Down Expand Up @@ -149,7 +151,10 @@ def apply_map(op: ops.Operation, idx: int) -> ops.OP_TREE:
return circuit_op

return map_moments(
circuit, lambda m, i: [circuits.Moment(apply_map(op, i) for op in m.operations)], deep=deep
circuit,
lambda m, i: [circuits.Moment(apply_map(op, i) for op in m.operations)],
deep=deep,
tags_to_ignore=tags_to_ignore,
)


Expand Down
57 changes: 55 additions & 2 deletions cirq-core/cirq/transformers/transformer_primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,47 @@ def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE:
# pylint: enable=line-too-long


def test_map_operations_deep_respects_tags_to_ignore():
q = cirq.LineQubit.range(2)
c_nested = cirq.FrozenCircuit(cirq.CX(*q), cirq.CX(*q).with_tags("ignore"), cirq.CX(*q))
c_nested_mapped = cirq.FrozenCircuit(cirq.CZ(*q), cirq.CX(*q).with_tags("ignore"), cirq.CZ(*q))
c_orig = cirq.Circuit(
c_nested,
cirq.CircuitOperation(c_nested).repeat(4).with_tags("ignore"),
c_nested,
cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"),
cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"),
cirq.CircuitOperation(c_nested).repeat(7),
)
),
c_nested,
)
c_expected = cirq.Circuit(
c_nested_mapped,
cirq.CircuitOperation(c_nested).repeat(4).with_tags("ignore"),
c_nested_mapped,
cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.CircuitOperation(c_nested_mapped).repeat(5).with_tags("preserve_tag"),
cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"),
cirq.CircuitOperation(c_nested_mapped).repeat(7),
)
),
c_nested_mapped,
)
cirq.testing.assert_same_circuits(
cirq.map_operations(
c_orig,
lambda op, _: cirq.CZ(*op.qubits) if op.gate == cirq.CX else op,
tags_to_ignore=["ignore"],
deep=True,
),
c_expected,
)


def test_map_operations_respects_tags_to_ignore():
q = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.CNOT(*q), cirq.CNOT(*q).with_tags("ignore"), cirq.CNOT(*q))
Expand Down Expand Up @@ -402,17 +443,29 @@ def test_map_moments_drop_empty_moments():
def test_map_moments_drop_empty_moments_deep():
op = cirq.X(cirq.NamedQubit("q"))
c_nested = cirq.FrozenCircuit(cirq.Moment(op), cirq.Moment(), cirq.Moment(op))
circuit_op = cirq.CircuitOperation(c_nested).repeat(2)
circuit_op_dropped = cirq.CircuitOperation(cirq.FrozenCircuit([op, op])).repeat(2)
c_orig = cirq.Circuit(
c_nested,
cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"),
c_nested,
cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"),
cirq.CircuitOperation(
cirq.FrozenCircuit(circuit_op, circuit_op.with_tags("ignore"), circuit_op)
)
.repeat(5)
.with_tags("preserve_tag"),
)
c_expected = cirq.Circuit(
[op, op],
cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"),
[op, op],
cirq.CircuitOperation(cirq.FrozenCircuit([op, op])).repeat(5).with_tags("preserve_tag"),
cirq.CircuitOperation(
cirq.FrozenCircuit(
circuit_op_dropped, circuit_op.with_tags("ignore"), circuit_op_dropped
)
)
.repeat(5)
.with_tags("preserve_tag"),
)
c_mapped = cirq.map_moments(
c_orig, lambda m, i: [] if len(m) == 0 else [m], deep=True, tags_to_ignore=("ignore",)
Expand Down

0 comments on commit aed4eb8

Please sign in to comment.