diff --git a/einsum-solver/src/subscripts.rs b/einsum-solver/src/subscripts.rs index 8e7a42b..92862e5 100644 --- a/einsum-solver/src/subscripts.rs +++ b/einsum-solver/src/subscripts.rs @@ -1,6 +1,6 @@ //! Einsum subscripts, e.g. `ij,jk->ik` use crate::parser; -use anyhow::{bail, Result}; +use anyhow::Result; use std::{ collections::{BTreeMap, BTreeSet}, fmt, @@ -205,7 +205,7 @@ impl Subscripts { /// ij,jk,kl->il | arg0 arg1 arg2 -> out0 /// ``` /// - /// will be factorized into + /// will be factorized with `(arg0, arg1)` into /// /// ```text /// ij,jk->ik | arg0 arg1 -> out1 @@ -214,16 +214,17 @@ impl Subscripts { /// /// where `out1` is a new identifier. /// - /// /// ``` /// use einsum_solver::{subscripts::*, parser::RawSubscript}; /// use std::str::FromStr; + /// use maplit::btreeset; /// /// let mut names = Namespace::init(); /// let base = Subscripts::from_raw_indices(&mut names, "ij,jk,kl->il").unwrap(); /// - /// // Factorize along j - /// let (ijjk, ikkl) = base.factorize(&mut names, 'j').unwrap().unwrap(); + /// let (ijjk, ikkl) = base.factorize(&mut names, + /// btreeset!{ Position::User(0), Position::User(1) } + /// ).unwrap(); /// /// let arg0 = &ijjk.inputs[0]; /// assert_eq!(arg0.raw(), &RawSubscript::Indices(vec!['i', 'j'])); @@ -236,54 +237,60 @@ impl Subscripts { /// let out1 = &ijjk.output; /// assert_eq!(out1.raw(), &RawSubscript::Indices(vec!['i', 'k'])); /// assert_eq!(out1.position(), &Position::Intermidiate(1)); - /// - /// // returns `Ok(None)` if subscript is irreducible - /// assert!(ijjk.factorize(&mut names, 'j').unwrap().is_none()); - /// assert!(ikkl.factorize(&mut names, 'k').unwrap().is_none()); /// ``` - pub fn factorize(&self, names: &mut Namespace, index: char) -> Result> { - if !self.contraction_indices().contains(&index) { - bail!("Unknown index: {}", index); - } - - let mut first = Vec::new(); - let mut second = Vec::new(); - let mut out_indices = BTreeSet::new(); + pub fn factorize( + &self, + names: &mut Namespace, + inners: BTreeSet, + ) -> Result<(Self, Self)> { + let mut inner_inputs = Vec::new(); + let mut outer_inputs = Vec::new(); + let mut indices: BTreeMap = BTreeMap::new(); for input in &self.inputs { - let indices = input.indices(); - if indices.iter().any(|label| *label == index) { - first.push(input.clone()); - for c in indices { - if c != index { - out_indices.insert(c); - } + if inners.contains(&input.position) { + inner_inputs.push(input.clone()); + for c in input.indices() { + indices + .entry(c) + .and_modify(|(i, _)| *i += 1) + .or_insert((1, 0)); } } else { - second.push(input.clone()); + outer_inputs.push(input.clone()); + for c in input.indices() { + indices + .entry(c) + .and_modify(|(_, o)| *o += 1) + .or_insert((0, 1)); + } } } - - // irreducible - if second.is_empty() { - return Ok(None); - } - - let output = Subscript { - raw: parser::RawSubscript::Indices(out_indices.into_iter().collect()), + let out = Subscript { + raw: parser::RawSubscript::Indices( + indices + .into_iter() + .filter_map(|(key, (i, o))| { + if i == 1 || (i >= 2 && o > 0) { + Some(key) + } else { + None + } + }) + .collect(), + ), position: names.new(), }; - - second.insert(0, output.clone()); - Ok(Some(( - Self { - inputs: first, - output, + outer_inputs.insert(0, out.clone()); + Ok(( + Subscripts { + inputs: inner_inputs, + output: out, }, - Self { - inputs: second, + Subscripts { + inputs: outer_inputs, output: self.output.clone(), }, - ))) + )) } }