Skip to content

Commit

Permalink
Add support for deep=True to cirq.optimize_for_target_gateset trans…
Browse files Browse the repository at this point in the history
…former (quantumlib#5124)

- Adds support for `deep=True` flag to `cirq.optimize_for_target_gateset` which enables optimizing circuits preserving the sub-circuit structure (i.e. without unrolling circuit operations).
- Part of quantumlib#5039
  • Loading branch information
tanujkhattar authored and rht committed May 1, 2023
1 parent ccdc1d7 commit 88917e1
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
19 changes: 17 additions & 2 deletions cirq-core/cirq/transformers/optimize_for_target_gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

"""Transformers to rewrite a circuit using gates from a given target gateset."""

from typing import Optional, Callable, TYPE_CHECKING
from typing import Optional, Callable, Hashable, Sequence, TYPE_CHECKING

from cirq import circuits
from cirq.protocols import decompose_protocol as dp
from cirq.transformers import transformer_api, transformer_primitives

Expand All @@ -38,6 +39,7 @@ def _decompose_operations_to_target_gateset(
gateset: Optional['cirq.Gateset'] = None,
decomposer: Callable[['cirq.Operation', int], dp.DecomposeResult] = lambda *_: NotImplemented,
ignore_failures: bool = True,
tags_to_decompose: Sequence[Hashable] = (),
) -> 'cirq.Circuit':
"""Decomposes every operation to `gateset` using `cirq.decompose` and `decomposer`.
Expand All @@ -56,6 +58,8 @@ def _decompose_operations_to_target_gateset(
- `None` or `NotImplemented` if does not know how to decompose a given `op`.
ignore_failures: If set, operations that fail to convert are left unchanged. If not set,
conversion failures raise a ValueError.
tags_to_decompose: `cirq.CircuitOperation`s tagged with any of `tags_to_decompose` will
be decomposed even if context.deep is True.
Returns:
An equivalent circuit containing gates accepted by `gateset`.
Expand All @@ -65,6 +69,13 @@ def _decompose_operations_to_target_gateset(
"""

def map_func(op: 'cirq.Operation', moment_index: int):
if (
context
and context.deep
and isinstance(op.untagged, circuits.CircuitOperation)
and set(op.tags).isdisjoint(tags_to_decompose)
):
return op
return dp.decompose(
op,
intercepting_decomposer=lambda o: decomposer(o, moment_index),
Expand All @@ -77,7 +88,10 @@ def map_func(op: 'cirq.Operation', moment_index: int):
)

return transformer_primitives.map_operations_and_unroll(
circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else ()
circuit,
map_func,
tags_to_ignore=context.tags_to_ignore if context else (),
deep=context.deep if context else False,
).unfreeze(copy=False)


Expand Down Expand Up @@ -122,6 +136,7 @@ def optimize_for_target_gateset(
gateset=gateset,
decomposer=gateset.decompose_to_target_gateset,
ignore_failures=ignore_failures,
tags_to_decompose=(gateset._intermediate_result_tag,),
)

for transformer in gateset.postprocess_transformers:
Expand Down
50 changes: 50 additions & 0 deletions cirq-core/cirq/transformers/optimize_for_target_gateset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,53 @@ def test_optimize_for_target_gateset():
_ = cirq.optimize_for_target_gateset(
c_orig, gateset=gateset, context=context, ignore_failures=False
)


def test_optimize_for_target_gateset_deep():
q0, q1 = cirq.LineQubit.range(2)
c_nested = cirq.FrozenCircuit(cirq.CX(q0, q1))
c_orig = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.H(q0), cirq.CircuitOperation(c_nested).repeat(3))
).repeat(5)
)
c_expected = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.single_qubit_matrix_to_phxz(cirq.unitary(cirq.H(q0))).on(q0),
cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.MatrixGate(c_nested.unitary(qubit_order=[q0, q1]), name="M").on(q0, q1)
)
).repeat(3),
)
).repeat(5)
)
gateset = MatrixGateTargetGateset()
context = cirq.TransformerContext(deep=True)
c_new = cirq.optimize_for_target_gateset(c_orig, gateset=gateset, context=context)
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c_new, c_expected)
cirq.testing.assert_has_diagram(
c_orig,
'''
[ [ 0: ───@─── ] ]
[ 0: ───H───[ │ ]──────────── ]
0: ───[ [ 1: ───X─── ](loops=3) ]────────────
[ │ ]
[ 1: ───────#2──────────────────────── ](loops=5)
1: ───#2──────────────────────────────────────────────────
''',
)
cirq.testing.assert_has_diagram(
c_new,
'''
[ [ 0: ───M[1]─── ] ]
[ 0: ───PhXZ(a=-0.5,x=0.5,z=-1)───[ │ ]──────────── ]
0: ───[ [ 1: ───M[2]─── ](loops=3) ]────────────
[ │ ]
[ 1: ─────────────────────────────#2─────────────────────────── ](loops=5)
1: ───#2───────────────────────────────────────────────────────────────────────────
''',
)

0 comments on commit 88917e1

Please sign in to comment.