diff --git a/Cargo.toml b/Cargo.toml index e2e47dc..6942d17 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] members = [ "einsum-derive", - "einsum-solver", + "einsum-codegen", ] diff --git a/einsum-solver/Cargo.toml b/einsum-codegen/Cargo.toml similarity index 93% rename from einsum-solver/Cargo.toml rename to einsum-codegen/Cargo.toml index 7f28f17..a7debe1 100644 --- a/einsum-solver/Cargo.toml +++ b/einsum-codegen/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "einsum-solver" +name = "einsum-codegen" version = "0.1.0" edition = "2021" diff --git a/einsum-solver/src/codegen/format.rs b/einsum-codegen/src/codegen/format.rs similarity index 100% rename from einsum-solver/src/codegen/format.rs rename to einsum-codegen/src/codegen/format.rs diff --git a/einsum-solver/src/codegen/mod.rs b/einsum-codegen/src/codegen/mod.rs similarity index 100% rename from einsum-solver/src/codegen/mod.rs rename to einsum-codegen/src/codegen/mod.rs diff --git a/einsum-codegen/src/codegen/ndarray/mod.rs b/einsum-codegen/src/codegen/ndarray/mod.rs new file mode 100644 index 0000000..4d948ab --- /dev/null +++ b/einsum-codegen/src/codegen/ndarray/mod.rs @@ -0,0 +1,66 @@ +//! For [ndarray](https://crates.io/crates/ndarray) crate + +pub mod naive; + +use crate::subscripts::Subscripts; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote}; + +fn dim(n: usize) -> syn::Path { + let ix = quote::format_ident!("Ix{}", n); + syn::parse_quote! { ndarray::#ix } +} + +/// Generate einsum function definition +pub fn function_definition(subscripts: &Subscripts, inner: TokenStream2) -> TokenStream2 { + let fn_name = format_ident!("{}", subscripts.escaped_ident()); + let n = subscripts.inputs.len(); + + let args = &subscripts.inputs; + let storages: Vec = (0..n).map(|n| quote::format_ident!("S{}", n)).collect(); + let dims: Vec = subscripts + .inputs + .iter() + .map(|ss| dim(ss.indices().len())) + .collect(); + + let out_dim = dim(subscripts.output.indices().len()); + + quote! { + fn #fn_name( + #( #args: ndarray::ArrayBase<#storages, #dims> ),* + ) -> ndarray::Array + where + T: ndarray::LinalgScalar, + #( #storages: ndarray::Data ),* + { + #inner + } + } +} + +#[cfg(test)] +mod test { + use crate::{codegen::format_block, *}; + + #[test] + fn function_definition_snapshot() { + let mut namespace = Namespace::init(); + let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap(); + let inner = quote::quote! { todo!() }; + let tt = format_block(super::function_definition(&subscripts, inner).to_string()); + insta::assert_snapshot!(tt, @r###" + fn ij_jk__ik( + arg0: ndarray::ArrayBase, + arg1: ndarray::ArrayBase, + ) -> ndarray::Array + where + T: ndarray::LinalgScalar, + S0: ndarray::Data, + S1: ndarray::Data, + { + todo!() + } + "###); + } +} diff --git a/einsum-solver/src/codegen/ndarray/naive.rs b/einsum-codegen/src/codegen/ndarray/naive.rs similarity index 66% rename from einsum-solver/src/codegen/ndarray/naive.rs rename to einsum-codegen/src/codegen/ndarray/naive.rs index 691cefe..c489967 100644 --- a/einsum-solver/src/codegen/ndarray/naive.rs +++ b/einsum-codegen/src/codegen/ndarray/naive.rs @@ -1,16 +1,14 @@ //! Generate einsum function with naive loop -use crate::{namespace::Position, subscripts::Subscripts}; +#[cfg(doc)] +use super::function_definition; -use proc_macro2::{Span, TokenStream as TokenStream2}; +use crate::Subscripts; + +use proc_macro2::TokenStream as TokenStream2; use quote::quote; use std::collections::HashSet; -fn dim(n: usize) -> syn::Path { - let ix = quote::format_ident!("Ix{}", n); - syn::parse_quote! { ndarray::#ix } -} - fn index_ident(i: char) -> syn::Ident { quote::format_ident!("{}", i) } @@ -19,17 +17,6 @@ fn n_ident(i: char) -> syn::Ident { quote::format_ident!("n_{}", i) } -/// Generate for loop -/// -/// ```ignore -/// for #index0 in 0..#n0 { -/// for #index1 in 0..#n1 { -/// for #index2 in 0..#n2 { -/// #inner -/// } -/// } -/// } -/// ``` fn contraction_for(indices: &[char], inner: TokenStream2) -> TokenStream2 { let mut tt = inner; for &i in indices.iter().rev() { @@ -72,7 +59,7 @@ fn contraction_inner(subscripts: &Subscripts) -> TokenStream2 { } } -/// Generate naive contraction loop, e.g. +/// Generate naive contraction loop /// /// ``` /// # use ndarray::Array2; @@ -134,7 +121,7 @@ pub fn array_size_asserts(subscripts: &Subscripts) -> TokenStream2 { .map(|m| quote::format_ident!("n_{}", m)) .collect(); // size of index defined previously, e.g. `n_i` - let n: Vec<_> = arg.indices().into_iter().map(|i| n_ident(i)).collect(); + let n: Vec<_> = arg.indices().into_iter().map(n_ident).collect(); tt.push(quote! { let (#(#n_each),*) = #arg.dim(); #(assert_eq!(#n_each, #n);)* @@ -155,46 +142,25 @@ fn define_output_array(subscripts: &Subscripts) -> TokenStream2 { } } -pub fn define(subscripts: &Subscripts) -> TokenStream2 { - let fn_name = syn::Ident::new(&subscripts.escaped_ident(), Span::call_site()); - let n = subscripts.inputs.len(); - - let args: Vec<_> = (0..n).map(|n| Position::Arg(n)).collect(); - let storages: Vec = (0..n).map(|n| quote::format_ident!("S{}", n)).collect(); - let dims: Vec = subscripts - .inputs - .iter() - .map(|ss| dim(ss.indices().len())) - .collect(); - - let out_dim = dim(subscripts.output.indices().len()); - +/// Actual component of einsum [function_definition] +pub fn inner(subscripts: &Subscripts) -> TokenStream2 { let array_size = define_array_size(subscripts); let array_size_asserts = array_size_asserts(subscripts); let output_ident = &subscripts.output; let output_tt = define_output_array(subscripts); let contraction_tt = contraction(subscripts); - quote! { - fn #fn_name( - #( #args: ndarray::ArrayBase<#storages, #dims> ),* - ) -> ndarray::Array - where - T: ndarray::LinalgScalar, - #( #storages: ndarray::Data ),* - { - #array_size - #array_size_asserts - #output_tt - #contraction_tt - #output_ident - } + #array_size + #array_size_asserts + #output_tt + #contraction_tt + #output_ident } } #[cfg(test)] mod test { - use crate::{codegen::format_block, namespace::Namespace, subscripts::Subscripts}; + use crate::{codegen::format_block, *}; #[test] fn define_array_size() { @@ -208,7 +174,7 @@ mod test { } #[test] - fn contraction_snapshots() { + fn contraction() { let mut namespace = Namespace::init(); let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap(); let tt = format_block(super::contraction(&subscripts).to_string()); @@ -224,42 +190,32 @@ mod test { } #[test] - fn einsum_fn_snapshots() { + fn inner() { let mut namespace = Namespace::init(); let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap(); - let tt = format_block(super::define(&subscripts).to_string()); + let tt = format_block(super::inner(&subscripts).to_string()); insta::assert_snapshot!(tt, @r###" - fn ij_jk__ik( - arg0: ndarray::ArrayBase, - arg1: ndarray::ArrayBase, - ) -> ndarray::Array - where - T: ndarray::LinalgScalar, - S0: ndarray::Data, - S1: ndarray::Data, + let (n_i, n_j) = arg0.dim(); + let (_, n_k) = arg1.dim(); { - let (n_i, n_j) = arg0.dim(); - let (_, n_k) = arg1.dim(); - { - let (n_0, n_1) = arg0.dim(); - assert_eq!(n_0, n_i); - assert_eq!(n_1, n_j); - } - { - let (n_0, n_1) = arg1.dim(); - assert_eq!(n_0, n_j); - assert_eq!(n_1, n_k); - } - let mut out0 = ndarray::Array::zeros((n_i, n_k)); - for i in 0..n_i { - for k in 0..n_k { - for j in 0..n_j { - out0[(i, k)] = arg0[(i, j)] * arg1[(j, k)]; - } + let (n_0, n_1) = arg0.dim(); + assert_eq!(n_0, n_i); + assert_eq!(n_1, n_j); + } + { + let (n_0, n_1) = arg1.dim(); + assert_eq!(n_0, n_j); + assert_eq!(n_1, n_k); + } + let mut out0 = ndarray::Array::zeros((n_i, n_k)); + for i in 0..n_i { + for k in 0..n_k { + for j in 0..n_j { + out0[(i, k)] = arg0[(i, j)] * arg1[(j, k)]; } } - out0 } + out0 "###); } } diff --git a/einsum-solver/src/lib.rs b/einsum-codegen/src/lib.rs similarity index 98% rename from einsum-solver/src/lib.rs rename to einsum-codegen/src/lib.rs index 6683ff5..b321692 100644 --- a/einsum-solver/src/lib.rs +++ b/einsum-codegen/src/lib.rs @@ -157,7 +157,12 @@ //! pub mod codegen; -pub mod namespace; pub mod parser; -pub mod path; -pub mod subscripts; + +mod namespace; +mod path; +mod subscripts; + +pub use namespace::*; +pub use path::*; +pub use subscripts::*; diff --git a/einsum-solver/src/namespace.rs b/einsum-codegen/src/namespace.rs similarity index 100% rename from einsum-solver/src/namespace.rs rename to einsum-codegen/src/namespace.rs diff --git a/einsum-solver/src/parser.rs b/einsum-codegen/src/parser.rs similarity index 99% rename from einsum-solver/src/parser.rs rename to einsum-codegen/src/parser.rs index efb946d..9d2bdd3 100644 --- a/einsum-solver/src/parser.rs +++ b/einsum-codegen/src/parser.rs @@ -13,7 +13,7 @@ use std::fmt; /// index = `a` | `b` | `c` | `d` | `e` | `f` | `g` | `h` | `i` | `j` | `k` | `l` |`m` | `n` | `o` | `p` | `q` | `r` | `s` | `t` | `u` | `v` | `w` | `x` |`y` | `z`; pub fn index(input: &str) -> IResult<&str, char> { - satisfy(|c| matches!(c, 'a'..='z')).parse(input) + satisfy(|c| c.is_ascii_lowercase()).parse(input) } /// ellipsis = `...` diff --git a/einsum-solver/src/path.rs b/einsum-codegen/src/path.rs similarity index 58% rename from einsum-solver/src/path.rs rename to einsum-codegen/src/path.rs index df36da8..bc32474 100644 --- a/einsum-solver/src/path.rs +++ b/einsum-codegen/src/path.rs @@ -1,41 +1,67 @@ -//! Construct and execute contraction path +//! Execution path -use crate::{namespace::Namespace, subscripts::*}; +use crate::*; use anyhow::Result; use std::collections::BTreeSet; #[derive(Debug, Clone, PartialEq, Eq)] -pub struct Path(Vec); +pub struct Path { + original: Subscripts, + reduced_subscripts: Vec, +} impl std::ops::Deref for Path { type Target = [Subscripts]; fn deref(&self) -> &[Subscripts] { - &self.0 + &self.reduced_subscripts } } impl Path { + pub fn output(&self) -> &Subscript { + &self.original.output + } + + pub fn num_args(&self) -> usize { + self.original.inputs.len() + } + pub fn compute_order(&self) -> usize { - self.0 - .iter() - .map(|ss| ss.compute_order()) - .max() - .expect("self.0 never be empty") + compute_order(&self.reduced_subscripts) } pub fn memory_order(&self) -> usize { - self.0 - .iter() - .map(|ss| ss.memory_order()) - .max() - .expect("self.0 never be empty") + memory_order(&self.reduced_subscripts) + } + + pub fn brute_force(indices: &str) -> Result { + let mut names = Namespace::init(); + let subscripts = Subscripts::from_raw_indices(&mut names, indices)?; + Ok(Path { + original: subscripts.clone(), + reduced_subscripts: brute_force_work(&mut names, subscripts)?, + }) } } -pub fn brute_force(names: &mut Namespace, subscripts: Subscripts) -> Result { +fn compute_order(ss: &[Subscripts]) -> usize { + ss.iter() + .map(|ss| ss.compute_order()) + .max() + .expect("self.0 never be empty") +} + +fn memory_order(ss: &[Subscripts]) -> usize { + ss.iter() + .map(|ss| ss.memory_order()) + .max() + .expect("self.0 never be empty") +} + +fn brute_force_work(names: &mut Namespace, subscripts: Subscripts) -> Result> { if subscripts.inputs.len() <= 2 { // Cannot be factorized anymore - return Ok(Path(vec![subscripts])); + return Ok(vec![subscripts]); } let n = subscripts.inputs.len(); @@ -59,15 +85,15 @@ pub fn brute_force(names: &mut Namespace, subscripts: Subscripts) -> Result>>()?; - subpaths.push(Path(vec![subscripts])); + .collect::>>()?; + subpaths.push(vec![subscripts]); Ok(subpaths .into_iter() - .min_by_key(|path| (path.compute_order(), path.memory_order())) + .min_by_key(|path| (compute_order(path), memory_order(path))) .expect("subpath never be empty")) } @@ -77,9 +103,7 @@ mod test { #[test] fn brute_force_ij_jk() -> Result<()> { - let mut names = Namespace::init(); - let subscripts = Subscripts::from_raw_indices(&mut names, "ij,jk->ik")?; - let path = brute_force(&mut names, subscripts)?; + let path = Path::brute_force("ij,jk->ik")?; assert_eq!(path.len(), 1); assert_eq!(path[0].to_string(), "ij,jk->ik | arg0,arg1->out0"); Ok(()) @@ -87,9 +111,7 @@ mod test { #[test] fn brute_force_ij_jk_kl_l() -> Result<()> { - let mut names = Namespace::init(); - let subscripts = Subscripts::from_raw_indices(&mut names, "ij,jk,kl,l->i")?; - let path = brute_force(&mut names, subscripts)?; + let path = Path::brute_force("ij,jk,kl,l->i")?; assert_eq!(path.len(), 3); assert_eq!(path[0].to_string(), "kl,l->k | arg2,arg3->out1"); assert_eq!(path[1].to_string(), "k,jk->j | out1,arg1->out2"); @@ -99,9 +121,7 @@ mod test { #[test] fn brute_force_i_i_i() -> Result<()> { - let mut names = Namespace::init(); - let subscripts = Subscripts::from_raw_indices(&mut names, "i,i,i->")?; - let path = brute_force(&mut names, subscripts)?; + let path = Path::brute_force("i,i,i->")?; assert_eq!(path.len(), 1); assert_eq!(path[0].to_string(), "i,i,i-> | arg0,arg1,arg2->out0"); Ok(()) diff --git a/einsum-solver/src/subscripts.rs b/einsum-codegen/src/subscripts.rs similarity index 95% rename from einsum-solver/src/subscripts.rs rename to einsum-codegen/src/subscripts.rs index 3a19c18..58f19a9 100644 --- a/einsum-solver/src/subscripts.rs +++ b/einsum-codegen/src/subscripts.rs @@ -1,8 +1,8 @@ //! Einsum subscripts, e.g. `ij,jk->ik` -use crate::{namespace::*, parser::*}; +use crate::{parser::*, *}; use anyhow::Result; use proc_macro2::TokenStream; -use quote::ToTokens; +use quote::{format_ident, quote, ToTokens, TokenStreamExt}; use std::{ collections::{BTreeMap, BTreeSet}, fmt, @@ -78,6 +78,17 @@ impl fmt::Display for Subscripts { } } +impl ToTokens for Subscripts { + fn to_tokens(&self, tokens: &mut TokenStream) { + let fn_name = format_ident!("{}", self.escaped_ident()); + let args = &self.inputs; + let out = &self.output; + tokens.append_all(quote! { + let #out = #fn_name(#(#args),*); + }); + } +} + impl Subscripts { /// Returns $\alpha$ if this subscripts requires $O(N^\alpha)$ floating point operation pub fn compute_order(&self) -> usize { @@ -106,7 +117,7 @@ impl Subscripts { /// /// ``` /// use std::str::FromStr; - /// use einsum_solver::{subscripts::*, parser::*, namespace::*}; + /// use einsum_codegen::{*, parser::*}; /// /// let mut names = Namespace::init(); /// @@ -165,7 +176,7 @@ impl Subscripts { /// ``` /// use std::str::FromStr; /// use maplit::btreeset; - /// use einsum_solver::{subscripts::Subscripts, namespace::*}; + /// use einsum_codegen::*; /// /// let mut names = Namespace::init(); /// @@ -207,7 +218,7 @@ impl Subscripts { /// ``` /// /// ``` - /// use einsum_solver::{subscripts::*, namespace::*, parser::RawSubscript}; + /// use einsum_codegen::{*, parser::RawSubscript}; /// use std::str::FromStr; /// use maplit::btreeset; /// diff --git a/einsum-derive/Cargo.toml b/einsum-derive/Cargo.toml index 73f8bb6..8628091 100644 --- a/einsum-derive/Cargo.toml +++ b/einsum-derive/Cargo.toml @@ -17,5 +17,5 @@ insta = "1.21.0" ndarray = "0.15.6" trybuild = "1.0.71" -[dependencies.einsum-solver] -path = "../einsum-solver" +[dependencies.einsum-codegen] +path = "../einsum-codegen" diff --git a/einsum-derive/src/lib.rs b/einsum-derive/src/lib.rs index 9d86770..0a2b0fe 100644 --- a/einsum-derive/src/lib.rs +++ b/einsum-derive/src/lib.rs @@ -1,49 +1,77 @@ //! proc-macro based einsum implementation +//! +//! ``` +//! use ndarray::array; +//! use einsum_derive::einsum; +//! +//! let a = array![ +//! [1.0, 2.0], +//! [3.0, 4.0] +//! ]; +//! let b = array![ +//! [1.0, 2.0], +//! [3.0, 4.0] +//! ]; +//! let c = einsum!("ij,jk->ik", a, b); +//! assert_eq!(c, array![ +//! [6.0, 8.0], +//! [12.0, 16.0] +//! ]); +//! ``` +//! +//! This proc-macro wil compile the input subscripts `"ij,jk->ik"` +//! to generate Rust code executing corresponding operation. +//! +//! Examples +//! --------- +//! +//! - `matmul3` +//! +//! ``` +//! use ndarray::array; +//! use einsum_derive::einsum; +//! +//! let a = array![[1.0, 2.0], [3.0, 4.0]]; +//! let b = array![[1.0, 2.0], [3.0, 4.0]]; +//! let c = array![[1.0, 2.0], [3.0, 4.0]]; +//! let d = einsum!("ij,jk,kl->il", a, b, c); +//! assert_eq!(d, array![[24.0, 32.0], [48.0, 64.0]]); +//! ``` +//! +//! - Take diagonal elements +//! +//! ``` +//! use ndarray::array; +//! use einsum_derive::einsum; +//! +//! let a = array![[1.0, 2.0], [3.0, 4.0]]; +//! let d = einsum!("ii->i", a); +//! assert_eq!(d, array![1.0, 4.0]); +//! ``` +//! +//! - If the subscripts and the number of input mismatches, +//! this raises compile error: +//! +//! ```compile_fail +//! use ndarray::array; +//! use einsum_derive::einsum; +//! +//! let a = array![ +//! [1.0, 2.0], +//! [3.0, 4.0] +//! ]; +//! let c = einsum!("ij,jk->ik", a /* needs one more arg */); +//! ``` +//! -use einsum_solver::{codegen::ndarray::*, namespace::*, subscripts::Subscripts}; +use einsum_codegen::{codegen::ndarray::*, *}; use proc_macro::TokenStream; -use proc_macro2::{Span, TokenStream as TokenStream2}; -use proc_macro_error::{abort_call_site, proc_macro_error, OptionExt}; +use proc_macro2::TokenStream as TokenStream2; +use proc_macro_error::{abort_call_site, proc_macro_error}; use quote::quote; use syn::parse::Parser; /// proc-macro based einsum -/// -/// ``` -/// use ndarray::array; -/// use einsum_derive::einsum; -/// -/// let a = array![ -/// [1.0, 2.0], -/// [3.0, 4.0] -/// ]; -/// let b = array![ -/// [1.0, 2.0], -/// [3.0, 4.0] -/// ]; -/// let c = einsum!("ij,jk->ik", a, b); -/// assert_eq!(c, array![ -/// [6.0, 8.0], -/// [12.0, 16.0] -/// ]); -/// ``` -/// -/// This proc-macro wil compile the input subscripts `"ij,jk->ik"` -/// to generate Rust code executing corresponding operation. -/// -/// If the subscripts and the number of input mismatches, -/// this raises compile error: -/// -/// ```compile_fail -/// use ndarray::array; -/// use einsum_derive::einsum; -/// -/// let a = array![ -/// [1.0, 2.0], -/// [3.0, 4.0] -/// ]; -/// let c = einsum!("ij,jk->ik", a /* needs one more arg */); -/// ``` #[proc_macro_error] #[proc_macro] pub fn einsum(input: TokenStream) -> TokenStream { @@ -52,26 +80,30 @@ pub fn einsum(input: TokenStream) -> TokenStream { fn einsum2(input: TokenStream2) -> TokenStream2 { let (subscripts, args) = parse(input); - - // Validate subscripts - let mut names = Namespace::init(); - let subscripts = Subscripts::from_raw_indices(&mut names, &subscripts) - .ok() - .expect_or_abort("Invalid subscripts"); - if subscripts.inputs.len() != args.len() { + let arg_ident: Vec<_> = (0..args.len()).map(Position::Arg).collect(); + let path = Path::brute_force(&subscripts).expect("Failed to construct execution path"); + let fn_defs: Vec<_> = path + .iter() + .map(|ss| { + let inner = naive::inner(ss); + function_definition(ss, inner) + }) + .collect(); + let out = path.output(); + if path.num_args() != args.len() { abort_call_site!( "Argument number mismatch: subscripts ({}), args ({})", - subscripts.inputs.len(), + path.num_args(), args.len() - ); + ) } - let einsum_fn = naive::define(&subscripts); - let fn_name = syn::Ident::new(&subscripts.escaped_ident(), Span::call_site()); quote! { { - #einsum_fn - #fn_name(#(#args),*) + #(#fn_defs)* + #(let #arg_ident = #args;)* + #(#path)* + #out } } } @@ -96,7 +128,7 @@ fn parse(input: TokenStream2) -> (String, Vec) { #[cfg(test)] mod test { use super::*; - use einsum_solver::codegen::format_block; + use einsum_codegen::codegen::format_block; use std::str::FromStr; #[test] @@ -110,7 +142,7 @@ mod test { } #[test] - fn test_snapshots() { + fn einsum_ij_jk() { let input = TokenStream2::from_str(r#""ij,jk->ik", a, b"#).unwrap(); let tt = format_block(einsum2(input).to_string()); insta::assert_snapshot!(tt, @r###" @@ -146,7 +178,88 @@ mod test { } out0 } - ij_jk__ik(a, b) + let arg0 = a; + let arg1 = b; + let out0 = ij_jk__ik(arg0, arg1); + out0 + } + "###); + } + + #[test] + fn einsum_ij_jk_kl() { + let input = TokenStream2::from_str(r#""ij,jk,kl->il", a, b, c"#).unwrap(); + let tt = format_block(einsum2(input).to_string()); + insta::assert_snapshot!(tt, @r###" + { + fn ij_jk__ik( + arg0: ndarray::ArrayBase, + arg1: ndarray::ArrayBase, + ) -> ndarray::Array + where + T: ndarray::LinalgScalar, + S0: ndarray::Data, + S1: ndarray::Data, + { + let (n_i, n_j) = arg0.dim(); + let (_, n_k) = arg1.dim(); + { + let (n_0, n_1) = arg0.dim(); + assert_eq!(n_0, n_i); + assert_eq!(n_1, n_j); + } + { + let (n_0, n_1) = arg1.dim(); + assert_eq!(n_0, n_j); + assert_eq!(n_1, n_k); + } + let mut out1 = ndarray::Array::zeros((n_i, n_k)); + for i in 0..n_i { + for k in 0..n_k { + for j in 0..n_j { + out1[(i, k)] = arg0[(i, j)] * arg1[(j, k)]; + } + } + } + out1 + } + fn ik_kl__il( + out1: ndarray::ArrayBase, + arg2: ndarray::ArrayBase, + ) -> ndarray::Array + where + T: ndarray::LinalgScalar, + S0: ndarray::Data, + S1: ndarray::Data, + { + let (n_i, n_k) = out1.dim(); + let (_, n_l) = arg2.dim(); + { + let (n_0, n_1) = out1.dim(); + assert_eq!(n_0, n_i); + assert_eq!(n_1, n_k); + } + { + let (n_0, n_1) = arg2.dim(); + assert_eq!(n_0, n_k); + assert_eq!(n_1, n_l); + } + let mut out0 = ndarray::Array::zeros((n_i, n_l)); + for i in 0..n_i { + for l in 0..n_l { + for k in 0..n_k { + out0[(i, l)] = out1[(i, k)] * arg2[(k, l)]; + } + } + } + out0 + } + let arg0 = a; + let arg1 = b; + let arg2 = c; + let out1 = ij_jk__ik(arg0, arg1); + let out0 = ik_kl__il(out1, arg2); + out0 } "###); } diff --git a/einsum-derive/tests/diag.rs b/einsum-derive/tests/diag.rs deleted file mode 100644 index 14adf33..0000000 --- a/einsum-derive/tests/diag.rs +++ /dev/null @@ -1,9 +0,0 @@ -use einsum_derive::einsum; -use ndarray::array; - -#[test] -fn diag() { - let a = array![[1.0, 2.0], [3.0, 4.0]]; - let d = einsum!("ii->i", a); - assert_eq!(d, array![1.0, 4.0]); -} diff --git a/einsum-derive/tests/matmul.rs b/einsum-derive/tests/matmul.rs deleted file mode 100644 index e2c9c17..0000000 --- a/einsum-derive/tests/matmul.rs +++ /dev/null @@ -1,10 +0,0 @@ -use einsum_derive::einsum; -use ndarray::array; - -#[test] -fn matmul() { - let a = array![[1.0, 2.0], [3.0, 4.0]]; - let b = array![[1.0, 2.0], [3.0, 4.0]]; - let c = einsum!("ij,jk->ik", a, b); - assert_eq!(c, array![[6.0, 8.0], [12.0, 16.0]]); -} diff --git a/einsum-solver/src/codegen/ndarray/mod.rs b/einsum-solver/src/codegen/ndarray/mod.rs deleted file mode 100644 index 98803e8..0000000 --- a/einsum-solver/src/codegen/ndarray/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -//! For [ndarray](https://crates.io/crates/ndarray) crate - -pub mod naive;