diff --git a/einsum-derive/src/einsum_fn.rs b/einsum-derive/src/einsum_fn.rs index 5b8b139..ea446fa 100644 --- a/einsum-derive/src/einsum_fn.rs +++ b/einsum-derive/src/einsum_fn.rs @@ -168,7 +168,7 @@ pub fn def_einsum_fn(subscripts: &Subscripts) -> TokenStream2 { mod test { use super::*; use crate::format::format_block; - use einsum_solver::subscripts::{Namespace, Subscripts}; + use einsum_solver::{namespace::Namespace, subscripts::Subscripts}; #[test] fn contraction_snapshots() { diff --git a/einsum-derive/src/lib.rs b/einsum-derive/src/lib.rs index 487f549..6f2d634 100644 --- a/einsum-derive/src/lib.rs +++ b/einsum-derive/src/lib.rs @@ -1,6 +1,6 @@ //! proc-macro based einsum implementation -use einsum_solver::subscripts::{Namespace, Subscripts}; +use einsum_solver::{namespace::*, subscripts::Subscripts}; use proc_macro::TokenStream; use proc_macro2::{Span, TokenStream as TokenStream2}; use proc_macro_error::{abort_call_site, proc_macro_error, OptionExt}; diff --git a/einsum-solver/src/lib.rs b/einsum-solver/src/lib.rs index 9fca4e0..ea3c400 100644 --- a/einsum-solver/src/lib.rs +++ b/einsum-solver/src/lib.rs @@ -67,8 +67,10 @@ //! For simplicity, both addition `+` and multiplication `*` are counted as 1 operation, //! and do not consider fused multiplication-addition (FMA). //! In the above `matmul3` example, there are $\\#K \times \\#J$ addition -//! and $2 \times \\#K \times \\#J$ multiplications, +//! and $2 \times \\#K \times \\#J$ multiplications for every indices $(i, l)$, //! where $\\#$ denotes the number of elements in the index sets. +//! Assuming the all sizes of indices are same and denoted by $N$, +//! there are $O(N^4)$ floating point operations. //! //! When we sum up partially along `j`: //! $$ @@ -79,8 +81,8 @@ //! \sum_{k \in K} c_{kl} d_{ik}, //! \text{where} \space d_{ik} = \sum_{j \in J} a_{ij} b_{jk}, //! $$ -//! there are only $2\\#K + 2\\#J$ operations with $\\#I \times \\#K$ -//! memorization storage. +//! there are $O(N^3)$ operations for both computing $d_{ik}$ and final summation +//! with $O(N^2)$ memorization storage. //! //! When is this factorization possible? We know that above `matmul3` example //! is also written as associative matrix product $ABC = A(BC) = (AB)C$, @@ -154,6 +156,7 @@ //! and the objective of this crate is to (heuristically) solve this problem. //! +pub mod namespace; pub mod parser; pub mod path; pub mod subscripts; diff --git a/einsum-solver/src/namespace.rs b/einsum-solver/src/namespace.rs new file mode 100644 index 0000000..1ed47e3 --- /dev/null +++ b/einsum-solver/src/namespace.rs @@ -0,0 +1,35 @@ +/// Names of tensors +/// +/// As the crate level document explains, +/// einsum factorization requires to track names of tensors +/// in addition to subscripts, and this struct manages it. +/// This works as a simple counter, which counts how many intermediate +/// tensor denoted `out{N}` appears and issues new `out{N+1}` identifier. +/// +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Namespace { + last: usize, +} + +impl Namespace { + /// Create new namespace + pub fn init() -> Self { + Namespace { last: 0 } + } + + /// Issue new identifier + pub fn new_ident(&mut self) -> Position { + let pos = Position::Intermidiate(self.last); + self.last += 1; + pos + } +} + +/// Which tensor the subscript specifies +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] +pub enum Position { + /// The tensor which user inputs as N-th argument of einsum + User(usize), + /// The tensor created by einsum in its N-th step + Intermidiate(usize), +} diff --git a/einsum-solver/src/path.rs b/einsum-solver/src/path.rs index 176a0a5..581fab3 100644 --- a/einsum-solver/src/path.rs +++ b/einsum-solver/src/path.rs @@ -19,7 +19,7 @@ impl Path { /// /// ``` /// use std::str::FromStr; - /// use einsum_solver::{path::Path, subscripts::{Subscripts, Namespace}}; + /// use einsum_solver::{path::Path, namespace::Namespace, subscripts::Subscripts}; /// /// let mut names = Namespace::init(); /// let subscripts = Subscripts::from_raw_indices(&mut names, "ij,ji->").unwrap(); diff --git a/einsum-solver/src/subscripts.rs b/einsum-solver/src/subscripts.rs index 8e7a42b..9c6e2c7 100644 --- a/einsum-solver/src/subscripts.rs +++ b/einsum-solver/src/subscripts.rs @@ -1,6 +1,6 @@ //! Einsum subscripts, e.g. `ij,jk->ik` -use crate::parser; -use anyhow::{bail, Result}; +use crate::{namespace::*, parser::*}; +use anyhow::Result; use std::{ collections::{BTreeMap, BTreeSet}, fmt, @@ -9,12 +9,12 @@ use std::{ #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Subscript { - raw: parser::RawSubscript, + raw: RawSubscript, position: Position, } impl Subscript { - pub fn raw(&self) -> &parser::RawSubscript { + pub fn raw(&self) -> &RawSubscript { &self.raw } @@ -24,50 +24,15 @@ impl Subscript { pub fn indices(&self) -> Vec { match &self.raw { - parser::RawSubscript::Indices(indices) => indices.clone(), - parser::RawSubscript::Ellipsis { start, end } => { + RawSubscript::Indices(indices) => indices.clone(), + RawSubscript::Ellipsis { start, end } => { start.iter().chain(end.iter()).cloned().collect() } } } } -/// Names of tensors -/// -/// As the crate level document explains, -/// einsum factorization requires to track names of tensors -/// in addition to subscripts, and this struct manages it. -/// This works as a simple counter, which counts how many intermediate -/// tensor denoted `out{N}` appears and issues new `out{N+1}` identifier. -/// -#[derive(Debug, PartialEq, Eq, Clone)] -pub struct Namespace { - last: usize, -} - -impl Namespace { - /// Create new namespace - pub fn init() -> Self { - Namespace { last: 0 } - } - - /// Issue new identifier - pub fn new(&mut self) -> Position { - let pos = Position::Intermidiate(self.last); - self.last += 1; - pos - } -} - -/// Which tensor the subscript specifies -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] -pub enum Position { - /// The tensor which user inputs as N-th argument of einsum - User(usize), - /// The tensor created by einsum in its N-th step - Intermidiate(usize), -} - +#[cfg_attr(doc, katexit::katexit)] /// Einsum subscripts with tensor names, e.g. `ij,jk->ik | arg0 arg1 -> out` #[derive(Debug, PartialEq, Eq)] pub struct Subscripts { @@ -92,6 +57,16 @@ impl fmt::Display for Subscripts { } impl Subscripts { + /// Returns $\alpha$ if this subscripts requires $O(N^\alpha)$ floating point operation + pub fn compute_order(&self) -> usize { + self.memory_order() + self.contraction_indices().len() + } + + /// Returns $\beta$ if this subscripts requires $O(N^\beta)$ memory + pub fn memory_order(&self) -> usize { + self.output.indices().len() + } + /// Normalize subscripts into "explicit mode" /// /// [numpy.einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html) @@ -109,7 +84,7 @@ impl Subscripts { /// /// ``` /// use std::str::FromStr; - /// use einsum_solver::{subscripts::{Subscripts, Namespace}, parser::RawSubscripts}; + /// use einsum_solver::{subscripts::*, parser::*, namespace::*}; /// /// let mut names = Namespace::init(); /// @@ -124,7 +99,7 @@ impl Subscripts { /// assert_eq!(subscripts.output.raw(), &['i', 'j']); /// ``` /// - pub fn from_raw(names: &mut Namespace, raw: parser::RawSubscripts) -> Self { + pub fn from_raw(names: &mut Namespace, raw: RawSubscripts) -> Self { let inputs = raw .inputs .iter() @@ -134,7 +109,7 @@ impl Subscripts { position: Position::User(i), }) .collect(); - let position = names.new(); + let position = names.new_ident(); if let Some(output) = raw.output { return Subscripts { inputs, @@ -147,7 +122,7 @@ impl Subscripts { let count = count_indices(&inputs); let output = Subscript { - raw: parser::RawSubscript::Indices( + raw: RawSubscript::Indices( count .iter() .filter_map(|(key, value)| if *value == 1 { Some(*key) } else { None }) @@ -159,16 +134,16 @@ impl Subscripts { } pub fn from_raw_indices(names: &mut Namespace, indices: &str) -> Result { - let raw = parser::RawSubscripts::from_str(indices)?; + let raw = RawSubscripts::from_str(indices)?; Ok(Self::from_raw(names, raw)) } - /// Indices to be factorize + /// Indices to be contracted /// /// ``` /// use std::str::FromStr; /// use maplit::btreeset; - /// use einsum_solver::subscripts::{Subscripts, Namespace}; + /// use einsum_solver::{subscripts::Subscripts, namespace::*}; /// /// let mut names = Namespace::init(); /// @@ -198,92 +173,82 @@ impl Subscripts { /// Factorize subscripts /// - /// This requires mutable reference to [Namespace] since factorization process - /// creates new identifier for intermediate storage, e.g. - /// /// ```text /// ij,jk,kl->il | arg0 arg1 arg2 -> out0 /// ``` /// - /// will be factorized into + /// will be factorized with `(arg0, arg1)` into /// /// ```text /// ij,jk->ik | arg0 arg1 -> out1 /// ik,kl->il | out1 arg2 -> out0 /// ``` /// - /// where `out1` is a new identifier. - /// - /// /// ``` - /// use einsum_solver::{subscripts::*, parser::RawSubscript}; + /// use einsum_solver::{subscripts::*, namespace::*, parser::RawSubscript}; /// use std::str::FromStr; + /// use maplit::btreeset; /// /// let mut names = Namespace::init(); /// let base = Subscripts::from_raw_indices(&mut names, "ij,jk,kl->il").unwrap(); /// - /// // Factorize along j - /// let (ijjk, ikkl) = base.factorize(&mut names, 'j').unwrap().unwrap(); - /// - /// let arg0 = &ijjk.inputs[0]; - /// assert_eq!(arg0.raw(), &RawSubscript::Indices(vec!['i', 'j'])); - /// assert_eq!(arg0.position(), &Position::User(0)); - /// - /// let arg1 = &ijjk.inputs[1]; - /// assert_eq!(arg1.raw(), &RawSubscript::Indices(vec!['j', 'k'])); - /// assert_eq!(arg1.position(), &Position::User(1)); - /// - /// let out1 = &ijjk.output; - /// assert_eq!(out1.raw(), &RawSubscript::Indices(vec!['i', 'k'])); - /// assert_eq!(out1.position(), &Position::Intermidiate(1)); - /// - /// // returns `Ok(None)` if subscript is irreducible - /// assert!(ijjk.factorize(&mut names, 'j').unwrap().is_none()); - /// assert!(ikkl.factorize(&mut names, 'k').unwrap().is_none()); + /// let (ijjk, ikkl) = base.factorize(&mut names, + /// btreeset!{ Position::User(0), Position::User(1) } + /// ).unwrap(); /// ``` - pub fn factorize(&self, names: &mut Namespace, index: char) -> Result> { - if !self.contraction_indices().contains(&index) { - bail!("Unknown index: {}", index); - } - - let mut first = Vec::new(); - let mut second = Vec::new(); - let mut out_indices = BTreeSet::new(); + pub fn factorize( + &self, + names: &mut Namespace, + inners: BTreeSet, + ) -> Result<(Self, Self)> { + let mut inner_inputs = Vec::new(); + let mut outer_inputs = Vec::new(); + let mut indices: BTreeMap = BTreeMap::new(); for input in &self.inputs { - let indices = input.indices(); - if indices.iter().any(|label| *label == index) { - first.push(input.clone()); - for c in indices { - if c != index { - out_indices.insert(c); - } + if inners.contains(&input.position) { + inner_inputs.push(input.clone()); + for c in input.indices() { + indices + .entry(c) + .and_modify(|(i, _)| *i += 1) + .or_insert((1, 0)); } } else { - second.push(input.clone()); + outer_inputs.push(input.clone()); + for c in input.indices() { + indices + .entry(c) + .and_modify(|(_, o)| *o += 1) + .or_insert((0, 1)); + } } } - - // irreducible - if second.is_empty() { - return Ok(None); - } - - let output = Subscript { - raw: parser::RawSubscript::Indices(out_indices.into_iter().collect()), - position: names.new(), + let out = Subscript { + raw: RawSubscript::Indices( + indices + .into_iter() + .filter_map(|(key, (i, o))| { + if i == 1 || (i >= 2 && o > 0) { + Some(key) + } else { + None + } + }) + .collect(), + ), + position: names.new_ident(), }; - - second.insert(0, output.clone()); - Ok(Some(( - Self { - inputs: first, - output, + outer_inputs.insert(0, out.clone()); + Ok(( + Subscripts { + inputs: inner_inputs, + output: out, }, - Self { - inputs: second, + Subscripts { + inputs: outer_inputs, output: self.output.clone(), }, - ))) + )) } }