-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: constant folding for list operations #795
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,13 @@ use serde::{Deserialize, Serialize}; | |
use smol_str::SmolStr; | ||
|
||
use crate::{ | ||
extension::{ExtensionId, ExtensionSet, TypeDef, TypeDefBound}, | ||
algorithm::const_fold::sorted_consts, | ||
extension::{ | ||
simple_op::{MakeExtensionOp, OpLoadError}, | ||
ConstFold, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef, | ||
TypeDefBound, | ||
}, | ||
ops::{self, custom::ExtensionOp, OpName}, | ||
types::{ | ||
type_param::{TypeArg, TypeParam}, | ||
CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeBound, | ||
|
@@ -47,7 +53,9 @@ impl CustomConst for ListValue { | |
CustomCheckFailure::Message("List type check fail.".to_string()) | ||
}; | ||
|
||
get_type(&LIST_TYPENAME) | ||
EXTENSION | ||
.get_type(&LIST_TYPENAME) | ||
.unwrap() | ||
.check_custom(typ) | ||
.map_err(|_| error())?; | ||
|
||
|
@@ -72,6 +80,55 @@ impl CustomConst for ListValue { | |
.union(&ExtensionSet::singleton(&EXTENSION_NAME)) | ||
} | ||
} | ||
|
||
struct PopFold; | ||
|
||
impl ConstFold for PopFold { | ||
fn fold( | ||
&self, | ||
type_args: &[TypeArg], | ||
consts: &[(crate::IncomingPort, ops::Const)], | ||
) -> crate::extension::ConstFoldResult { | ||
let [TypeArg::Type { ty }] = type_args else { | ||
return None; | ||
}; | ||
let [list]: [&ops::Const; 1] = sorted_consts(consts).try_into().ok()?; | ||
let list: &ListValue = list.get_custom_value().expect("Should be list value."); | ||
let mut list = list.clone(); | ||
let elem = list.0.pop()?; // empty list fails to evaluate "pop" | ||
let list = make_list_const(list, type_args); | ||
let elem = ops::Const::new(elem, ty.clone()).unwrap(); | ||
|
||
Some(vec![(0.into(), list), (1.into(), elem)]) | ||
} | ||
} | ||
|
||
pub(crate) fn make_list_const(list: ListValue, type_args: &[TypeArg]) -> ops::Const { | ||
let list_type_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); | ||
ops::Const::new( | ||
list.into(), | ||
Type::new_extension(list_type_def.instantiate(type_args).unwrap()), | ||
) | ||
.unwrap() | ||
} | ||
|
||
struct PushFold; | ||
|
||
impl ConstFold for PushFold { | ||
fn fold( | ||
&self, | ||
type_args: &[TypeArg], | ||
consts: &[(crate::IncomingPort, ops::Const)], | ||
) -> crate::extension::ConstFoldResult { | ||
let [list, elem]: [&ops::Const; 2] = sorted_consts(consts).try_into().ok()?; | ||
let list: &ListValue = list.get_custom_value().expect("Should be list value."); | ||
let mut list = list.clone(); | ||
list.0.push(elem.value().clone()); | ||
let list = make_list_const(list, type_args); | ||
|
||
Some(vec![(0.into(), list)]) | ||
} | ||
} | ||
const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; | ||
|
||
fn extension() -> Extension { | ||
|
@@ -87,7 +144,7 @@ fn extension() -> Extension { | |
.unwrap(); | ||
let list_type_def = extension.get_type(&LIST_TYPENAME).unwrap(); | ||
|
||
let (l, e) = list_and_elem_type(list_type_def); | ||
let (l, e) = list_and_elem_type_vars(list_type_def); | ||
extension | ||
.add_op( | ||
POP_NAME, | ||
|
@@ -97,14 +154,17 @@ fn extension() -> Extension { | |
FunctionType::new(vec![l.clone()], vec![l.clone(), e.clone()]), | ||
), | ||
) | ||
.unwrap(); | ||
.unwrap() | ||
.set_constant_folder(PopFold); | ||
extension | ||
.add_op( | ||
PUSH_NAME, | ||
"Push to back of list".into(), | ||
PolyFuncType::new(vec![TP], FunctionType::new(vec![l.clone(), e], vec![l])), | ||
) | ||
.unwrap(); | ||
.unwrap() | ||
.set_constant_folder(PushFold); | ||
|
||
extension | ||
} | ||
|
||
|
@@ -113,11 +173,18 @@ lazy_static! { | |
pub static ref EXTENSION: Extension = extension(); | ||
} | ||
|
||
fn get_type(name: &str) -> &TypeDef { | ||
EXTENSION.get_type(name).unwrap() | ||
/// Get the type of a list of `elem_type` | ||
pub fn list_type(elem_type: Type) -> Type { | ||
Type::new_extension( | ||
EXTENSION | ||
.get_type(&LIST_TYPENAME) | ||
.unwrap() | ||
.instantiate(vec![TypeArg::Type { ty: elem_type }]) | ||
.unwrap(), | ||
) | ||
} | ||
|
||
fn list_and_elem_type(list_type_def: &TypeDef) -> (Type, Type) { | ||
fn list_and_elem_type_vars(list_type_def: &TypeDef) -> (Type, Type) { | ||
let elem_type = Type::new_var_use(0, TypeBound::Any); | ||
let list_type = Type::new_extension( | ||
list_type_def | ||
|
@@ -126,22 +193,107 @@ fn list_and_elem_type(list_type_def: &TypeDef) -> (Type, Type) { | |
); | ||
(list_type, elem_type) | ||
} | ||
|
||
/// A list operation | ||
#[derive(Debug, Clone, PartialEq)] | ||
pub enum ListOp { | ||
/// Pop from end of list | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the behaviour of the program if the list is empty? (Panic? Error? Doesn't seem to be captured in the signature.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. panics, should be clear in spec once we have it |
||
Pop, | ||
/// Push to end of list | ||
Push, | ||
} | ||
|
||
impl ListOp { | ||
/// Instantiate a list operation with an `element_type` | ||
pub fn with_type(self, element_type: Type) -> ListOpInst { | ||
ListOpInst { | ||
elem_type: element_type, | ||
op: self, | ||
} | ||
} | ||
} | ||
|
||
/// A list operation with a concrete element type. | ||
#[derive(Debug, Clone, PartialEq)] | ||
pub struct ListOpInst { | ||
op: ListOp, | ||
elem_type: Type, | ||
} | ||
|
||
impl OpName for ListOpInst { | ||
fn name(&self) -> SmolStr { | ||
match self.op { | ||
ListOp::Pop => POP_NAME, | ||
ListOp::Push => PUSH_NAME, | ||
} | ||
} | ||
} | ||
|
||
impl MakeExtensionOp for ListOpInst { | ||
fn from_extension_op( | ||
ext_op: &ExtensionOp, | ||
) -> Result<Self, crate::extension::simple_op::OpLoadError> { | ||
let [TypeArg::Type { ty }] = ext_op.args() else { | ||
return Err(SignatureError::InvalidTypeArgs.into()); | ||
}; | ||
let name = ext_op.def().name(); | ||
let op = match name { | ||
// can't use const SmolStr in pattern | ||
_ if name == &POP_NAME => ListOp::Pop, | ||
_ if name == &PUSH_NAME => ListOp::Push, | ||
_ => return Err(OpLoadError::NotMember(name.to_string())), | ||
}; | ||
|
||
Ok(Self { | ||
elem_type: ty.clone(), | ||
op, | ||
}) | ||
} | ||
|
||
fn type_args(&self) -> Vec<TypeArg> { | ||
vec![TypeArg::Type { | ||
ty: self.elem_type.clone(), | ||
}] | ||
} | ||
} | ||
|
||
impl ListOpInst { | ||
/// Convert this list operation to an [`ExtensionOp`] by providing a | ||
/// registry to validate the element tyoe against. | ||
ss2165 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pub fn to_extension_op(self, elem_type_registry: &ExtensionRegistry) -> Option<ExtensionOp> { | ||
let registry = ExtensionRegistry::try_new( | ||
elem_type_registry | ||
.clone() | ||
.into_iter() | ||
// ignore self if already in registry | ||
.filter_map(|(_, ext)| (ext.name() != EXTENSION.name()).then_some(ext)) | ||
.chain(std::iter::once(EXTENSION.to_owned())), | ||
) | ||
.unwrap(); | ||
ExtensionOp::new( | ||
registry.get(&EXTENSION_NAME)?.get_op(&self.name())?.clone(), | ||
self.type_args(), | ||
®istry, | ||
) | ||
.ok() | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use crate::{ | ||
extension::{ | ||
prelude::{ConstUsize, QB_T, USIZE_T}, | ||
ExtensionRegistry, OpDef, PRELUDE, | ||
ExtensionRegistry, PRELUDE, | ||
}, | ||
ops::OpTrait, | ||
std_extensions::arithmetic::float_types::{self, ConstF64, FLOAT64_TYPE}, | ||
types::{type_param::TypeArg, Type, TypeRow}, | ||
types::{type_param::TypeArg, TypeRow}, | ||
Extension, | ||
}; | ||
|
||
use super::*; | ||
fn get_op(name: &str) -> &OpDef { | ||
EXTENSION.get_op(name).unwrap() | ||
} | ||
|
||
#[test] | ||
fn test_extension() { | ||
let r: Extension = extension(); | ||
|
@@ -174,40 +326,30 @@ mod test { | |
|
||
#[test] | ||
fn test_list_ops() { | ||
let reg = ExtensionRegistry::try_new([ | ||
EXTENSION.to_owned(), | ||
PRELUDE.to_owned(), | ||
float_types::EXTENSION.to_owned(), | ||
]) | ||
.unwrap(); | ||
let pop_sig = get_op(&POP_NAME) | ||
.compute_signature(&[TypeArg::Type { ty: QB_T }], ®) | ||
.unwrap(); | ||
let reg = | ||
ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]) | ||
.unwrap(); | ||
let pop_op = ListOp::Pop.with_type(QB_T); | ||
let pop_ext = pop_op.clone().to_extension_op(®).unwrap(); | ||
assert_eq!(ListOpInst::from_extension_op(&pop_ext).unwrap(), pop_op); | ||
let pop_sig = pop_ext.dataflow_signature().unwrap(); | ||
|
||
let list_type = Type::new_extension(CustomType::new( | ||
LIST_TYPENAME, | ||
vec![TypeArg::Type { ty: QB_T }], | ||
EXTENSION_NAME, | ||
TypeBound::Any, | ||
)); | ||
let list_t = list_type(QB_T); | ||
|
||
let both_row: TypeRow = vec![list_type.clone(), QB_T].into(); | ||
let just_list_row: TypeRow = vec![list_type].into(); | ||
let both_row: TypeRow = vec![list_t.clone(), QB_T].into(); | ||
let just_list_row: TypeRow = vec![list_t].into(); | ||
assert_eq!(pop_sig.input(), &just_list_row); | ||
assert_eq!(pop_sig.output(), &both_row); | ||
|
||
let push_sig = get_op(&PUSH_NAME) | ||
.compute_signature(&[TypeArg::Type { ty: FLOAT64_TYPE }], ®) | ||
.unwrap(); | ||
let push_op = ListOp::Push.with_type(FLOAT64_TYPE); | ||
let push_ext = push_op.clone().to_extension_op(®).unwrap(); | ||
assert_eq!(ListOpInst::from_extension_op(&push_ext).unwrap(), push_op); | ||
let push_sig = push_ext.dataflow_signature().unwrap(); | ||
|
||
let list_type = Type::new_extension(CustomType::new( | ||
LIST_TYPENAME, | ||
vec![TypeArg::Type { ty: FLOAT64_TYPE }], | ||
EXTENSION_NAME, | ||
TypeBound::Copyable, | ||
)); | ||
let both_row: TypeRow = vec![list_type.clone(), FLOAT64_TYPE].into(); | ||
let just_list_row: TypeRow = vec![list_type].into(); | ||
let list_t = list_type(FLOAT64_TYPE); | ||
|
||
let both_row: TypeRow = vec![list_t.clone(), FLOAT64_TYPE].into(); | ||
let just_list_row: TypeRow = vec![list_t].into(); | ||
|
||
assert_eq!(push_sig.input(), &both_row); | ||
assert_eq!(push_sig.output(), &just_list_row); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
collections
extension doesn't seem to be mentioned in the main spec; is there a spec for it somewhere else?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet, it is somewhat of a placeholder for adding more list operations. See #725