Skip to content

Commit

Permalink
Factorize with tensor names, instead of along index
Browse files Browse the repository at this point in the history
  • Loading branch information
termoshtt committed Nov 22, 2022
1 parent 0a3284c commit 3b505cc
Showing 1 changed file with 49 additions and 42 deletions.
91 changes: 49 additions & 42 deletions einsum-solver/src/subscripts.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Einsum subscripts, e.g. `ij,jk->ik`
use crate::parser;
use anyhow::{bail, Result};
use anyhow::Result;
use std::{
collections::{BTreeMap, BTreeSet},
fmt,
@@ -205,7 +205,7 @@ impl Subscripts {
/// ij,jk,kl->il | arg0 arg1 arg2 -> out0
/// ```
///
/// will be factorized into
/// will be factorized with `(arg0, arg1)` into
///
/// ```text
/// ij,jk->ik | arg0 arg1 -> out1
@@ -214,16 +214,17 @@ impl Subscripts {
///
/// where `out1` is a new identifier.
///
///
/// ```
/// use einsum_solver::{subscripts::*, parser::RawSubscript};
/// use std::str::FromStr;
/// use maplit::btreeset;
///
/// let mut names = Namespace::init();
/// let base = Subscripts::from_raw_indices(&mut names, "ij,jk,kl->il").unwrap();
///
/// // Factorize along j
/// let (ijjk, ikkl) = base.factorize(&mut names, 'j').unwrap().unwrap();
/// let (ijjk, ikkl) = base.factorize(&mut names,
/// btreeset!{ Position::User(0), Position::User(1) }
/// ).unwrap();
///
/// let arg0 = &ijjk.inputs[0];
/// assert_eq!(arg0.raw(), &RawSubscript::Indices(vec!['i', 'j']));
@@ -236,54 +237,60 @@ impl Subscripts {
/// let out1 = &ijjk.output;
/// assert_eq!(out1.raw(), &RawSubscript::Indices(vec!['i', 'k']));
/// assert_eq!(out1.position(), &Position::Intermidiate(1));
///
/// // returns `Ok(None)` if subscript is irreducible
/// assert!(ijjk.factorize(&mut names, 'j').unwrap().is_none());
/// assert!(ikkl.factorize(&mut names, 'k').unwrap().is_none());
/// ```
pub fn factorize(&self, names: &mut Namespace, index: char) -> Result<Option<(Self, Self)>> {
if !self.contraction_indices().contains(&index) {
bail!("Unknown index: {}", index);
}

let mut first = Vec::new();
let mut second = Vec::new();
let mut out_indices = BTreeSet::new();
pub fn factorize(
&self,
names: &mut Namespace,
inners: BTreeSet<Position>,
) -> Result<(Self, Self)> {
let mut inner_inputs = Vec::new();
let mut outer_inputs = Vec::new();
let mut indices: BTreeMap<char, (usize /* inner */, usize /* outer */)> = BTreeMap::new();
for input in &self.inputs {
let indices = input.indices();
if indices.iter().any(|label| *label == index) {
first.push(input.clone());
for c in indices {
if c != index {
out_indices.insert(c);
}
if inners.contains(&input.position) {
inner_inputs.push(input.clone());
for c in input.indices() {
indices
.entry(c)
.and_modify(|(i, _)| *i += 1)
.or_insert((1, 0));
}
} else {
second.push(input.clone());
outer_inputs.push(input.clone());
for c in input.indices() {
indices
.entry(c)
.and_modify(|(_, o)| *o += 1)
.or_insert((0, 1));
}
}
}

// irreducible
if second.is_empty() {
return Ok(None);
}

let output = Subscript {
raw: parser::RawSubscript::Indices(out_indices.into_iter().collect()),
let out = Subscript {
raw: parser::RawSubscript::Indices(
indices
.into_iter()
.filter_map(|(key, (i, o))| {
if i == 1 || (i >= 2 && o > 0) {
Some(key)
} else {
None
}
})
.collect(),
),
position: names.new(),
};

second.insert(0, output.clone());
Ok(Some((
Self {
inputs: first,
output,
outer_inputs.insert(0, out.clone());
Ok((
Subscripts {
inputs: inner_inputs,
output: out,
},
Self {
inputs: second,
Subscripts {
inputs: outer_inputs,
output: self.output.clone(),
},
)))
))
}
}

0 comments on commit 3b505cc

Please sign in to comment.