Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: constant folding for list operations #795

Merged
merged 4 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,49 @@ mod test {
Ok(())
}

#[cfg(not(feature = "extension_inference"))] // inference fails for test graph, shouldn't
#[test]
fn test_list_ops() -> Result<(), Box<dyn std::error::Error>> {
use crate::std_extensions::collections::{self, make_list_const, ListOp, ListValue};
use crate::types::TypeArg;

let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
logic::EXTENSION.to_owned(),
collections::EXTENSION.to_owned(),
])
.unwrap();
let list = make_list_const(
ListValue::new(vec![Value::unit_sum(1)]),
&[TypeArg::Type { ty: BOOL_T }],
);
let mut build = DFGBuilder::new(FunctionType::new(
type_row![],
vec![list.const_type().clone()],
))
.unwrap();

let list_wire = build.add_load_const(list.clone())?;

let pop = build.add_dataflow_op(
ListOp::Pop.with_type(BOOL_T).to_extension_op(&reg).unwrap(),
[list_wire],
)?;

let push = build.add_dataflow_op(
ListOp::Push
.with_type(BOOL_T)
.to_extension_op(&reg)
.unwrap(),
pop.outputs(),
)?;
let mut h = build.finish_hugr_with_outputs(push.outputs(), &reg)?;
constant_fold_pass(&mut h, &reg);

assert_fully_folded(&h, &list);
Ok(())
}

fn assert_fully_folded(h: &Hugr, expected_const: &Const) {
// check the hugr just loads and returns a single const
let mut node_count = 0;
Expand Down
10 changes: 10 additions & 0 deletions src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ impl ExtensionRegistry {
}
}

impl IntoIterator for ExtensionRegistry {
type Item = (ExtensionId, Extension);

type IntoIter = <BTreeMap<ExtensionId, Extension> as IntoIterator>::IntoIter;

fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}

/// An Extension Registry containing no extensions.
pub const EMPTY_REG: ExtensionRegistry = ExtensionRegistry(BTreeMap::new());

Expand Down
224 changes: 183 additions & 41 deletions src/std_extensions/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ use serde::{Deserialize, Serialize};
use smol_str::SmolStr;

use crate::{
extension::{ExtensionId, ExtensionSet, TypeDef, TypeDefBound},
algorithm::const_fold::sorted_consts,
extension::{
simple_op::{MakeExtensionOp, OpLoadError},
ConstFold, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef,
TypeDefBound,
},
ops::{self, custom::ExtensionOp, OpName},
types::{
type_param::{TypeArg, TypeParam},
CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeBound,
Expand Down Expand Up @@ -47,7 +53,9 @@ impl CustomConst for ListValue {
CustomCheckFailure::Message("List type check fail.".to_string())
};

get_type(&LIST_TYPENAME)
EXTENSION
.get_type(&LIST_TYPENAME)
.unwrap()
.check_custom(typ)
.map_err(|_| error())?;

Expand All @@ -72,6 +80,55 @@ impl CustomConst for ListValue {
.union(&ExtensionSet::singleton(&EXTENSION_NAME))
}
}

struct PopFold;

impl ConstFold for PopFold {
fn fold(
&self,
type_args: &[TypeArg],
consts: &[(crate::IncomingPort, ops::Const)],
) -> crate::extension::ConstFoldResult {
let [TypeArg::Type { ty }] = type_args else {
return None;
};
let [list]: [&ops::Const; 1] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let mut list = list.clone();
let elem = list.0.pop()?; // empty list fails to evaluate "pop"
let list = make_list_const(list, type_args);
let elem = ops::Const::new(elem, ty.clone()).unwrap();

Some(vec![(0.into(), list), (1.into(), elem)])
}
}

pub(crate) fn make_list_const(list: ListValue, type_args: &[TypeArg]) -> ops::Const {
let list_type_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap();
ops::Const::new(
list.into(),
Type::new_extension(list_type_def.instantiate(type_args).unwrap()),
)
.unwrap()
}

struct PushFold;

impl ConstFold for PushFold {
fn fold(
&self,
type_args: &[TypeArg],
consts: &[(crate::IncomingPort, ops::Const)],
) -> crate::extension::ConstFoldResult {
let [list, elem]: [&ops::Const; 2] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let mut list = list.clone();
list.0.push(elem.value().clone());
let list = make_list_const(list, type_args);

Some(vec![(0.into(), list)])
}
}
const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };

fn extension() -> Extension {
Expand All @@ -87,7 +144,7 @@ fn extension() -> Extension {
.unwrap();
let list_type_def = extension.get_type(&LIST_TYPENAME).unwrap();

let (l, e) = list_and_elem_type(list_type_def);
let (l, e) = list_and_elem_type_vars(list_type_def);
extension
.add_op(
POP_NAME,
Expand All @@ -97,14 +154,17 @@ fn extension() -> Extension {
FunctionType::new(vec![l.clone()], vec![l.clone(), e.clone()]),
),
)
.unwrap();
.unwrap()
.set_constant_folder(PopFold);
extension
.add_op(
PUSH_NAME,
"Push to back of list".into(),
PolyFuncType::new(vec![TP], FunctionType::new(vec![l.clone(), e], vec![l])),
)
.unwrap();
.unwrap()
.set_constant_folder(PushFold);

extension
}

Expand All @@ -113,11 +173,18 @@ lazy_static! {
pub static ref EXTENSION: Extension = extension();
}

fn get_type(name: &str) -> &TypeDef {
EXTENSION.get_type(name).unwrap()
/// Get the type of a list of `elem_type`
pub fn list_type(elem_type: Type) -> Type {
Type::new_extension(
EXTENSION
.get_type(&LIST_TYPENAME)
.unwrap()
.instantiate(vec![TypeArg::Type { ty: elem_type }])
.unwrap(),
)
}

fn list_and_elem_type(list_type_def: &TypeDef) -> (Type, Type) {
fn list_and_elem_type_vars(list_type_def: &TypeDef) -> (Type, Type) {
let elem_type = Type::new_var_use(0, TypeBound::Any);
let list_type = Type::new_extension(
list_type_def
Expand All @@ -126,22 +193,107 @@ fn list_and_elem_type(list_type_def: &TypeDef) -> (Type, Type) {
);
(list_type, elem_type)
}

/// A list operation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The collections extension doesn't seem to be mentioned in the main spec; is there a spec for it somewhere else?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet, it is somewhat of a placeholder for adding more list operations. See #725

#[derive(Debug, Clone, PartialEq)]
pub enum ListOp {
/// Pop from end of list
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the behaviour of the program if the list is empty? (Panic? Error? Doesn't seem to be captured in the signature.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

panics, should be clear in spec once we have it

Pop,
/// Push to end of list
Push,
}

impl ListOp {
/// Instantiate a list operation with an `element_type`
pub fn with_type(self, element_type: Type) -> ListOpInst {
ListOpInst {
elem_type: element_type,
op: self,
}
}
}

/// A list operation with a concrete element type.
#[derive(Debug, Clone, PartialEq)]
pub struct ListOpInst {
op: ListOp,
elem_type: Type,
}

impl OpName for ListOpInst {
fn name(&self) -> SmolStr {
match self.op {
ListOp::Pop => POP_NAME,
ListOp::Push => PUSH_NAME,
}
}
}

impl MakeExtensionOp for ListOpInst {
fn from_extension_op(
ext_op: &ExtensionOp,
) -> Result<Self, crate::extension::simple_op::OpLoadError> {
let [TypeArg::Type { ty }] = ext_op.args() else {
return Err(SignatureError::InvalidTypeArgs.into());
};
let name = ext_op.def().name();
let op = match name {
// can't use const SmolStr in pattern
_ if name == &POP_NAME => ListOp::Pop,
_ if name == &PUSH_NAME => ListOp::Push,
_ => return Err(OpLoadError::NotMember(name.to_string())),
};

Ok(Self {
elem_type: ty.clone(),
op,
})
}

fn type_args(&self) -> Vec<TypeArg> {
vec![TypeArg::Type {
ty: self.elem_type.clone(),
}]
}
}

impl ListOpInst {
/// Convert this list operation to an [`ExtensionOp`] by providing a
/// registry to validate the element tyoe against.
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
pub fn to_extension_op(self, elem_type_registry: &ExtensionRegistry) -> Option<ExtensionOp> {
let registry = ExtensionRegistry::try_new(
elem_type_registry
.clone()
.into_iter()
// ignore self if already in registry
.filter_map(|(_, ext)| (ext.name() != EXTENSION.name()).then_some(ext))
.chain(std::iter::once(EXTENSION.to_owned())),
)
.unwrap();
ExtensionOp::new(
registry.get(&EXTENSION_NAME)?.get_op(&self.name())?.clone(),
self.type_args(),
&registry,
)
.ok()
}
}

#[cfg(test)]
mod test {
use crate::{
extension::{
prelude::{ConstUsize, QB_T, USIZE_T},
ExtensionRegistry, OpDef, PRELUDE,
ExtensionRegistry, PRELUDE,
},
ops::OpTrait,
std_extensions::arithmetic::float_types::{self, ConstF64, FLOAT64_TYPE},
types::{type_param::TypeArg, Type, TypeRow},
types::{type_param::TypeArg, TypeRow},
Extension,
};

use super::*;
fn get_op(name: &str) -> &OpDef {
EXTENSION.get_op(name).unwrap()
}

#[test]
fn test_extension() {
let r: Extension = extension();
Expand Down Expand Up @@ -174,40 +326,30 @@ mod test {

#[test]
fn test_list_ops() {
let reg = ExtensionRegistry::try_new([
EXTENSION.to_owned(),
PRELUDE.to_owned(),
float_types::EXTENSION.to_owned(),
])
.unwrap();
let pop_sig = get_op(&POP_NAME)
.compute_signature(&[TypeArg::Type { ty: QB_T }], &reg)
.unwrap();
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()])
.unwrap();
let pop_op = ListOp::Pop.with_type(QB_T);
let pop_ext = pop_op.clone().to_extension_op(&reg).unwrap();
assert_eq!(ListOpInst::from_extension_op(&pop_ext).unwrap(), pop_op);
let pop_sig = pop_ext.dataflow_signature().unwrap();

let list_type = Type::new_extension(CustomType::new(
LIST_TYPENAME,
vec![TypeArg::Type { ty: QB_T }],
EXTENSION_NAME,
TypeBound::Any,
));
let list_t = list_type(QB_T);

let both_row: TypeRow = vec![list_type.clone(), QB_T].into();
let just_list_row: TypeRow = vec![list_type].into();
let both_row: TypeRow = vec![list_t.clone(), QB_T].into();
let just_list_row: TypeRow = vec![list_t].into();
assert_eq!(pop_sig.input(), &just_list_row);
assert_eq!(pop_sig.output(), &both_row);

let push_sig = get_op(&PUSH_NAME)
.compute_signature(&[TypeArg::Type { ty: FLOAT64_TYPE }], &reg)
.unwrap();
let push_op = ListOp::Push.with_type(FLOAT64_TYPE);
let push_ext = push_op.clone().to_extension_op(&reg).unwrap();
assert_eq!(ListOpInst::from_extension_op(&push_ext).unwrap(), push_op);
let push_sig = push_ext.dataflow_signature().unwrap();

let list_type = Type::new_extension(CustomType::new(
LIST_TYPENAME,
vec![TypeArg::Type { ty: FLOAT64_TYPE }],
EXTENSION_NAME,
TypeBound::Copyable,
));
let both_row: TypeRow = vec![list_type.clone(), FLOAT64_TYPE].into();
let just_list_row: TypeRow = vec![list_type].into();
let list_t = list_type(FLOAT64_TYPE);

let both_row: TypeRow = vec![list_t.clone(), FLOAT64_TYPE].into();
let just_list_row: TypeRow = vec![list_t].into();

assert_eq!(push_sig.input(), &both_row);
assert_eq!(push_sig.output(), &just_list_row);
Expand Down
Loading