Skip to content

Commit

Permalink
Break intermediate measurements on 3+ qubits into single qubit measur…
Browse files Browse the repository at this point in the history
…ements in `RouteCQC`quantumlib#6293 (quantumlib#6349)
  • Loading branch information
shef4 authored Nov 27, 2023
1 parent 4467123 commit f81f8cf
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 27 deletions.
26 changes: 16 additions & 10 deletions cirq/transformers/routing/route_circuit_cqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,26 +249,32 @@ def _get_one_and_two_qubit_ops_as_timesteps(
output routed circuit, single-qubit operations are inserted before two-qubit operations.
Raises:
ValueError: if circuit has intermediate measurement op's that act on 3 or more qubits.
ValueError: if circuit has intermediate measurements that act on three or more
qubits with a custom key.
"""
two_qubit_circuit = circuits.Circuit()
single_qubit_ops: List[List[cirq.Operation]] = []

if any(
protocols.num_qubits(op) > 2 and protocols.is_measurement(op)
for op in itertools.chain(*circuit.moments[:-1])
):
# There is at least one non-terminal measurement on 3+ qubits
raise ValueError('Non-terminal measurements on three or more qubits are not supported')

for moment in circuit:
for i, moment in enumerate(circuit):
for op in moment:
timestep = two_qubit_circuit.earliest_available_moment(op)
single_qubit_ops.extend([] for _ in range(timestep + 1 - len(single_qubit_ops)))
two_qubit_circuit.append(
circuits.Moment() for _ in range(timestep + 1 - len(two_qubit_circuit))
)
if protocols.num_qubits(op) == 2:
if protocols.num_qubits(op) > 2 and protocols.is_measurement(op):
key = op.gate.key # type: ignore
default_key = ops.measure(op.qubits).gate.key # type: ignore
if len(circuit.moments) == i + 1:
single_qubit_ops[timestep].append(op)
elif key in ('', default_key):
single_qubit_ops[timestep].extend(ops.measure(qubit) for qubit in op.qubits)
else:
raise ValueError(
'Intermediate measurements on three or more qubits '
'with a custom key are not supported'
)
elif protocols.num_qubits(op) == 2:
two_qubit_circuit[timestep] = two_qubit_circuit[timestep].with_operation(op)
else:
single_qubit_ops[timestep].append(op)
Expand Down
44 changes: 27 additions & 17 deletions cirq/transformers/routing/route_circuit_cqc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,35 +107,45 @@ def test_circuit_with_measurement_gates():
cirq.testing.assert_same_circuits(routed_circuit, circuit)


def test_circuit_with_valid_intermediate_multi_qubit_measurement_gates():
device = cirq.testing.construct_ring_device(3)
def test_circuit_with_two_qubit_intermediate_measurement_gate():
device = cirq.testing.construct_ring_device(2)
device_graph = device.metadata.nx_graph
router = cirq.RouteCQC(device_graph)
q = cirq.LineQubit.range(2)
hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(2)})

valid_circuit = cirq.Circuit(cirq.measure_each(*q), cirq.H.on_each(q))

c_routed = router(
valid_circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True)
qs = cirq.LineQubit.range(2)
hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(2)})
circuit = cirq.Circuit([cirq.Moment(cirq.measure(qs)), cirq.Moment(cirq.H.on_each(qs))])
routed_circuit = router(
circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True)
)
device.validate_circuit(c_routed)
device.validate_circuit(routed_circuit)


def test_circuit_with_invalid_intermediate_multi_qubit_measurement_gates():
def test_circuit_with_multi_qubit_intermediate_measurement_gate_and_with_default_key():
device = cirq.testing.construct_ring_device(3)
device_graph = device.metadata.nx_graph
router = cirq.RouteCQC(device_graph)
q = cirq.LineQubit.range(3)
hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(3)})
qs = cirq.LineQubit.range(3)
hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(3)})
circuit = cirq.Circuit([cirq.Moment(cirq.measure(qs)), cirq.Moment(cirq.H.on_each(qs))])
routed_circuit = router(
circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True)
)
expected = cirq.Circuit([cirq.Moment(cirq.measure_each(qs)), cirq.Moment(cirq.H.on_each(qs))])
cirq.testing.assert_same_circuits(routed_circuit, expected)

invalid_circuit = cirq.Circuit(cirq.MeasurementGate(3).on(*q), cirq.H.on_each(*q))

def test_circuit_with_multi_qubit_intermediate_measurement_gate_with_custom_key():
device = cirq.testing.construct_ring_device(3)
device_graph = device.metadata.nx_graph
router = cirq.RouteCQC(device_graph)
qs = cirq.LineQubit.range(3)
hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(3)})
circuit = cirq.Circuit(
[cirq.Moment(cirq.measure(qs, key="test")), cirq.Moment(cirq.H.on_each(qs))]
)
with pytest.raises(ValueError):
_ = router(
invalid_circuit,
initial_mapper=hard_coded_mapper,
context=cirq.TransformerContext(deep=True),
circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True)
)


Expand Down

0 comments on commit f81f8cf

Please sign in to comment.