Skip to content

Commit

Permalink
move op adding to extension
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Nov 27, 2023
1 parent 58b6eab commit 64351e2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 25 deletions.
15 changes: 15 additions & 0 deletions src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,21 @@ impl Extension {
Entry::Vacant(ve) => Ok(Arc::get_mut(ve.insert(Arc::new(op))).unwrap()),
}
}

pub fn add_op_enum(
&mut self,
op: &impl super::simple_op::OpEnum,
) -> Result<&mut OpDef, ExtensionBuildError> {
let def = self.add_op(
op.name().into(),
op.description().to_string(),
op.def_signature(),
)?;

op.post_opdef(def);

Ok(def)
}
}

#[cfg(test)]
Expand Down
34 changes: 9 additions & 25 deletions src/extension/simple_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,36 +76,12 @@ pub trait OpEnum: FromStr + IntoEnumIterator + IntoStaticSt {
/// Try to load one of the operations of this set from an [OpDef].
fn from_op_def(op_def: &OpDef, args: &[TypeArg]) -> Result<Self, Self::LoadError>;

/// Add an operation to an extension.
fn add_to_extension<'e>(
&self,
ext: &'e mut Extension,
) -> Result<&'e OpDef, ExtensionBuildError> {
let def = ext.add_op(
self.name().into(),
self.description().to_string(),
self.def_signature(),
)?;

self.post_opdef(def);

Ok(def)
}

/// Iterator over all operations in the set. Non-trivial variants will have
/// default values used for the members.
fn all_variants() -> <Self as IntoEnumIterator>::Iterator {
<Self as IntoEnumIterator>::iter()
}

/// load all variants of a `SimpleOpEnum` in to an extension as op defs.
fn load_all_ops(extension: &mut Extension) -> Result<(), ExtensionBuildError> {
for op in Self::all_variants() {
op.add_to_extension(extension)?;
}
Ok(())
}

/// Try to instantiate a variant from an [OpType]. Default behaviour assumes
/// an [ExtensionOp] and loads from the name.
fn from_optype(op: &OpType) -> Option<Self> {
Expand All @@ -124,6 +100,14 @@ pub trait OpEnum: FromStr + IntoEnumIterator + IntoStaticSt {
op_enum: self,
}
}

/// load all variants of a [OpEnum] in to an extension as op defs.
fn load_all_ops(extension: &mut Extension) -> Result<(), ExtensionBuildError> {
for op in Self::all_variants() {
extension.add_op_enum(&op)?;
}
Ok(())
}
}

pub struct RegisteredEnum<'r, T> {
Expand Down Expand Up @@ -208,7 +192,7 @@ mod test {
let ext_name = ExtensionId::new("dummy").unwrap();
let mut e = Extension::new(ext_name.clone());

o.add_to_extension(&mut e).unwrap();
e.add_op_enum(&o).unwrap();

assert_eq!(
DummyEnum::from_op_def(e.get_op(o.name()).unwrap(), &[]).unwrap(),
Expand Down

0 comments on commit 64351e2

Please sign in to comment.