Skip to content

Commit

Permalink
Add support for deep=True to cirq.expand_composite transformer (qua…
Browse files Browse the repository at this point in the history
…ntumlib#5119)

- Adds support to recursively run `cirq.expand_composite` transformer on circuits wrapped inside a circuit operation by setting deep=True in transformer context.
- Note that this does not rely on `preserve_structure` argument of `protocols.decompose` because the latter does not support handling nested circuit operations tagged with a no-compile tag (the added tests would fail if we rely on protocols.decompose(preserve_structure=True) instead transformer primitives). Hence, I would argue that we should deprecate the preserve_structure=True flag in protocols.decompose in-favour of this transformer. cc @95-martin-orion 
- Part of quantumlib#5039
  • Loading branch information
tanujkhattar authored Mar 21, 2022
1 parent 32c0e02 commit 32afaa0
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 3 deletions.
15 changes: 12 additions & 3 deletions cirq/transformers/expand_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import Callable, Optional, TYPE_CHECKING

from cirq import ops, protocols
from cirq import circuits, ops, protocols
from cirq.transformers import transformer_api, transformer_primitives

if TYPE_CHECKING:
Expand Down Expand Up @@ -49,8 +49,17 @@ def expand_composite(
"""

def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
return protocols.decompose(op, keep=no_decomp, on_stuck_raise=None)
if context and context.deep and isinstance(op.untagged, circuits.CircuitOperation):
return op
return protocols.decompose(
op,
keep=no_decomp,
on_stuck_raise=None,
)

return transformer_primitives.map_operations_and_unroll(
circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else ()
circuit,
map_func,
tags_to_ignore=context.tags_to_ignore if context else (),
deep=context.deep if context else False,
).unfreeze(copy=False)
72 changes: 72 additions & 0 deletions cirq/transformers/expand_composite_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,75 @@ def test_do_not_decompose_no_compile():
c = cirq.Circuit(cirq.CNOT(q0, q1).with_tags("no_compile"))
context = cirq.TransformerContext(tags_to_ignore=("no_compile",))
assert_equal_mod_empty(c, cirq.expand_composite(c, context=context))


def test_expands_composite_recursively_preserving_structur():
q = cirq.LineQubit.range(2)
c_nested = cirq.FrozenCircuit(
cirq.SWAP(*q[:2]), cirq.SWAP(*q[:2]).with_tags("ignore"), cirq.SWAP(*q[:2])
)
c_nested_expanded = cirq.FrozenCircuit(
[cirq.CNOT(*q), cirq.CNOT(*q[::-1]), cirq.CNOT(*q)],
cirq.SWAP(*q[:2]).with_tags("ignore"),
[cirq.CNOT(*q), cirq.CNOT(*q[::-1]), cirq.CNOT(*q)],
)
c_orig = cirq.Circuit(
c_nested,
cirq.CircuitOperation(
cirq.FrozenCircuit(
c_nested,
cirq.CircuitOperation(c_nested).repeat(5).with_tags("ignore"),
cirq.CircuitOperation(c_nested).repeat(6).with_tags("preserve_tag"),
cirq.CircuitOperation(c_nested).repeat(7),
c_nested,
)
)
.repeat(4)
.with_tags("ignore"),
c_nested,
cirq.CircuitOperation(
cirq.FrozenCircuit(
c_nested,
cirq.CircuitOperation(c_nested).repeat(5).with_tags("ignore"),
cirq.CircuitOperation(c_nested).repeat(6).with_tags("preserve_tag"),
cirq.CircuitOperation(c_nested).repeat(7),
c_nested,
)
)
.repeat(5)
.with_tags("preserve_tag"),
c_nested,
)
c_expected = cirq.Circuit(
c_nested_expanded,
cirq.CircuitOperation(
cirq.FrozenCircuit(
c_nested,
cirq.CircuitOperation(c_nested).repeat(5).with_tags("ignore"),
cirq.CircuitOperation(c_nested).repeat(6).with_tags("preserve_tag"),
cirq.CircuitOperation(c_nested).repeat(7),
c_nested,
)
)
.repeat(4)
.with_tags("ignore"),
c_nested_expanded,
cirq.CircuitOperation(
cirq.FrozenCircuit(
c_nested_expanded,
cirq.CircuitOperation(c_nested).repeat(5).with_tags("ignore"),
cirq.CircuitOperation(c_nested_expanded).repeat(6).with_tags("preserve_tag"),
cirq.CircuitOperation(c_nested_expanded).repeat(7),
c_nested_expanded,
)
)
.repeat(5)
.with_tags("preserve_tag"),
c_nested_expanded,
)

context = cirq.TransformerContext(tags_to_ignore=["ignore"], deep=True)
c_expanded = cirq.expand_composite(
c_orig, no_decomp=lambda op: op.gate == cirq.CNOT, context=context
)
cirq.testing.assert_same_circuits(c_expanded, c_expected)

0 comments on commit 32afaa0

Please sign in to comment.