Skip to content

Commit

Permalink
Merge pull request #15 from termoshtt/factorize-using-tensor-names
Browse files Browse the repository at this point in the history
Factorize with tensor names, instead of along index
  • Loading branch information
termoshtt authored Nov 23, 2022
2 parents 0a3284c + 760ea9c commit 58e762c
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 115 deletions.
2 changes: 1 addition & 1 deletion einsum-derive/src/einsum_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ pub fn def_einsum_fn(subscripts: &Subscripts) -> TokenStream2 {
mod test {
use super::*;
use crate::format::format_block;
use einsum_solver::subscripts::{Namespace, Subscripts};
use einsum_solver::{namespace::Namespace, subscripts::Subscripts};

#[test]
fn contraction_snapshots() {
Expand Down
2 changes: 1 addition & 1 deletion einsum-derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! proc-macro based einsum implementation
use einsum_solver::subscripts::{Namespace, Subscripts};
use einsum_solver::{namespace::*, subscripts::Subscripts};
use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};
use proc_macro_error::{abort_call_site, proc_macro_error, OptionExt};
Expand Down
9 changes: 6 additions & 3 deletions einsum-solver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@
//! For simplicity, both addition `+` and multiplication `*` are counted as 1 operation,
//! and do not consider fused multiplication-addition (FMA).
//! In the above `matmul3` example, there are $\\#K \times \\#J$ addition
//! and $2 \times \\#K \times \\#J$ multiplications,
//! and $2 \times \\#K \times \\#J$ multiplications for every indices $(i, l)$,
//! where $\\#$ denotes the number of elements in the index sets.
//! Assuming the all sizes of indices are same and denoted by $N$,
//! there are $O(N^4)$ floating point operations.
//!
//! When we sum up partially along `j`:
//! $$
Expand All @@ -79,8 +81,8 @@
//! \sum_{k \in K} c_{kl} d_{ik},
//! \text{where} \space d_{ik} = \sum_{j \in J} a_{ij} b_{jk},
//! $$
//! there are only $2\\#K + 2\\#J$ operations with $\\#I \times \\#K$
//! memorization storage.
//! there are $O(N^3)$ operations for both computing $d_{ik}$ and final summation
//! with $O(N^2)$ memorization storage.
//!
//! When is this factorization possible? We know that above `matmul3` example
//! is also written as associative matrix product $ABC = A(BC) = (AB)C$,
Expand Down Expand Up @@ -154,6 +156,7 @@
//! and the objective of this crate is to (heuristically) solve this problem.
//!
pub mod namespace;
pub mod parser;
pub mod path;
pub mod subscripts;
35 changes: 35 additions & 0 deletions einsum-solver/src/namespace.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/// Names of tensors
///
/// As the crate level document explains,
/// einsum factorization requires to track names of tensors
/// in addition to subscripts, and this struct manages it.
/// This works as a simple counter, which counts how many intermediate
/// tensor denoted `out{N}` appears and issues new `out{N+1}` identifier.
///
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Namespace {
last: usize,
}

impl Namespace {
/// Create new namespace
pub fn init() -> Self {
Namespace { last: 0 }
}

/// Issue new identifier
pub fn new_ident(&mut self) -> Position {
let pos = Position::Intermidiate(self.last);
self.last += 1;
pos
}
}

/// Which tensor the subscript specifies
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
pub enum Position {
/// The tensor which user inputs as N-th argument of einsum
User(usize),
/// The tensor created by einsum in its N-th step
Intermidiate(usize),
}
2 changes: 1 addition & 1 deletion einsum-solver/src/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl Path {
///
/// ```
/// use std::str::FromStr;
/// use einsum_solver::{path::Path, subscripts::{Subscripts, Namespace}};
/// use einsum_solver::{path::Path, namespace::Namespace, subscripts::Subscripts};
///
/// let mut names = Namespace::init();
/// let subscripts = Subscripts::from_raw_indices(&mut names, "ij,ji->").unwrap();
Expand Down
183 changes: 74 additions & 109 deletions einsum-solver/src/subscripts.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Einsum subscripts, e.g. `ij,jk->ik`
use crate::parser;
use anyhow::{bail, Result};
use crate::{namespace::*, parser::*};
use anyhow::Result;
use std::{
collections::{BTreeMap, BTreeSet},
fmt,
Expand All @@ -9,12 +9,12 @@ use std::{

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Subscript {
raw: parser::RawSubscript,
raw: RawSubscript,
position: Position,
}

impl Subscript {
pub fn raw(&self) -> &parser::RawSubscript {
pub fn raw(&self) -> &RawSubscript {
&self.raw
}

Expand All @@ -24,50 +24,15 @@ impl Subscript {

pub fn indices(&self) -> Vec<char> {
match &self.raw {
parser::RawSubscript::Indices(indices) => indices.clone(),
parser::RawSubscript::Ellipsis { start, end } => {
RawSubscript::Indices(indices) => indices.clone(),
RawSubscript::Ellipsis { start, end } => {
start.iter().chain(end.iter()).cloned().collect()
}
}
}
}

/// Names of tensors
///
/// As the crate level document explains,
/// einsum factorization requires to track names of tensors
/// in addition to subscripts, and this struct manages it.
/// This works as a simple counter, which counts how many intermediate
/// tensor denoted `out{N}` appears and issues new `out{N+1}` identifier.
///
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Namespace {
last: usize,
}

impl Namespace {
/// Create new namespace
pub fn init() -> Self {
Namespace { last: 0 }
}

/// Issue new identifier
pub fn new(&mut self) -> Position {
let pos = Position::Intermidiate(self.last);
self.last += 1;
pos
}
}

/// Which tensor the subscript specifies
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
pub enum Position {
/// The tensor which user inputs as N-th argument of einsum
User(usize),
/// The tensor created by einsum in its N-th step
Intermidiate(usize),
}

#[cfg_attr(doc, katexit::katexit)]
/// Einsum subscripts with tensor names, e.g. `ij,jk->ik | arg0 arg1 -> out`
#[derive(Debug, PartialEq, Eq)]
pub struct Subscripts {
Expand All @@ -92,6 +57,16 @@ impl fmt::Display for Subscripts {
}

impl Subscripts {
/// Returns $\alpha$ if this subscripts requires $O(N^\alpha)$ floating point operation
pub fn compute_order(&self) -> usize {
self.memory_order() + self.contraction_indices().len()
}

/// Returns $\beta$ if this subscripts requires $O(N^\beta)$ memory
pub fn memory_order(&self) -> usize {
self.output.indices().len()
}

/// Normalize subscripts into "explicit mode"
///
/// [numpy.einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html)
Expand All @@ -109,7 +84,7 @@ impl Subscripts {
///
/// ```
/// use std::str::FromStr;
/// use einsum_solver::{subscripts::{Subscripts, Namespace}, parser::RawSubscripts};
/// use einsum_solver::{subscripts::*, parser::*, namespace::*};
///
/// let mut names = Namespace::init();
///
Expand All @@ -124,7 +99,7 @@ impl Subscripts {
/// assert_eq!(subscripts.output.raw(), &['i', 'j']);
/// ```
///
pub fn from_raw(names: &mut Namespace, raw: parser::RawSubscripts) -> Self {
pub fn from_raw(names: &mut Namespace, raw: RawSubscripts) -> Self {
let inputs = raw
.inputs
.iter()
Expand All @@ -134,7 +109,7 @@ impl Subscripts {
position: Position::User(i),
})
.collect();
let position = names.new();
let position = names.new_ident();
if let Some(output) = raw.output {
return Subscripts {
inputs,
Expand All @@ -147,7 +122,7 @@ impl Subscripts {

let count = count_indices(&inputs);
let output = Subscript {
raw: parser::RawSubscript::Indices(
raw: RawSubscript::Indices(
count
.iter()
.filter_map(|(key, value)| if *value == 1 { Some(*key) } else { None })
Expand All @@ -159,16 +134,16 @@ impl Subscripts {
}

pub fn from_raw_indices(names: &mut Namespace, indices: &str) -> Result<Self> {
let raw = parser::RawSubscripts::from_str(indices)?;
let raw = RawSubscripts::from_str(indices)?;
Ok(Self::from_raw(names, raw))
}

/// Indices to be factorize
/// Indices to be contracted
///
/// ```
/// use std::str::FromStr;
/// use maplit::btreeset;
/// use einsum_solver::subscripts::{Subscripts, Namespace};
/// use einsum_solver::{subscripts::Subscripts, namespace::*};
///
/// let mut names = Namespace::init();
///
Expand Down Expand Up @@ -198,92 +173,82 @@ impl Subscripts {

/// Factorize subscripts
///
/// This requires mutable reference to [Namespace] since factorization process
/// creates new identifier for intermediate storage, e.g.
///
/// ```text
/// ij,jk,kl->il | arg0 arg1 arg2 -> out0
/// ```
///
/// will be factorized into
/// will be factorized with `(arg0, arg1)` into
///
/// ```text
/// ij,jk->ik | arg0 arg1 -> out1
/// ik,kl->il | out1 arg2 -> out0
/// ```
///
/// where `out1` is a new identifier.
///
///
/// ```
/// use einsum_solver::{subscripts::*, parser::RawSubscript};
/// use einsum_solver::{subscripts::*, namespace::*, parser::RawSubscript};
/// use std::str::FromStr;
/// use maplit::btreeset;
///
/// let mut names = Namespace::init();
/// let base = Subscripts::from_raw_indices(&mut names, "ij,jk,kl->il").unwrap();
///
/// // Factorize along j
/// let (ijjk, ikkl) = base.factorize(&mut names, 'j').unwrap().unwrap();
///
/// let arg0 = &ijjk.inputs[0];
/// assert_eq!(arg0.raw(), &RawSubscript::Indices(vec!['i', 'j']));
/// assert_eq!(arg0.position(), &Position::User(0));
///
/// let arg1 = &ijjk.inputs[1];
/// assert_eq!(arg1.raw(), &RawSubscript::Indices(vec!['j', 'k']));
/// assert_eq!(arg1.position(), &Position::User(1));
///
/// let out1 = &ijjk.output;
/// assert_eq!(out1.raw(), &RawSubscript::Indices(vec!['i', 'k']));
/// assert_eq!(out1.position(), &Position::Intermidiate(1));
///
/// // returns `Ok(None)` if subscript is irreducible
/// assert!(ijjk.factorize(&mut names, 'j').unwrap().is_none());
/// assert!(ikkl.factorize(&mut names, 'k').unwrap().is_none());
/// let (ijjk, ikkl) = base.factorize(&mut names,
/// btreeset!{ Position::User(0), Position::User(1) }
/// ).unwrap();
/// ```
pub fn factorize(&self, names: &mut Namespace, index: char) -> Result<Option<(Self, Self)>> {
if !self.contraction_indices().contains(&index) {
bail!("Unknown index: {}", index);
}

let mut first = Vec::new();
let mut second = Vec::new();
let mut out_indices = BTreeSet::new();
pub fn factorize(
&self,
names: &mut Namespace,
inners: BTreeSet<Position>,
) -> Result<(Self, Self)> {
let mut inner_inputs = Vec::new();
let mut outer_inputs = Vec::new();
let mut indices: BTreeMap<char, (usize /* inner */, usize /* outer */)> = BTreeMap::new();
for input in &self.inputs {
let indices = input.indices();
if indices.iter().any(|label| *label == index) {
first.push(input.clone());
for c in indices {
if c != index {
out_indices.insert(c);
}
if inners.contains(&input.position) {
inner_inputs.push(input.clone());
for c in input.indices() {
indices
.entry(c)
.and_modify(|(i, _)| *i += 1)
.or_insert((1, 0));
}
} else {
second.push(input.clone());
outer_inputs.push(input.clone());
for c in input.indices() {
indices
.entry(c)
.and_modify(|(_, o)| *o += 1)
.or_insert((0, 1));
}
}
}

// irreducible
if second.is_empty() {
return Ok(None);
}

let output = Subscript {
raw: parser::RawSubscript::Indices(out_indices.into_iter().collect()),
position: names.new(),
let out = Subscript {
raw: RawSubscript::Indices(
indices
.into_iter()
.filter_map(|(key, (i, o))| {
if i == 1 || (i >= 2 && o > 0) {
Some(key)
} else {
None
}
})
.collect(),
),
position: names.new_ident(),
};

second.insert(0, output.clone());
Ok(Some((
Self {
inputs: first,
output,
outer_inputs.insert(0, out.clone());
Ok((
Subscripts {
inputs: inner_inputs,
output: out,
},
Self {
inputs: second,
Subscripts {
inputs: outer_inputs,
output: self.output.clone(),
},
)))
))
}
}

Expand Down

0 comments on commit 58e762c

Please sign in to comment.