Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Replace Circuit::num_gates with num_operations #384

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 80 additions & 17 deletions tket2/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,31 @@ impl<T: HugrView> Circuit<T> {
.expect("Circuit has no I/O nodes")
}

/// The number of quantum gates in the circuit.
/// The number of operations in the circuit.
///
/// This includes [`Tk2Op`]s, pytket ops, and any other custom operations.
///
/// Nested circuits are traversed to count their operations.
///
/// [`Tk2Op`]: crate::Tk2Op
#[inline]
pub fn num_gates(&self) -> usize
pub fn num_operations(&self) -> usize
where
Self: Sized,
{
// TODO: Discern quantum gates in the commands iterator.
self.hugr().children(self.parent).count() - 2
let mut count = 0;
let mut roots = vec![self.parent];
while let Some(node) = roots.pop() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't have some existing graph/hierarchy traversal we could use? is this to avoid petgraph dependencies?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's hugr::DescendantsGraph but that's optimised for arbitrary accesses to the HugrView.
It keeps a cache of nodes it knows are in graph in a RefCell<HashMap<_,_>>, and updates it as it filters throught.
We can skip that here since we have a set traversal order.

I guess we could optimise that in Portgraph. I'll open an issue, but leave the explicit loop here for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just iterate over all nodes in the hugr?

Copy link
Collaborator Author

@aborgna-q aborgna-q Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we only care about the region describing the circuit, and its descendants.

If the hugr defines a module with multiple functions, we should only traverse the circuit's. E.g.

/// 2-qubit circuit with a Hadamard, a CNOT, and a T gate,
/// defined inside a module containing other circuits.
#[fixture]
fn module_with_circuits() -> Circuit {
let mut module = simple_module();
let other_circ = simple_circuit();
let hugr = module.hugr_mut();
hugr.insert_hugr(hugr.root(), other_circ.into_hugr());
return module;
}

This gets a bit interesting with any kind of control flow; should function calls count towards the operation count? Should conditionals count each branch?

Here I went with the best-defined option: traverse embedded DFG blocks, but ignore any control flow primitives.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving this here to register the backlink:
CQCL/portgraph#135

for child in self.hugr().children(node) {
let optype = self.hugr().get_optype(child);
if optype.is_custom_op() {
count += 1;
} else if OpTag::DataflowParent.is_superset(optype.tag()) {
roots.push(child);
}
}
}
count
}

/// Count the number of qubits in the circuit.
Expand Down Expand Up @@ -471,6 +488,7 @@ fn update_signature(
#[cfg(test)]
mod tests {
use cool_asserts::assert_matches;
use rstest::{fixture, rstest};

use hugr::types::FunctionType;
use hugr::{
Expand All @@ -479,38 +497,83 @@ mod tests {
};

use super::*;
use crate::utils::build_module_with_circuit;
use crate::{json::load_tk1_json_str, utils::build_simple_circuit, Tk2Op};

fn test_circuit() -> Circuit {
#[fixture]
fn tk1_circuit() -> Circuit {
load_tk1_json_str(
r#"{ "phase": "0",
"bits": [["c", [0]]],
"qubits": [["q", [0]], ["q", [1]]],
"commands": [
{"args": [["q", [0]]], "op": {"type": "H"}},
{"args": [["q", [0]], ["q", [1]]], "op": {"type": "CX"}},
{"args": [["q", [1]]], "op": {"type": "X"}}
{"args": [["q", [1]]], "op": {"params": ["0.25"], "type": "Rz"}}
],
"implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]]
}"#,
)
.unwrap()
}

#[test]
fn test_circuit_properties() {
let circ = test_circuit();
/// 2-qubit circuit with a Hadamard, a CNOT, and a Rz gate.
#[fixture]
fn simple_circuit() -> Circuit {
build_simple_circuit(2, |circ| {
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::X, [1])?;

assert_eq!(circ.name(), None);
assert_eq!(circ.circuit_signature().body().input_count(), 3);
assert_eq!(circ.circuit_signature().body().output_count(), 3);
assert_eq!(circ.qubit_count(), 2);
assert_eq!(circ.num_gates(), 3);
// TODO: Replace the `X` with the following once Hugr adds `CircuitBuilder::add_constant`.
// See https://github.com/CQCL/hugr/pull/1168

//let angle = circ.add_constant(ConstF64::new(0.5));
//circ.append_and_consume(
// Tk2Op::RzF64,
// [CircuitUnit::Linear(1), CircuitUnit::Wire(angle)],
//)?;
Ok(())
})
.unwrap()
}

/// 2-qubit circuit with a Hadamard, a CNOT, and a Rz gate,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring is innacurate (X not Rz)

/// defined inside a module.
#[fixture]
fn simple_module() -> Circuit {
build_module_with_circuit(2, |circ| {
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::X, [1])?;
Ok(())
})
.unwrap()
}

#[rstest]
#[case::simple(simple_circuit(), 2, 0, None)]
#[case::module(simple_module(), 2, 0, None)]
#[case::tk1(tk1_circuit(), 2, 1, None)]
fn test_circuit_properties(
#[case] circ: Circuit,
#[case] qubits: usize,
#[case] bits: usize,
#[case] name: Option<&str>,
) {
assert_eq!(circ.name(), name);
assert_eq!(circ.circuit_signature().body().input_count(), qubits + bits);
assert_eq!(
circ.circuit_signature().body().output_count(),
qubits + bits
);
assert_eq!(circ.qubit_count(), qubits);
assert_eq!(circ.num_operations(), 3);

assert_eq!(circ.units().count(), 3);
assert_eq!(circ.units().count(), qubits + bits);
assert_eq!(circ.nonlinear_units().count(), 0);
assert_eq!(circ.linear_units().count(), 3);
assert_eq!(circ.qubits().count(), 2);
assert_eq!(circ.linear_units().count(), qubits + bits);
assert_eq!(circ.qubits().count(), qubits);
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion tket2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
//! let mut circ: Circuit = tket2::json::load_tk1_json_file("../test_files/barenco_tof_5.json").unwrap();
//!
//! assert_eq!(circ.qubit_count(), 9);
//! assert_eq!(circ.num_gates(), 170);
//! assert_eq!(circ.num_operations(), 170);
//!
//! // Traverse the circuit and print the gates.
//! for command in circ.commands() {
Expand Down
5 changes: 4 additions & 1 deletion tket2/src/optimiser/badger/eq_circ_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ impl EqCircClass {
};

// Find the index for the smallest circuit
let min_index = circs.iter().position_min_by_key(|c| c.num_gates()).unwrap();
let min_index = circs
.iter()
.position_min_by_key(|c| c.num_operations())
.unwrap();
let representative = circs.swap_remove(min_index);
Ok(Self::new(representative, circs))
}
Expand Down
2 changes: 1 addition & 1 deletion tket2/src/portmatching/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl CircuitPattern {
/// Construct a pattern from a circuit.
pub fn try_from_circuit(circuit: &Circuit) -> Result<Self, InvalidPattern> {
let hugr = circuit.hugr();
if circuit.num_gates() == 0 {
if circuit.num_operations() == 0 {
return Err(InvalidPattern::EmptyCircuit);
}
let mut pattern = Pattern::new();
Expand Down
2 changes: 1 addition & 1 deletion tket2/src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl CircuitRewrite {
/// The difference between the new number of nodes minus the old. A positive
/// number is an increase in node count, a negative number is a decrease.
pub fn node_count_delta(&self) -> isize {
let new_count = self.replacement().num_gates() as isize;
let new_count = self.replacement().num_operations() as isize;
let old_count = self.subcircuit().node_count() as isize;
new_count - old_count
}
Expand Down
8 changes: 4 additions & 4 deletions tket2/src/rewrite/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ impl RewriteStrategy for GreedyRewriteStrategy {
}

fn circuit_cost(&self, circ: &Circuit<impl HugrView>) -> Self::Cost {
circ.num_gates()
circ.num_operations()
}

fn op_cost(&self, _op: &OpType) -> Self::Cost {
Expand Down Expand Up @@ -488,7 +488,7 @@ mod tests {
let strategy = GreedyRewriteStrategy;
let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec();
assert_eq!(rewritten.len(), 1);
assert_eq!(rewritten[0].circ.num_gates(), 5);
assert_eq!(rewritten[0].circ.num_operations(), 5);

if REWRITE_TRACING_ENABLED {
assert_eq!(rewritten[0].circ.rewrite_trace().unwrap().len(), 3);
Expand All @@ -511,7 +511,7 @@ mod tests {
let strategy = LexicographicCostFunction::default_cx();
let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec();
let exp_circ_lens = HashSet::from_iter([3, 7, 9]);
let circ_lens: HashSet<_> = rewritten.iter().map(|r| r.circ.num_gates()).collect();
let circ_lens: HashSet<_> = rewritten.iter().map(|r| r.circ.num_operations()).collect();
assert_eq!(circ_lens, exp_circ_lens);

if REWRITE_TRACING_ENABLED {
Expand Down Expand Up @@ -547,7 +547,7 @@ mod tests {
let strategy = GammaStrategyCost::exhaustive_cx_with_gamma(10.);
let rewritten = strategy.apply_rewrites(rws, &circ);
let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]);
let circ_lens: HashSet<_> = rewritten.map(|r| r.circ.num_gates()).collect();
let circ_lens: HashSet<_> = rewritten.map(|r| r.circ.num_operations()).collect();
assert_eq!(circ_lens, exp_circ_lens);
}

Expand Down
Loading