Skip to content

Commit

Permalink
Merge pull request #18 from termoshtt/partial-summation2
Browse files Browse the repository at this point in the history
Generate partial summation Rust code
  • Loading branch information
termoshtt authored Nov 29, 2022
2 parents cff2e22 + 12b3d5a commit 495d6f6
Show file tree
Hide file tree
Showing 16 changed files with 348 additions and 199 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[workspace]
members = [
"einsum-derive",
"einsum-solver",
"einsum-codegen",
]
2 changes: 1 addition & 1 deletion einsum-solver/Cargo.toml → einsum-codegen/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[package]
name = "einsum-solver"
name = "einsum-codegen"
version = "0.1.0"
edition = "2021"

Expand Down
File renamed without changes.
File renamed without changes.
66 changes: 66 additions & 0 deletions einsum-codegen/src/codegen/ndarray/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//! For [ndarray](https://crates.io/crates/ndarray) crate
pub mod naive;

use crate::subscripts::Subscripts;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};

fn dim(n: usize) -> syn::Path {
let ix = quote::format_ident!("Ix{}", n);
syn::parse_quote! { ndarray::#ix }
}

/// Generate einsum function definition
pub fn function_definition(subscripts: &Subscripts, inner: TokenStream2) -> TokenStream2 {
let fn_name = format_ident!("{}", subscripts.escaped_ident());
let n = subscripts.inputs.len();

let args = &subscripts.inputs;
let storages: Vec<syn::Ident> = (0..n).map(|n| quote::format_ident!("S{}", n)).collect();
let dims: Vec<syn::Path> = subscripts
.inputs
.iter()
.map(|ss| dim(ss.indices().len()))
.collect();

let out_dim = dim(subscripts.output.indices().len());

quote! {
fn #fn_name<T, #(#storages),*>(
#( #args: ndarray::ArrayBase<#storages, #dims> ),*
) -> ndarray::Array<T, #out_dim>
where
T: ndarray::LinalgScalar,
#( #storages: ndarray::Data<Elem = T> ),*
{
#inner
}
}
}

#[cfg(test)]
mod test {
use crate::{codegen::format_block, *};

#[test]
fn function_definition_snapshot() {
let mut namespace = Namespace::init();
let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap();
let inner = quote::quote! { todo!() };
let tt = format_block(super::function_definition(&subscripts, inner).to_string());
insta::assert_snapshot!(tt, @r###"
fn ij_jk__ik<T, S0, S1>(
arg0: ndarray::ArrayBase<S0, ndarray::Ix2>,
arg1: ndarray::ArrayBase<S1, ndarray::Ix2>,
) -> ndarray::Array<T, ndarray::Ix2>
where
T: ndarray::LinalgScalar,
S0: ndarray::Data<Elem = T>,
S1: ndarray::Data<Elem = T>,
{
todo!()
}
"###);
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
//! Generate einsum function with naive loop
use crate::{namespace::Position, subscripts::Subscripts};
#[cfg(doc)]
use super::function_definition;

use proc_macro2::{Span, TokenStream as TokenStream2};
use crate::Subscripts;

use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use std::collections::HashSet;

fn dim(n: usize) -> syn::Path {
let ix = quote::format_ident!("Ix{}", n);
syn::parse_quote! { ndarray::#ix }
}

fn index_ident(i: char) -> syn::Ident {
quote::format_ident!("{}", i)
}
Expand All @@ -19,17 +17,6 @@ fn n_ident(i: char) -> syn::Ident {
quote::format_ident!("n_{}", i)
}

/// Generate for loop
///
/// ```ignore
/// for #index0 in 0..#n0 {
/// for #index1 in 0..#n1 {
/// for #index2 in 0..#n2 {
/// #inner
/// }
/// }
/// }
/// ```
fn contraction_for(indices: &[char], inner: TokenStream2) -> TokenStream2 {
let mut tt = inner;
for &i in indices.iter().rev() {
Expand Down Expand Up @@ -72,7 +59,7 @@ fn contraction_inner(subscripts: &Subscripts) -> TokenStream2 {
}
}

/// Generate naive contraction loop, e.g.
/// Generate naive contraction loop
///
/// ```
/// # use ndarray::Array2;
Expand Down Expand Up @@ -134,7 +121,7 @@ pub fn array_size_asserts(subscripts: &Subscripts) -> TokenStream2 {
.map(|m| quote::format_ident!("n_{}", m))
.collect();
// size of index defined previously, e.g. `n_i`
let n: Vec<_> = arg.indices().into_iter().map(|i| n_ident(i)).collect();
let n: Vec<_> = arg.indices().into_iter().map(n_ident).collect();
tt.push(quote! {
let (#(#n_each),*) = #arg.dim();
#(assert_eq!(#n_each, #n);)*
Expand All @@ -155,46 +142,25 @@ fn define_output_array(subscripts: &Subscripts) -> TokenStream2 {
}
}

pub fn define(subscripts: &Subscripts) -> TokenStream2 {
let fn_name = syn::Ident::new(&subscripts.escaped_ident(), Span::call_site());
let n = subscripts.inputs.len();

let args: Vec<_> = (0..n).map(|n| Position::Arg(n)).collect();
let storages: Vec<syn::Ident> = (0..n).map(|n| quote::format_ident!("S{}", n)).collect();
let dims: Vec<syn::Path> = subscripts
.inputs
.iter()
.map(|ss| dim(ss.indices().len()))
.collect();

let out_dim = dim(subscripts.output.indices().len());

/// Actual component of einsum [function_definition]
pub fn inner(subscripts: &Subscripts) -> TokenStream2 {
let array_size = define_array_size(subscripts);
let array_size_asserts = array_size_asserts(subscripts);
let output_ident = &subscripts.output;
let output_tt = define_output_array(subscripts);
let contraction_tt = contraction(subscripts);

quote! {
fn #fn_name<T, #(#storages),*>(
#( #args: ndarray::ArrayBase<#storages, #dims> ),*
) -> ndarray::Array<T, #out_dim>
where
T: ndarray::LinalgScalar,
#( #storages: ndarray::Data<Elem = T> ),*
{
#array_size
#array_size_asserts
#output_tt
#contraction_tt
#output_ident
}
#array_size
#array_size_asserts
#output_tt
#contraction_tt
#output_ident
}
}

#[cfg(test)]
mod test {
use crate::{codegen::format_block, namespace::Namespace, subscripts::Subscripts};
use crate::{codegen::format_block, *};

#[test]
fn define_array_size() {
Expand All @@ -208,7 +174,7 @@ mod test {
}

#[test]
fn contraction_snapshots() {
fn contraction() {
let mut namespace = Namespace::init();
let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap();
let tt = format_block(super::contraction(&subscripts).to_string());
Expand All @@ -224,42 +190,32 @@ mod test {
}

#[test]
fn einsum_fn_snapshots() {
fn inner() {
let mut namespace = Namespace::init();
let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap();
let tt = format_block(super::define(&subscripts).to_string());
let tt = format_block(super::inner(&subscripts).to_string());
insta::assert_snapshot!(tt, @r###"
fn ij_jk__ik<T, S0, S1>(
arg0: ndarray::ArrayBase<S0, ndarray::Ix2>,
arg1: ndarray::ArrayBase<S1, ndarray::Ix2>,
) -> ndarray::Array<T, ndarray::Ix2>
where
T: ndarray::LinalgScalar,
S0: ndarray::Data<Elem = T>,
S1: ndarray::Data<Elem = T>,
let (n_i, n_j) = arg0.dim();
let (_, n_k) = arg1.dim();
{
let (n_i, n_j) = arg0.dim();
let (_, n_k) = arg1.dim();
{
let (n_0, n_1) = arg0.dim();
assert_eq!(n_0, n_i);
assert_eq!(n_1, n_j);
}
{
let (n_0, n_1) = arg1.dim();
assert_eq!(n_0, n_j);
assert_eq!(n_1, n_k);
}
let mut out0 = ndarray::Array::zeros((n_i, n_k));
for i in 0..n_i {
for k in 0..n_k {
for j in 0..n_j {
out0[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
}
let (n_0, n_1) = arg0.dim();
assert_eq!(n_0, n_i);
assert_eq!(n_1, n_j);
}
{
let (n_0, n_1) = arg1.dim();
assert_eq!(n_0, n_j);
assert_eq!(n_1, n_k);
}
let mut out0 = ndarray::Array::zeros((n_i, n_k));
for i in 0..n_i {
for k in 0..n_k {
for j in 0..n_j {
out0[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
}
}
out0
}
out0
"###);
}
}
11 changes: 8 additions & 3 deletions einsum-solver/src/lib.rs → einsum-codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,12 @@
//!
pub mod codegen;
pub mod namespace;
pub mod parser;
pub mod path;
pub mod subscripts;

mod namespace;
mod path;
mod subscripts;

pub use namespace::*;
pub use path::*;
pub use subscripts::*;
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::fmt;

/// index = `a` | `b` | `c` | `d` | `e` | `f` | `g` | `h` | `i` | `j` | `k` | `l` |`m` | `n` | `o` | `p` | `q` | `r` | `s` | `t` | `u` | `v` | `w` | `x` |`y` | `z`;
pub fn index(input: &str) -> IResult<&str, char> {
satisfy(|c| matches!(c, 'a'..='z')).parse(input)
satisfy(|c| c.is_ascii_lowercase()).parse(input)
}

/// ellipsis = `...`
Expand Down
Loading

0 comments on commit 495d6f6

Please sign in to comment.