Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port RemoveDiagonalGatesBeforeMeasure to rust #13065

Merged
merged 17 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/accelerate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub mod isometry;
pub mod nlayout;
pub mod optimize_1q_gates;
pub mod pauli_exp_val;
pub mod remove_diagonal_gates_before_measure;
pub mod results;
pub mod sabre;
pub mod sampled_exp_val;
Expand Down
107 changes: 107 additions & 0 deletions crates/accelerate/src/remove_diagonal_gates_before_measure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// This code is part of Qiskit.
//
// (C) Copyright IBM 2024
//
// This code is licensed under the Apache License, Version 2.0. You may
// obtain a copy of this license in the LICENSE.txt file in the root directory
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
//
// Any modifications or derivative works of this code must retain this
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

/// Remove diagonal gates (including diagonal 2Q gates) before a measurement.
use pyo3::prelude::*;

use qiskit_circuit::dag_circuit::{DAGCircuit, NodeType};
use qiskit_circuit::operations::Operation;
use qiskit_circuit::operations::StandardGate;

/// Run the RemoveDiagonalGatesBeforeMeasure pass on `dag`.
/// Args:
/// dag (DAGCircuit): the DAG to be optimized.
/// Returns:
/// DAGCircuit: the optimized DAG.
#[pyfunction]
#[pyo3(name = "remove_diagonal_gates_before_measure")]
fn run_remove_diagonal_before_measure(dag: &mut DAGCircuit) -> PyResult<()> {
static DIAGONAL_1Q_GATES: [StandardGate; 8] = [
Cryoris marked this conversation as resolved.
Show resolved Hide resolved
StandardGate::RZGate,
StandardGate::ZGate,
StandardGate::TGate,
StandardGate::SGate,
StandardGate::TdgGate,
StandardGate::SdgGate,
StandardGate::U1Gate,
StandardGate::PhaseGate,
];
static DIAGONAL_2Q_GATES: [StandardGate; 7] = [
StandardGate::CZGate,
StandardGate::CRZGate,
StandardGate::CU1Gate,
StandardGate::RZZGate,
StandardGate::CPhaseGate,
StandardGate::CSGate,
StandardGate::CSdgGate,
];
static DIAGONAL_3Q_GATES: [StandardGate; 1] = [StandardGate::CCZGate];

Copy link
Member Author

@ShellyGarion ShellyGarion Sep 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that there are some 1-qubit and 2-qubit diagonal gates that did not appear in the original list, so I added them here (I'll add them to the tests later).

There is also a 3-qubit diagonal gate: CCZGate (which was not added in this PR)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, there is also an n-qubit diagonal gate: MCPhaseGate. This gate was not handled this PR.
This is since the algorithm given here that goes over all the successors of each of the predecessors would be O(n^2) and not O(n).

let mut nodes_to_remove = Vec::new();
for index in dag.op_nodes(true) {
let node = &dag.dag[index];
let NodeType::Operation(inst) = node else {panic!()};

if inst.op.name() == "measure" {
let predecessor = (dag.quantum_predecessors(index))
.next()
.expect("index is an operation node, so it must have a predecessor.");

match &dag.dag[predecessor] {
NodeType::Operation(pred_inst) => match pred_inst.standard_gate() {
Some(gate) => {
if DIAGONAL_1Q_GATES.contains(&gate) {
nodes_to_remove.push(predecessor);
} else if DIAGONAL_2Q_GATES.contains(&gate)
|| DIAGONAL_3Q_GATES.contains(&gate)
{
let successors = dag.quantum_successors(predecessor);
let remove_s = successors
.map(|s| {
let node_s = &dag.dag[s];
if let NodeType::Operation(inst_s) = node_s {
inst_s.op.name() == "measure"
} else {
false
}
})
.all(|ok_to_remove| ok_to_remove);
if remove_s {
nodes_to_remove.push(predecessor);
}
}
}
None => {
continue;
}
},
_ => {
continue;
}
}
}
}

for node_to_remove in nodes_to_remove {
if dag.dag.node_weight(node_to_remove).is_some() {
dag.remove_op_node(node_to_remove);
}
}

Ok(())
}

#[pymodule]
pub fn remove_diagonal_gates_before_measure(m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(run_remove_diagonal_before_measure))?;
Ok(())
}
4 changes: 2 additions & 2 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5122,7 +5122,7 @@ impl DAGCircuit {
}
}

fn quantum_predecessors(&self, node: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
pub fn quantum_predecessors(&self, node: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
self.dag
.edges_directed(node, Incoming)
.filter_map(|e| match e.weight() {
Expand All @@ -5132,7 +5132,7 @@ impl DAGCircuit {
.unique()
}

fn quantum_successors(&self, node: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
pub fn quantum_successors(&self, node: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
self.dag
.edges_directed(node, Outgoing)
.filter_map(|e| match e.weight() {
Expand Down
8 changes: 7 additions & 1 deletion crates/pyext/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ use qiskit_accelerate::{
commutation_checker::commutation_checker, convert_2q_block_matrix::convert_2q_block_matrix,
dense_layout::dense_layout, error_map::error_map,
euler_one_qubit_decomposer::euler_one_qubit_decomposer, isometry::isometry, nlayout::nlayout,
optimize_1q_gates::optimize_1q_gates, pauli_exp_val::pauli_expval, results::results,
optimize_1q_gates::optimize_1q_gates, pauli_exp_val::pauli_expval,
remove_diagonal_gates_before_measure::remove_diagonal_gates_before_measure, results::results,
sabre::sabre, sampled_exp_val::sampled_exp_val, sparse_pauli_op::sparse_pauli_op,
star_prerouting::star_prerouting, stochastic_swap::stochastic_swap, synthesis::synthesis,
target_transpiler::target, two_qubit_decompose::two_qubit_decompose, uc_gate::uc_gate,
Expand Down Expand Up @@ -50,6 +51,11 @@ fn _accelerate(m: &Bound<PyModule>) -> PyResult<()> {
add_submodule(m, optimize_1q_gates, "optimize_1q_gates")?;
add_submodule(m, pauli_expval, "pauli_expval")?;
add_submodule(m, synthesis, "synthesis")?;
add_submodule(
m,
remove_diagonal_gates_before_measure,
"remove_diagonal_gates_before_measure",
)?;
add_submodule(m, results, "results")?;
add_submodule(m, sabre, "sabre")?;
add_submodule(m, sampled_exp_val, "sampled_exp_val")?;
Expand Down
3 changes: 3 additions & 0 deletions qiskit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@
sys.modules["qiskit._accelerate.pauli_expval"] = _accelerate.pauli_expval
sys.modules["qiskit._accelerate.qasm2"] = _accelerate.qasm2
sys.modules["qiskit._accelerate.qasm3"] = _accelerate.qasm3
sys.modules["qiskit._accelerate.remove_diagonal_gates_before_measure"] = (
_accelerate.remove_diagonal_gates_before_measure
)
sys.modules["qiskit._accelerate.results"] = _accelerate.results
sys.modules["qiskit._accelerate.sabre"] = _accelerate.sabre
sys.modules["qiskit._accelerate.sampled_exp_val"] = _accelerate.sampled_exp_val
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,13 @@

"""Remove diagonal gates (including diagonal 2Q gates) before a measurement."""

from qiskit.circuit import Measure
from qiskit.circuit.library.standard_gates import (
RZGate,
ZGate,
TGate,
SGate,
TdgGate,
SdgGate,
U1Gate,
CZGate,
CRZGate,
CU1Gate,
RZZGate,
)
from qiskit.dagcircuit import DAGOpNode
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.transpiler.passes.utils import control_flow

from qiskit._accelerate.remove_diagonal_gates_before_measure import (
remove_diagonal_gates_before_measure,
)


class RemoveDiagonalGatesBeforeMeasure(TransformationPass):
"""Remove diagonal gates (including diagonal 2Q gates) before a measurement.
Expand All @@ -48,22 +37,5 @@ def run(self, dag):
Returns:
DAGCircuit: the optimized DAG.
"""
diagonal_1q_gates = (RZGate, ZGate, TGate, SGate, TdgGate, SdgGate, U1Gate)
diagonal_2q_gates = (CZGate, CRZGate, CU1Gate, RZZGate)

nodes_to_remove = set()
for measure in dag.op_nodes(Measure):
predecessor = next(dag.quantum_predecessors(measure))

if isinstance(predecessor, DAGOpNode) and isinstance(predecessor.op, diagonal_1q_gates):
nodes_to_remove.add(predecessor)

if isinstance(predecessor, DAGOpNode) and isinstance(predecessor.op, diagonal_2q_gates):
successors = dag.quantum_successors(predecessor)
if all(isinstance(s, DAGOpNode) and isinstance(s.op, Measure) for s in successors):
nodes_to_remove.add(predecessor)

for node_to_remove in nodes_to_remove:
dag.remove_op_node(node_to_remove)

remove_diagonal_gates_before_measure(dag)
ShellyGarion marked this conversation as resolved.
Show resolved Hide resolved
return dag
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,29 @@ def test_optimize_1rz_1measure(self):

self.assertEqual(circuit_to_dag(expected), after)

def test_optimize_1phase_1measure(self):
"""Remove a single PhaseGate
qr0:--P--m-- qr0:--m-
| |
qr1:-----|-- ==> qr1:--|-
| |
cr0:-----.-- cr0:--.-
"""
qr = QuantumRegister(2, "qr")
cr = ClassicalRegister(1, "cr")
circuit = QuantumCircuit(qr, cr)
circuit.p(0.1, qr[0])
circuit.measure(qr[0], cr[0])
dag = circuit_to_dag(circuit)

expected = QuantumCircuit(qr, cr)
expected.measure(qr[0], cr[0])

pass_ = RemoveDiagonalGatesBeforeMeasure()
after = pass_.run(dag)

self.assertEqual(circuit_to_dag(expected), after)

def test_optimize_1z_1measure(self):
"""Remove a single ZGate
qr0:--Z--m-- qr0:--m-
Expand All @@ -74,7 +97,7 @@ def test_optimize_1z_1measure(self):
self.assertEqual(circuit_to_dag(expected), after)

def test_optimize_1t_1measure(self):
"""Remove a single TGate, SGate, TdgGate, SdgGate, U1Gate
"""Remove a single TGate
qr0:--T--m-- qr0:--m-
| |
qr1:-----|-- ==> qr1:--|-
Expand Down Expand Up @@ -298,6 +321,56 @@ def test_optimize_1cz_2measure(self):

self.assertEqual(circuit_to_dag(expected), after)

def test_optimize_1cs_2measure(self):
"""Remove a single CSGate
qr0:-CS--m--- qr0:--m---
| | |
qr1:--.--|-m- ==> qr1:--|-m-
| | | |
cr0:-----.-.- cr0:--.-.-
"""
qr = QuantumRegister(2, "qr")
cr = ClassicalRegister(1, "cr")
circuit = QuantumCircuit(qr, cr)
circuit.cs(qr[0], qr[1])
circuit.measure(qr[0], cr[0])
circuit.measure(qr[1], cr[0])
dag = circuit_to_dag(circuit)

expected = QuantumCircuit(qr, cr)
expected.measure(qr[0], cr[0])
expected.measure(qr[1], cr[0])

pass_ = RemoveDiagonalGatesBeforeMeasure()
after = pass_.run(dag)

self.assertEqual(circuit_to_dag(expected), after)

def test_optimize_1csdg_2measure(self):
"""Remove a single CSdgGate
qr0:-CSdg--m--- qr0:--m---
| | |
qr1:----.--|-m- ==> qr1:--|-m-
| | | |
cr0:-------.-.- cr0:--.-.-
"""
qr = QuantumRegister(2, "qr")
cr = ClassicalRegister(1, "cr")
circuit = QuantumCircuit(qr, cr)
circuit.csdg(qr[0], qr[1])
circuit.measure(qr[0], cr[0])
circuit.measure(qr[1], cr[0])
dag = circuit_to_dag(circuit)

expected = QuantumCircuit(qr, cr)
expected.measure(qr[0], cr[0])
expected.measure(qr[1], cr[0])

pass_ = RemoveDiagonalGatesBeforeMeasure()
after = pass_.run(dag)

self.assertEqual(circuit_to_dag(expected), after)

def test_optimize_1crz_2measure(self):
"""Remove a single CRZGate
qr0:-RZ--m--- qr0:--m---
Expand All @@ -323,6 +396,31 @@ def test_optimize_1crz_2measure(self):

self.assertEqual(circuit_to_dag(expected), after)

def test_optimize_1cp_2measure(self):
"""Remove a single CPhaseGate
qr0:-CP--m--- qr0:--m---
| | |
qr1:--.--|-m- ==> qr1:--|-m-
| | | |
cr0:-----.-.- cr0:--.-.-
"""
qr = QuantumRegister(2, "qr")
cr = ClassicalRegister(1, "cr")
circuit = QuantumCircuit(qr, cr)
circuit.cp(0.1, qr[0], qr[1])
circuit.measure(qr[0], cr[0])
circuit.measure(qr[1], cr[0])
dag = circuit_to_dag(circuit)

expected = QuantumCircuit(qr, cr)
expected.measure(qr[0], cr[0])
expected.measure(qr[1], cr[0])

pass_ = RemoveDiagonalGatesBeforeMeasure()
after = pass_.run(dag)

self.assertEqual(circuit_to_dag(expected), after)

def test_optimize_1cu1_2measure(self):
"""Remove a single CU1Gate
qr0:-CU1-m--- qr0:--m---
Expand Down Expand Up @@ -373,6 +471,28 @@ def test_optimize_1rzz_2measure(self):

self.assertEqual(circuit_to_dag(expected), after)

def test_optimize_1ccz_3measure(self):
"""Remove a single CCZGate
"""
qr = QuantumRegister(3, "qr")
cr = ClassicalRegister(1, "cr")
circuit = QuantumCircuit(qr, cr)
circuit.ccz(qr[0], qr[1], qr[2])
circuit.measure(qr[0], cr[0])
circuit.measure(qr[1], cr[0])
circuit.measure(qr[2], cr[0])
dag = circuit_to_dag(circuit)

expected = QuantumCircuit(qr, cr)
expected.measure(qr[0], cr[0])
expected.measure(qr[1], cr[0])
expected.measure(qr[2], cr[0])

pass_ = RemoveDiagonalGatesBeforeMeasure()
after = pass_.run(dag)

self.assertEqual(circuit_to_dag(expected), after)


class TestRemoveDiagonalGatesBeforeMeasureOveroptimizations(QiskitTestCase):
"""Test situations where remove_diagonal_gates_before_measure should not optimize"""
Expand Down Expand Up @@ -401,6 +521,23 @@ def test_optimize_1cz_1measure(self):

self.assertEqual(expected, after)

def test_optimize_1ccz_1measure(self):
"""Do not remove a CCZGate because measure happens on only one of the wires
"""
qr = QuantumRegister(3, "qr")
cr = ClassicalRegister(1, "cr")
circuit = QuantumCircuit(qr, cr)
circuit.ccz(qr[0], qr[1], qr[2])
circuit.measure(qr[1], cr[0])
dag = circuit_to_dag(circuit)

expected = deepcopy(dag)

pass_ = RemoveDiagonalGatesBeforeMeasure()
after = pass_.run(dag)

self.assertEqual(expected, after)

def test_do_not_optimize_with_conditional(self):
"""Diagonal gates with conditionals on a measurement target.
See https://github.com/Qiskit/qiskit-terra/pull/2208#issuecomment-487238819
Expand Down
Loading