diff --git a/crates/accelerate/src/consolidate_blocks.rs b/crates/accelerate/src/consolidate_blocks.rs index 1edd592ce87..0fec3fa2909 100644 --- a/crates/accelerate/src/consolidate_blocks.rs +++ b/crates/accelerate/src/consolidate_blocks.rs @@ -54,11 +54,12 @@ const MAX_2Q_DEPTH: usize = 20; #[allow(clippy::too_many_arguments)] #[pyfunction] -#[pyo3(signature = (dag, decomposer, force_consolidate, target=None, basis_gates=None, blocks=None, runs=None))] +#[pyo3(signature = (dag, decomposer, basis_gate_name, force_consolidate, target=None, basis_gates=None, blocks=None, runs=None))] pub(crate) fn consolidate_blocks( py: Python, dag: &mut DAGCircuit, decomposer: &TwoQubitBasisDecomposer, + basis_gate_name: &str, force_consolidate: bool, target: Option<&Target>, basis_gates: Option>, @@ -125,7 +126,7 @@ pub(crate) fn consolidate_blocks( let inst = dag.dag()[*node].unwrap_operation(); block_qargs.extend(dag.get_qargs(inst.qubits)); all_block_gates.insert(*node); - if inst.op.name() == decomposer.gate_name() { + if inst.op.name() == basis_gate_name { basis_count += 1; } if !is_supported( diff --git a/qiskit/transpiler/passes/optimization/consolidate_blocks.py b/qiskit/transpiler/passes/optimization/consolidate_blocks.py index ce64233699b..f31401abb6a 100644 --- a/qiskit/transpiler/passes/optimization/consolidate_blocks.py +++ b/qiskit/transpiler/passes/optimization/consolidate_blocks.py @@ -109,6 +109,7 @@ def run(self, dag): consolidate_blocks( dag, self.decomposer._inner_decomposer, + self.decomposer.gate.name, self.force_consolidate, target=self.target, basis_gates=self.basis_gates, diff --git a/test/python/transpiler/test_consolidate_blocks.py b/test/python/transpiler/test_consolidate_blocks.py index fbe11ff77db..83379a4eecb 100644 --- a/test/python/transpiler/test_consolidate_blocks.py +++ b/test/python/transpiler/test_consolidate_blocks.py @@ -18,8 +18,17 @@ import numpy as np from ddt import ddt, data -from qiskit.circuit import QuantumCircuit, QuantumRegister, IfElseOp, Gate -from qiskit.circuit.library import U2Gate, SwapGate, CXGate, CZGate, UnitaryGate +from qiskit.circuit import QuantumCircuit, QuantumRegister, IfElseOp, Gate, Parameter +from qiskit.circuit.library import ( + U2Gate, + SwapGate, + CXGate, + CZGate, + UnitaryGate, + SXGate, + XGate, + RZGate, +) from qiskit.converters import circuit_to_dag from qiskit.quantum_info.operators import Operator from qiskit.quantum_info.operators.measures import process_fidelity @@ -590,6 +599,73 @@ def test_no_kak_gates_in_preset_pm(self, opt_level): tqc = pm.run(qc) self.assertEqual(ref_tqc, tqc) + def test_non_cx_basis_gate(self): + """Test a non-cx kak gate is consolidated correctly.""" + qc = QuantumCircuit(2) + qc.cz(0, 1) + qc.x(0) + qc.h(1) + qc.z(1) + qc.t(1) + qc.h(0) + qc.t(0) + qc.cz(1, 0) + qc.sx(0) + qc.sx(1) + qc.cz(0, 1) + qc.sx(0) + qc.sx(1) + qc.cz(1, 0) + qc.x(0) + qc.h(1) + qc.z(1) + qc.t(1) + qc.h(0) + qc.t(0) + qc.cz(0, 1) + + consolidate_pass = ConsolidateBlocks(basis_gates=["sx", "x", "rz", "cz"]) + res = consolidate_pass(qc) + self.assertEqual({"unitary": 1}, res.count_ops()) + self.assertEqual(Operator.from_circuit(qc), Operator(res.data[0].operation.params[0])) + + def test_non_cx_target(self): + """Test a non-cx kak gate is consolidated correctly.""" + qc = QuantumCircuit(2) + qc.cz(0, 1) + qc.x(0) + qc.h(1) + qc.z(1) + qc.t(1) + qc.h(0) + qc.t(0) + qc.cz(1, 0) + qc.sx(0) + qc.sx(1) + qc.cz(0, 1) + qc.sx(0) + qc.sx(1) + qc.cz(1, 0) + qc.x(0) + qc.h(1) + qc.z(1) + qc.t(1) + qc.h(0) + qc.t(0) + qc.cz(0, 1) + + phi = Parameter("phi") + target = Target(num_qubits=2) + target.add_instruction(SXGate(), {(0,): None, (1,): None}) + target.add_instruction(XGate(), {(0,): None, (1,): None}) + target.add_instruction(RZGate(phi), {(0,): None, (1,): None}) + target.add_instruction(CZGate(), {(0, 1): None, (1, 0): None}) + + consolidate_pass = ConsolidateBlocks(target=target) + res = consolidate_pass(qc) + self.assertEqual({"unitary": 1}, res.count_ops()) + self.assertEqual(Operator.from_circuit(qc), Operator(res.data[0].operation.params[0])) + if __name__ == "__main__": unittest.main()