Skip to content

Commit

Permalink
Check ExtensionId strings well-formed (#491)
Browse files Browse the repository at this point in the history
...as a dot-separated list of (dot-free) identifiers, adding a new type
`IdentList` of which `ExtensionId` is an alias.
E.g. no `foo..bar` or `.foo` or `5x` but `x5`, `foo.bar` ok.

`pub const fn new_unchecked` method allows bypassing the check. Nasty
but those constants are damn useful.

Also a `const_extension_ids` that automatically generates *tests* that
the names in the consts are valid.
  • Loading branch information
acl-cqc authored Sep 8, 2023
1 parent c0f7a66 commit a295ab8
Show file tree
Hide file tree
Showing 24 changed files with 261 additions and 130 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ portgraph = { version = "0.9.0", features = ["serde", "petgraph"] }
pyo3 = { version = "0.19.0", optional = true, features = [
"multiple-pymethods",
] }
regex = "1.9.5"
cgmath = { version = "0.18.0", features = ["serde"] }
num-rational = { version = "0.4.1", features = ["serde"] }
downcast-rs = "1.2.0"
Expand Down Expand Up @@ -59,6 +60,7 @@ rmp-serde = "1.1.1"
webbrowser = "0.8.10"
urlencoding = "2.1.2"
cool_asserts = "2.0.3"
paste = "1.0"

[[bench]]
name = "bench_main"
Expand Down
2 changes: 1 addition & 1 deletion src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ mod test {
fn with_nonlinear_and_outputs() {
let my_custom_op = LeafOp::CustomOp(
crate::ops::custom::ExternalOp::Opaque(OpaqueOp::new(
"MissingRsrc".into(),
"MissingRsrc".try_into().unwrap(),
"MyOp",
"unknown op".to_string(),
vec![],
Expand Down
17 changes: 10 additions & 7 deletions src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ pub(crate) mod test {
use crate::builder::build_traits::DataflowHugr;
use crate::builder::{DataflowSubContainer, ModuleBuilder};
use crate::extension::prelude::BOOL_T;
use crate::extension::EMPTY_REG;
use crate::extension::{ExtensionId, EMPTY_REG};
use crate::hugr::validate::InterGraphEdgeError;
use crate::ops::{handle::NodeHandle, LeafOp, OpTag};

Expand Down Expand Up @@ -428,8 +428,11 @@ pub(crate) mod test {

#[test]
fn lift_node() -> Result<(), BuildError> {
let ab_extensions = ExtensionSet::from_iter(["A".into(), "B".into()]);
let c_extensions = ExtensionSet::singleton(&"C".into());
let xa: ExtensionId = "A".try_into().unwrap();
let xb: ExtensionId = "B".try_into().unwrap();
let xc = "C".try_into().unwrap();
let ab_extensions = ExtensionSet::from_iter([xa.clone(), xb.clone()]);
let c_extensions = ExtensionSet::singleton(&xc);
let abc_extensions = ab_extensions.clone().union(&c_extensions);

let parent_sig =
Expand All @@ -452,7 +455,7 @@ pub(crate) mod test {
let lift_a = add_ab.add_dataflow_op(
LeafOp::Lift {
type_row: type_row![BIT],
new_extension: "A".into(),
new_extension: xa.clone(),
},
[w],
)?;
Expand All @@ -462,9 +465,9 @@ pub(crate) mod test {
NodeType::new(
LeafOp::Lift {
type_row: type_row![BIT],
new_extension: "B".into(),
new_extension: xb,
},
ExtensionSet::from_iter(["A".into()]),
ExtensionSet::from_iter([xa]),
),
[w],
)?;
Expand All @@ -482,7 +485,7 @@ pub(crate) mod test {
NodeType::new(
LeafOp::Lift {
type_row: type_row![BIT],
new_extension: "C".into(),
new_extension: xc,
},
ab_extensions,
),
Expand Down
24 changes: 12 additions & 12 deletions src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::sync::Arc;
use smol_str::SmolStr;
use thiserror::Error;

use crate::hugr::IdentList;
use crate::ops;
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::types::type_param::{check_type_arg, TypeArgError};
Expand All @@ -30,7 +31,7 @@ pub mod validate;
pub use prelude::{PRELUDE, PRELUDE_REGISTRY};

/// Extension Registries store extensions to be looked up e.g. during validation.
pub struct ExtensionRegistry(BTreeMap<SmolStr, Extension>);
pub struct ExtensionRegistry(BTreeMap<ExtensionId, Extension>);

impl ExtensionRegistry {
/// Makes a new (empty) registry.
Expand Down Expand Up @@ -79,10 +80,10 @@ pub enum SignatureError {
InvalidTypeArgs,
/// The Extension Registry did not contain an Extension referenced by the Signature
#[error("Extension '{0}' not found")]
ExtensionNotFound(SmolStr),
ExtensionNotFound(ExtensionId),
/// The Extension was found in the registry, but did not contain the Type(Def) referenced in the Signature
#[error("Extension '{exn}' did not contain expected TypeDef '{typ}'")]
ExtensionTypeNotFound { exn: SmolStr, typ: SmolStr },
ExtensionTypeNotFound { exn: ExtensionId, typ: SmolStr },
/// The bound recorded for a CustomType doesn't match what the TypeDef would compute
#[error("Bound on CustomType ({actual}) did not match TypeDef ({expected})")]
WrongBound {
Expand Down Expand Up @@ -204,10 +205,10 @@ impl ExtensionValue {
/// A unique identifier for a extension.
///
/// The actual [`Extension`] is stored externally.
pub type ExtensionId = SmolStr;
pub type ExtensionId = IdentList;

/// A extension is a set of capabilities required to execute a graph.
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Extension {
/// Unique identifier for the extension.
pub name: ExtensionId,
Expand All @@ -232,18 +233,17 @@ pub struct Extension {
impl Extension {
/// Creates a new extension with the given name.
pub fn new(name: ExtensionId) -> Self {
Self {
name,
..Default::default()
}
Self::new_with_reqs(name, Default::default())
}

/// Creates a new extension with the given name and requirements.
pub fn new_with_reqs(name: ExtensionId, extension_reqs: ExtensionSet) -> Self {
Self {
name,
extension_reqs,
..Default::default()
types: Default::default(),
values: Default::default(),
operations: Default::default(),
}
}

Expand All @@ -263,7 +263,7 @@ impl Extension {
}

/// Returns the name of the extension.
pub fn name(&self) -> &str {
pub fn name(&self) -> &ExtensionId {
&self.name
}

Expand All @@ -284,7 +284,7 @@ impl Extension {
typed_value: ops::Const,
) -> Result<&mut ExtensionValue, ExtensionBuildError> {
let extension_value = ExtensionValue {
extension: self.name().into(),
extension: self.name.clone(),
name: name.into(),
typed_value,
};
Expand Down
66 changes: 31 additions & 35 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -671,8 +671,8 @@ mod test {
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr};
use crate::extension::{ExtensionSet, EMPTY_REG};
use crate::hugr::HugrMut;
use crate::hugr::{validate::ValidationError, Hugr, HugrView, NodeType};
use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType};
use crate::macros::const_extension_ids;
use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle, OpTrait};
use crate::type_row;
use crate::types::{FunctionType, Type};
Expand All @@ -682,11 +682,17 @@ mod test {

const NAT: Type = crate::extension::prelude::USIZE_T;

const_extension_ids! {
const A: ExtensionId = "A";
const B: ExtensionId = "B";
const C: ExtensionId = "C";
}

#[test]
// Build up a graph with some holes in its extension requirements, and infer
// them.
fn from_graph() -> Result<(), Box<dyn Error>> {
let rs = ExtensionSet::from_iter(["A".into(), "B".into(), "C".into()]);
let rs = ExtensionSet::from_iter([A, B, C]);
let main_sig =
FunctionType::new(type_row![NAT, NAT], type_row![NAT]).with_extension_delta(&rs);

Expand All @@ -706,16 +712,16 @@ mod test {
assert_matches!(hugr.get_io(hugr.root()), Some(_));

let add_a_sig = FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&"A".into()));
.with_extension_delta(&ExtensionSet::singleton(&A));

let add_b_sig = FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&"B".into()));
.with_extension_delta(&ExtensionSet::singleton(&B));

let add_ab_sig = FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::from_iter(["A".into(), "B".into()]));
.with_extension_delta(&ExtensionSet::from_iter([A, B]));

let mult_c_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&"C".into()));
.with_extension_delta(&ExtensionSet::singleton(&C));

let add_a = hugr.add_node_with_parent(
hugr.root(),
Expand Down Expand Up @@ -753,14 +759,11 @@ mod test {

let (_, closure) = infer_extensions(&hugr)?;
let empty = ExtensionSet::new();
let ab = ExtensionSet::from_iter(["A".into(), "B".into()]);
let ab = ExtensionSet::from_iter([A, B]);
assert_eq!(*closure.get(&(hugr.root())).unwrap(), empty);
assert_eq!(*closure.get(&(mult_c)).unwrap(), ab);
assert_eq!(*closure.get(&(add_ab)).unwrap(), empty);
assert_eq!(
*closure.get(&add_b).unwrap(),
ExtensionSet::singleton(&"A".into())
);
assert_eq!(*closure.get(&add_b).unwrap(), ExtensionSet::singleton(&A));
Ok(())
}

Expand All @@ -779,20 +782,19 @@ mod test {
})
.collect();

ctx.solved
.insert(metas[2], ExtensionSet::singleton(&"A".into()));
ctx.solved.insert(metas[2], ExtensionSet::singleton(&A));
ctx.add_constraint(metas[1], Constraint::Equal(metas[2]));
ctx.add_constraint(metas[0], Constraint::Plus("B".into(), metas[2]));
ctx.add_constraint(metas[4], Constraint::Plus("C".into(), metas[0]));
ctx.add_constraint(metas[0], Constraint::Plus(B, metas[2]));
ctx.add_constraint(metas[4], Constraint::Plus(C, metas[0]));
ctx.add_constraint(metas[3], Constraint::Equal(metas[4]));
ctx.add_constraint(metas[5], Constraint::Equal(metas[0]));
ctx.main_loop()?;

let a = ExtensionSet::singleton(&"A".into());
let a = ExtensionSet::singleton(&A);
let mut ab = a.clone();
ab.insert(&"B".into());
ab.insert(&B);
let mut abc = ab.clone();
abc.insert(&"C".into());
abc.insert(&C);

assert_eq!(ctx.get_solution(&metas[0]).unwrap(), &ab);
assert_eq!(ctx.get_solution(&metas[1]).unwrap(), &a);
Expand All @@ -810,7 +812,7 @@ mod test {
fn missing_lift_node() -> Result<(), Box<dyn Error>> {
let builder = DFGBuilder::new(
FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&"R".into())),
.with_extension_delta(&ExtensionSet::singleton(&"R".try_into().unwrap())),
)?;
let [w] = builder.input_wires_arr();
let hugr = builder.finish_hugr_with_outputs([w], &EMPTY_REG);
Expand Down Expand Up @@ -842,8 +844,8 @@ mod test {
.insert((NodeIndex::new(4).into(), Direction::Incoming), ab);
ctx.variables.insert(a);
ctx.variables.insert(b);
ctx.add_constraint(ab, Constraint::Plus("A".into(), b));
ctx.add_constraint(ab, Constraint::Plus("B".into(), a));
ctx.add_constraint(ab, Constraint::Plus(A, b));
ctx.add_constraint(ab, Constraint::Plus(B, a));
let solution = ctx.main_loop()?;
// We'll only find concrete solutions for the Incoming extension reqs of
// the main node created by `Hugr::default`
Expand All @@ -854,7 +856,7 @@ mod test {
#[test]
// Infer the extensions on a child node with no inputs
fn dangling_src() -> Result<(), Box<dyn Error>> {
let rs = ExtensionSet::singleton(&"R".into());
let rs = ExtensionSet::singleton(&"R".try_into().unwrap());

let mut hugr = closed_dfg_root_hugr(
FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs),
Expand Down Expand Up @@ -943,7 +945,7 @@ mod test {
const BOOLEAN: Type = Type::new_simple_predicate(2);
let just_bool = type_row![BOOLEAN];

let abc = ExtensionSet::from_iter(["A".into(), "B".into(), "C".into()]);
let abc = ExtensionSet::from_iter([A, B, C]);

// Parent graph is closed
let mut hugr = closed_dfg_root_hugr(
Expand All @@ -967,7 +969,7 @@ mod test {
child,
NodeType::open_extensions(ops::LeafOp::Lift {
type_row: just_bool,
new_extension: "C".into(),
new_extension: C,
}),
)?;

Expand All @@ -983,7 +985,7 @@ mod test {
.signature()
.unwrap()
.output_extensions(),
ExtensionSet::from_iter(["A".into(), "B".into()])
ExtensionSet::from_iter([A, B])
);

Ok(())
Expand Down Expand Up @@ -1047,7 +1049,7 @@ mod test {
}

let predicate_inputs = vec![type_row![]; 2];
let rs = ExtensionSet::from_iter(["A".into(), "B".into()]);
let rs = ExtensionSet::from_iter([A, B]);

let inputs = type_row![NAT];
let outputs = type_row![NAT];
Expand All @@ -1065,15 +1067,9 @@ mod test {
let case_op = ops::Case {
signature: FunctionType::new(inputs, outputs).with_extension_delta(&rs),
};
let case0_node = build_case(
&mut hugr,
conditional_node,
case_op.clone(),
"A".into(),
"B".into(),
)?;
let case0_node = build_case(&mut hugr, conditional_node, case_op.clone(), A, B)?;

let case1_node = build_case(&mut hugr, conditional_node, case_op, "B".into(), "A".into())?;
let case1_node = build_case(&mut hugr, conditional_node, case_op, B, A)?;

hugr.infer_extensions()?;

Expand Down
24 changes: 9 additions & 15 deletions src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use lazy_static::lazy_static;
use smol_str::SmolStr;

use crate::{
extension::TypeDefBound,
extension::{ExtensionId, TypeDefBound},
types::{
type_param::{TypeArg, TypeParam},
CustomCheckFailure, CustomType, Type, TypeBound,
Expand All @@ -16,10 +16,10 @@ use crate::{
use super::ExtensionRegistry;

/// Name of prelude extension.
pub const PRELUDE_ID: &str = "prelude";
pub const PRELUDE_ID: ExtensionId = ExtensionId::new_unchecked("prelude");
lazy_static! {
static ref PRELUDE_DEF: Extension = {
let mut prelude = Extension::new(SmolStr::new_inline(PRELUDE_ID));
let mut prelude = Extension::new(PRELUDE_ID);
prelude
.add_type(
SmolStr::new_inline("usize"),
Expand Down Expand Up @@ -53,21 +53,15 @@ lazy_static! {
pub static ref PRELUDE_REGISTRY: ExtensionRegistry = [PRELUDE_DEF.to_owned()].into();

/// Prelude extension
pub static ref PRELUDE: &'static Extension = PRELUDE_REGISTRY.get(PRELUDE_ID).unwrap();
pub static ref PRELUDE: &'static Extension = PRELUDE_REGISTRY.get(&PRELUDE_ID).unwrap();

}

pub(crate) const USIZE_CUSTOM_T: CustomType = CustomType::new_simple(
SmolStr::new_inline("usize"),
SmolStr::new_inline(PRELUDE_ID),
TypeBound::Eq,
);
pub(crate) const USIZE_CUSTOM_T: CustomType =
CustomType::new_simple(SmolStr::new_inline("usize"), PRELUDE_ID, TypeBound::Eq);

pub(crate) const QB_CUSTOM_T: CustomType = CustomType::new_simple(
SmolStr::new_inline("qubit"),
SmolStr::new_inline(PRELUDE_ID),
TypeBound::Any,
);
pub(crate) const QB_CUSTOM_T: CustomType =
CustomType::new_simple(SmolStr::new_inline("qubit"), PRELUDE_ID, TypeBound::Any);

/// Qubit type.
pub const QB_T: Type = Type::new_extension(QB_CUSTOM_T);
Expand All @@ -90,7 +84,7 @@ pub fn new_array(typ: Type, size: u64) -> Type {

pub(crate) const ERROR_TYPE: Type = Type::new_extension(CustomType::new_simple(
smol_str::SmolStr::new_inline("error"),
smol_str::SmolStr::new_inline(PRELUDE_ID),
PRELUDE_ID,
TypeBound::Eq,
));

Expand Down
Loading

0 comments on commit a295ab8

Please sign in to comment.