Skip to content

Commit

Permalink
Add support for deep=True to cirq.merge_k_qubit_unitaries transformer (
Browse files Browse the repository at this point in the history
  • Loading branch information
tanujkhattar authored Mar 22, 2022
1 parent 64a6723 commit 92d19f6
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 15 deletions.
64 changes: 49 additions & 15 deletions cirq-core/cirq/transformers/merge_k_qubit_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,47 @@
import cirq


def _rewrite_merged_k_qubit_unitaries(
circuit: 'cirq.AbstractCircuit',
*,
context: Optional['cirq.TransformerContext'] = None,
k: int = 0,
rewriter: Optional[Callable[['cirq.CircuitOperation'], 'cirq.OP_TREE']] = None,
merged_circuit_op_tag: str = "_merged_k_qubit_unitaries_component",
) -> 'cirq.Circuit':
deep = context.deep if context else False

def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
if not (protocols.num_qubits(op) <= k and protocols.has_unitary(op)):
return op
op_untagged = op.untagged
if (
deep
and isinstance(op_untagged, circuits.CircuitOperation)
and merged_circuit_op_tag not in op.tags
):
return op_untagged.replace(
circuit=_rewrite_merged_k_qubit_unitaries(
op_untagged.circuit,
context=context,
k=k,
rewriter=rewriter,
merged_circuit_op_tag=merged_circuit_op_tag,
).freeze()
).with_tags(*op.tags)
if rewriter:
return rewriter(
cast(circuits.CircuitOperation, op_untagged)
if merged_circuit_op_tag in op.tags
else circuits.CircuitOperation(circuits.FrozenCircuit(op))
)
return ops.MatrixGate(protocols.unitary(op)).on(*op.qubits)

return transformer_primitives.map_operations_and_unroll(
circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else ()
).unfreeze(copy=False)


@transformer_api.transformer
def merge_k_qubit_unitaries(
circuit: 'cirq.AbstractCircuit',
Expand Down Expand Up @@ -54,24 +95,17 @@ def merge_k_qubit_unitaries(
if k <= 0:
raise ValueError(f"k should be greater than or equal to 1. Found {k}.")
merged_circuit_op_tag = "_merged_k_qubit_unitaries_component"

def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
if not (protocols.num_qubits(op) <= k and protocols.has_unitary(op)):
return op
if rewriter:
return rewriter(
cast(circuits.CircuitOperation, op.untagged)
if merged_circuit_op_tag in op.tags
else circuits.CircuitOperation(circuits.FrozenCircuit(op))
)
return ops.MatrixGate(protocols.unitary(op)).on(*op.qubits)

circuit = transformer_primitives.merge_k_qubit_unitaries_to_circuit_op(
circuit,
k=k,
tags_to_ignore=context.tags_to_ignore if context else (),
merged_circuit_op_tag=merged_circuit_op_tag,
deep=context.deep if context else False,
)
return _rewrite_merged_k_qubit_unitaries(
circuit,
context=context,
k=k,
rewriter=rewriter,
merged_circuit_op_tag=merged_circuit_op_tag,
)
return transformer_primitives.map_operations_and_unroll(
circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else ()
).unfreeze(copy=False)
65 changes: 65 additions & 0 deletions cirq-core/cirq/transformers/merge_k_qubit_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,68 @@ def rewriter_replace_with_decomp(op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
║ ║
a: ═════════════════════════════════════════════════════════════════════════════════════════════@══════════════════════════════^═══''',
)


def test_merge_k_qubit_unitaries_deep():
q = cirq.LineQubit.range(2)
h_cz_y = [cirq.H(q[0]), cirq.CZ(*q), cirq.Y(q[1])]
c_orig = cirq.Circuit(
h_cz_y,
cirq.Moment(cirq.X(q[0]).with_tags("ignore"), cirq.Y(q[1])),
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"),
[cirq.CNOT(*q), cirq.CNOT(*q)],
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(4),
[cirq.CNOT(*q), cirq.CZ(*q), cirq.CNOT(*q)],
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(5).with_tags("preserve_tag"),
)

def _wrap_in_cop(ops: cirq.OP_TREE, tag: str):
return cirq.CircuitOperation(cirq.FrozenCircuit(ops)).with_tags(tag)

c_expected = cirq.Circuit(
_wrap_in_cop([h_cz_y, cirq.Y(q[1])], '1'),
cirq.Moment(cirq.X(q[0]).with_tags("ignore")),
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"),
_wrap_in_cop([cirq.CNOT(*q), cirq.CNOT(*q)], '2'),
cirq.CircuitOperation(cirq.FrozenCircuit(_wrap_in_cop(h_cz_y, '3'))).repeat(4),
_wrap_in_cop([cirq.CNOT(*q), cirq.CZ(*q), cirq.CNOT(*q)], '4'),
cirq.CircuitOperation(cirq.FrozenCircuit(_wrap_in_cop(h_cz_y, '5')))
.repeat(5)
.with_tags("preserve_tag"),
strategy=cirq.InsertStrategy.NEW,
)

component_id = 0

def rewriter_merge_to_circuit_op(op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
nonlocal component_id
component_id = component_id + 1
return op.with_tags(f'{component_id}')

context = cirq.TransformerContext(tags_to_ignore=("ignore",), deep=True)
c_new = cirq.merge_k_qubit_unitaries(
c_orig,
k=2,
context=context,
rewriter=rewriter_merge_to_circuit_op,
)
cirq.testing.assert_same_circuits(c_new, c_expected)

def _wrap_in_matrix_gate(ops: cirq.OP_TREE):
op = _wrap_in_cop(ops, 'temp')
return cirq.MatrixGate(cirq.unitary(op)).on(*op.qubits)

c_expected_matrix = cirq.Circuit(
_wrap_in_matrix_gate([h_cz_y, cirq.Y(q[1])]),
cirq.Moment(cirq.X(q[0]).with_tags("ignore")),
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"),
_wrap_in_matrix_gate([cirq.CNOT(*q), cirq.CNOT(*q)]),
cirq.CircuitOperation(cirq.FrozenCircuit(_wrap_in_matrix_gate(h_cz_y))).repeat(4),
_wrap_in_matrix_gate([cirq.CNOT(*q), cirq.CZ(*q), cirq.CNOT(*q)]),
cirq.CircuitOperation(cirq.FrozenCircuit(_wrap_in_matrix_gate(h_cz_y)))
.repeat(5)
.with_tags("preserve_tag"),
strategy=cirq.InsertStrategy.NEW,
)
c_new_matrix = cirq.merge_k_qubit_unitaries(c_orig, k=2, context=context)
cirq.testing.assert_same_circuits(c_new_matrix, c_expected_matrix)

0 comments on commit 92d19f6

Please sign in to comment.