diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index 50ea430c4..16e4fe573 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -332,6 +332,49 @@ mod test { Ok(()) } + #[cfg(not(feature = "extension_inference"))] // inference fails for test graph, shouldn't + #[test] + fn test_list_ops() -> Result<(), Box> { + use crate::std_extensions::collections::{self, make_list_const, ListOp, ListValue}; + use crate::types::TypeArg; + + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + logic::EXTENSION.to_owned(), + collections::EXTENSION.to_owned(), + ]) + .unwrap(); + let list = make_list_const( + ListValue::new(vec![Value::unit_sum(1)]), + &[TypeArg::Type { ty: BOOL_T }], + ); + let mut build = DFGBuilder::new(FunctionType::new( + type_row![], + vec![list.const_type().clone()], + )) + .unwrap(); + + let list_wire = build.add_load_const(list.clone())?; + + let pop = build.add_dataflow_op( + ListOp::Pop.with_type(BOOL_T).to_extension_op(®).unwrap(), + [list_wire], + )?; + + let push = build.add_dataflow_op( + ListOp::Push + .with_type(BOOL_T) + .to_extension_op(®) + .unwrap(), + pop.outputs(), + )?; + let mut h = build.finish_hugr_with_outputs(push.outputs(), ®)?; + constant_fold_pass(&mut h, ®); + + assert_fully_folded(&h, &list); + Ok(()) + } + fn assert_fully_folded(h: &Hugr, expected_const: &Const) { // check the hugr just loads and returns a single const let mut node_count = 0; diff --git a/src/extension.rs b/src/extension.rs index 237dc3f4f..f87925755 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -72,6 +72,16 @@ impl ExtensionRegistry { } } +impl IntoIterator for ExtensionRegistry { + type Item = (ExtensionId, Extension); + + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + /// An Extension Registry containing no extensions. pub const EMPTY_REG: ExtensionRegistry = ExtensionRegistry(BTreeMap::new()); diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index ebec9bda7..d36b44252 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -5,7 +5,13 @@ use serde::{Deserialize, Serialize}; use smol_str::SmolStr; use crate::{ - extension::{ExtensionId, ExtensionSet, TypeDef, TypeDefBound}, + algorithm::const_fold::sorted_consts, + extension::{ + simple_op::{MakeExtensionOp, OpLoadError}, + ConstFold, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef, + TypeDefBound, + }, + ops::{self, custom::ExtensionOp, OpName}, types::{ type_param::{TypeArg, TypeParam}, CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeBound, @@ -47,7 +53,9 @@ impl CustomConst for ListValue { CustomCheckFailure::Message("List type check fail.".to_string()) }; - get_type(&LIST_TYPENAME) + EXTENSION + .get_type(&LIST_TYPENAME) + .unwrap() .check_custom(typ) .map_err(|_| error())?; @@ -72,6 +80,55 @@ impl CustomConst for ListValue { .union(&ExtensionSet::singleton(&EXTENSION_NAME)) } } + +struct PopFold; + +impl ConstFold for PopFold { + fn fold( + &self, + type_args: &[TypeArg], + consts: &[(crate::IncomingPort, ops::Const)], + ) -> crate::extension::ConstFoldResult { + let [TypeArg::Type { ty }] = type_args else { + return None; + }; + let [list]: [&ops::Const; 1] = sorted_consts(consts).try_into().ok()?; + let list: &ListValue = list.get_custom_value().expect("Should be list value."); + let mut list = list.clone(); + let elem = list.0.pop()?; // empty list fails to evaluate "pop" + let list = make_list_const(list, type_args); + let elem = ops::Const::new(elem, ty.clone()).unwrap(); + + Some(vec![(0.into(), list), (1.into(), elem)]) + } +} + +pub(crate) fn make_list_const(list: ListValue, type_args: &[TypeArg]) -> ops::Const { + let list_type_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); + ops::Const::new( + list.into(), + Type::new_extension(list_type_def.instantiate(type_args).unwrap()), + ) + .unwrap() +} + +struct PushFold; + +impl ConstFold for PushFold { + fn fold( + &self, + type_args: &[TypeArg], + consts: &[(crate::IncomingPort, ops::Const)], + ) -> crate::extension::ConstFoldResult { + let [list, elem]: [&ops::Const; 2] = sorted_consts(consts).try_into().ok()?; + let list: &ListValue = list.get_custom_value().expect("Should be list value."); + let mut list = list.clone(); + list.0.push(elem.value().clone()); + let list = make_list_const(list, type_args); + + Some(vec![(0.into(), list)]) + } +} const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; fn extension() -> Extension { @@ -87,7 +144,7 @@ fn extension() -> Extension { .unwrap(); let list_type_def = extension.get_type(&LIST_TYPENAME).unwrap(); - let (l, e) = list_and_elem_type(list_type_def); + let (l, e) = list_and_elem_type_vars(list_type_def); extension .add_op( POP_NAME, @@ -97,14 +154,17 @@ fn extension() -> Extension { FunctionType::new(vec![l.clone()], vec![l.clone(), e.clone()]), ), ) - .unwrap(); + .unwrap() + .set_constant_folder(PopFold); extension .add_op( PUSH_NAME, "Push to back of list".into(), PolyFuncType::new(vec![TP], FunctionType::new(vec![l.clone(), e], vec![l])), ) - .unwrap(); + .unwrap() + .set_constant_folder(PushFold); + extension } @@ -113,11 +173,18 @@ lazy_static! { pub static ref EXTENSION: Extension = extension(); } -fn get_type(name: &str) -> &TypeDef { - EXTENSION.get_type(name).unwrap() +/// Get the type of a list of `elem_type` +pub fn list_type(elem_type: Type) -> Type { + Type::new_extension( + EXTENSION + .get_type(&LIST_TYPENAME) + .unwrap() + .instantiate(vec![TypeArg::Type { ty: elem_type }]) + .unwrap(), + ) } -fn list_and_elem_type(list_type_def: &TypeDef) -> (Type, Type) { +fn list_and_elem_type_vars(list_type_def: &TypeDef) -> (Type, Type) { let elem_type = Type::new_var_use(0, TypeBound::Any); let list_type = Type::new_extension( list_type_def @@ -126,22 +193,107 @@ fn list_and_elem_type(list_type_def: &TypeDef) -> (Type, Type) { ); (list_type, elem_type) } + +/// A list operation +#[derive(Debug, Clone, PartialEq)] +pub enum ListOp { + /// Pop from end of list + Pop, + /// Push to end of list + Push, +} + +impl ListOp { + /// Instantiate a list operation with an `element_type` + pub fn with_type(self, element_type: Type) -> ListOpInst { + ListOpInst { + elem_type: element_type, + op: self, + } + } +} + +/// A list operation with a concrete element type. +#[derive(Debug, Clone, PartialEq)] +pub struct ListOpInst { + op: ListOp, + elem_type: Type, +} + +impl OpName for ListOpInst { + fn name(&self) -> SmolStr { + match self.op { + ListOp::Pop => POP_NAME, + ListOp::Push => PUSH_NAME, + } + } +} + +impl MakeExtensionOp for ListOpInst { + fn from_extension_op( + ext_op: &ExtensionOp, + ) -> Result { + let [TypeArg::Type { ty }] = ext_op.args() else { + return Err(SignatureError::InvalidTypeArgs.into()); + }; + let name = ext_op.def().name(); + let op = match name { + // can't use const SmolStr in pattern + _ if name == &POP_NAME => ListOp::Pop, + _ if name == &PUSH_NAME => ListOp::Push, + _ => return Err(OpLoadError::NotMember(name.to_string())), + }; + + Ok(Self { + elem_type: ty.clone(), + op, + }) + } + + fn type_args(&self) -> Vec { + vec![TypeArg::Type { + ty: self.elem_type.clone(), + }] + } +} + +impl ListOpInst { + /// Convert this list operation to an [`ExtensionOp`] by providing a + /// registry to validate the element type against. + pub fn to_extension_op(self, elem_type_registry: &ExtensionRegistry) -> Option { + let registry = ExtensionRegistry::try_new( + elem_type_registry + .clone() + .into_iter() + // ignore self if already in registry + .filter_map(|(_, ext)| (ext.name() != EXTENSION.name()).then_some(ext)) + .chain(std::iter::once(EXTENSION.to_owned())), + ) + .unwrap(); + ExtensionOp::new( + registry.get(&EXTENSION_NAME)?.get_op(&self.name())?.clone(), + self.type_args(), + ®istry, + ) + .ok() + } +} + #[cfg(test)] mod test { use crate::{ extension::{ prelude::{ConstUsize, QB_T, USIZE_T}, - ExtensionRegistry, OpDef, PRELUDE, + ExtensionRegistry, PRELUDE, }, + ops::OpTrait, std_extensions::arithmetic::float_types::{self, ConstF64, FLOAT64_TYPE}, - types::{type_param::TypeArg, Type, TypeRow}, + types::{type_param::TypeArg, TypeRow}, Extension, }; use super::*; - fn get_op(name: &str) -> &OpDef { - EXTENSION.get_op(name).unwrap() - } + #[test] fn test_extension() { let r: Extension = extension(); @@ -174,40 +326,30 @@ mod test { #[test] fn test_list_ops() { - let reg = ExtensionRegistry::try_new([ - EXTENSION.to_owned(), - PRELUDE.to_owned(), - float_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let pop_sig = get_op(&POP_NAME) - .compute_signature(&[TypeArg::Type { ty: QB_T }], ®) - .unwrap(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]) + .unwrap(); + let pop_op = ListOp::Pop.with_type(QB_T); + let pop_ext = pop_op.clone().to_extension_op(®).unwrap(); + assert_eq!(ListOpInst::from_extension_op(&pop_ext).unwrap(), pop_op); + let pop_sig = pop_ext.dataflow_signature().unwrap(); - let list_type = Type::new_extension(CustomType::new( - LIST_TYPENAME, - vec![TypeArg::Type { ty: QB_T }], - EXTENSION_NAME, - TypeBound::Any, - )); + let list_t = list_type(QB_T); - let both_row: TypeRow = vec![list_type.clone(), QB_T].into(); - let just_list_row: TypeRow = vec![list_type].into(); + let both_row: TypeRow = vec![list_t.clone(), QB_T].into(); + let just_list_row: TypeRow = vec![list_t].into(); assert_eq!(pop_sig.input(), &just_list_row); assert_eq!(pop_sig.output(), &both_row); - let push_sig = get_op(&PUSH_NAME) - .compute_signature(&[TypeArg::Type { ty: FLOAT64_TYPE }], ®) - .unwrap(); + let push_op = ListOp::Push.with_type(FLOAT64_TYPE); + let push_ext = push_op.clone().to_extension_op(®).unwrap(); + assert_eq!(ListOpInst::from_extension_op(&push_ext).unwrap(), push_op); + let push_sig = push_ext.dataflow_signature().unwrap(); - let list_type = Type::new_extension(CustomType::new( - LIST_TYPENAME, - vec![TypeArg::Type { ty: FLOAT64_TYPE }], - EXTENSION_NAME, - TypeBound::Copyable, - )); - let both_row: TypeRow = vec![list_type.clone(), FLOAT64_TYPE].into(); - let just_list_row: TypeRow = vec![list_type].into(); + let list_t = list_type(FLOAT64_TYPE); + + let both_row: TypeRow = vec![list_t.clone(), FLOAT64_TYPE].into(); + let just_list_row: TypeRow = vec![list_t].into(); assert_eq!(push_sig.input(), &both_row); assert_eq!(push_sig.output(), &just_list_row);