Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring codegen #14

Merged
merged 3 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions einsum-derive/src/args.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use proc_macro2::TokenStream as TokenStream2;
use proc_macro_error::{abort_call_site, ResultExt};
use syn::parse::Parser;

pub fn parse(input: TokenStream2) -> (String, Vec<syn::Expr>) {
let parser = syn::punctuated::Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated;
let args = parser
.parse2(input)
.expect_or_abort("Invalid input for einsum!");
let mut iter = args.into_iter();
let subscripts = if let Some(syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit),
attrs: _,
})) = iter.next()
{
lit.value()
} else {
abort_call_site!("einsum! must start with subscript string literal")
};
let args = iter.collect::<Vec<_>>();
(subscripts, args)
}

#[cfg(test)]
mod test {
use super::*;
use std::str::FromStr;

#[test]
fn test_parse() {
let input = TokenStream2::from_str(r#""ij,jk->ik", a, b"#).unwrap();
let (subscripts, exprs) = parse(input);
assert_eq!(subscripts, "ij,jk->ik");
assert_eq!(exprs.len(), 2);
assert_eq!(exprs[0], syn::parse_str::<syn::Expr>("a").unwrap());
assert_eq!(exprs[1], syn::parse_str::<syn::Expr>("b").unwrap());
}
}
222 changes: 222 additions & 0 deletions einsum-derive/src/einsum_fn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
use einsum_solver::subscripts::Subscripts;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::quote;
use std::collections::{hash_map::Entry, HashMap};

use crate::ident::*;

/// Generate for loop
///
/// ```ignore
/// for #index0 in 0..#n0 {
/// for #index1 in 0..#n1 {
/// for #index2 in 0..#n2 {
/// #inner
/// }
/// }
/// }
/// ```
fn contraction_for(indices: &[char], inner: TokenStream2) -> TokenStream2 {
let mut tt = inner;
for &i in indices.iter().rev() {
let index = index_ident(i);
let n = n_ident(i);
tt = quote! {
for #index in 0..#n { #tt }
};
}
tt
}

fn contraction_inner(subscripts: &Subscripts) -> TokenStream2 {
let mut inner_args_tt = Vec::new();
for argc in 0..subscripts.inputs.len() {
let name = arg_ident(argc);
let mut index = Vec::new();
for i in subscripts.inputs[argc].indices() {
index.push(index_ident(i));
}
inner_args_tt.push(quote! {
#name[(#(#index),*)]
})
}
let mut inner_mul = None;
for inner in inner_args_tt {
match inner_mul {
Some(i) => inner_mul = Some(quote! { #i * #inner }),
None => inner_mul = Some(inner),
}
}

let output_ident = output_ident();
let mut output_indices = Vec::new();
for i in &subscripts.output.indices() {
let index = index_ident(*i);
output_indices.push(index.clone());
}
quote! {
#output_ident[(#(#output_indices),*)] = #inner_mul;
}
}

/// Generate contraction parts, e.g.
///
/// ```ignore
/// for i in 0..n_i {
/// for k in 0..n_k {
/// for j in 0..n_j {
/// out[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
/// }
/// }
/// }
/// ```
///
fn contraction(subscripts: &Subscripts) -> TokenStream2 {
let mut indices: Vec<char> = subscripts.output.indices();
for i in subscripts.contraction_indices() {
indices.push(i);
}

let inner = contraction_inner(subscripts);
contraction_for(&indices, inner)
}

fn array_size(subscripts: &Subscripts) -> Vec<TokenStream2> {
let mut n_idents: HashMap<char, proc_macro2::Ident> = HashMap::new();
let mut tt = Vec::new();
for argc in 0..subscripts.inputs.len() {
let name = arg_ident(argc);
let mut index = Vec::new();
let mut n_index_each = Vec::new();
let mut def_or_assert = Vec::new();
for (m, i) in subscripts.inputs[argc].indices().into_iter().enumerate() {
index.push(index_ident(i));
let n = n_each_ident(argc, m);
match n_idents.entry(i) {
Entry::Occupied(entry) => {
let n_ = entry.get();
def_or_assert.push(quote! {
assert_eq!(#n_, #n);
});
}
Entry::Vacant(entry) => {
let n_ident = n_ident(i);
def_or_assert.push(quote! {
let #n_ident = #n;
});
entry.insert(n_ident);
}
}
n_index_each.push(n);
}
tt.push(quote! {
let (#(#n_index_each),*) = #name.dim();
#( #def_or_assert )*
});
}
tt
}

fn def_output_array(subscripts: &Subscripts) -> TokenStream2 {
// Define output array
let output_ident = output_ident();
let mut n_output = Vec::new();
for i in subscripts.output.indices() {
n_output.push(n_ident(i));
}
quote! {
let mut #output_ident = ndarray::Array::zeros((#(#n_output),*));
}
}

pub fn def_einsum_fn(subscripts: &Subscripts) -> TokenStream2 {
let fn_name = syn::Ident::new(&format!("{}", subscripts), Span::call_site());
let n = subscripts.inputs.len();

let args: Vec<syn::Ident> = (0..n).map(arg_ident).collect();
let storages: Vec<syn::Ident> = (0..n).map(|n| quote::format_ident!("S{}", n)).collect();
let dims: Vec<syn::Path> = subscripts
.inputs
.iter()
.map(|ss| dim(ss.indices().len()))
.collect();

let out_dim = dim(subscripts.output.indices().len());

let array_size = array_size(subscripts);
let output_ident = output_ident();
let output_tt = def_output_array(subscripts);
let contraction_tt = contraction(subscripts);

quote! {
fn #fn_name<T, #(#storages),*>(
#( #args: ndarray::ArrayBase<#storages, #dims> ),*
) -> ndarray::Array<T, #out_dim>
where
T: ndarray::LinalgScalar,
#( #storages: ndarray::Data<Elem = T> ),*
{
#(#array_size)*
#output_tt
#contraction_tt
#output_ident
}
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::format::format_block;
use einsum_solver::subscripts::{Namespace, Subscripts};

#[test]
fn contraction_snapshots() {
let mut namespace = Namespace::init();
let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap();
let tt = format_block(contraction(&subscripts).to_string());
insta::assert_snapshot!(tt, @r###"
for i in 0..n_i {
for k in 0..n_k {
for j in 0..n_j {
out[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
}
}
}
"###);
}

#[test]
fn einsum_fn_snapshots() {
let mut namespace = Namespace::init();
let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap();
let tt = format_block(def_einsum_fn(&subscripts).to_string());
insta::assert_snapshot!(tt, @r###"
fn ij_jk__ik<T, S0, S1>(
arg0: ndarray::ArrayBase<S0, ndarray::Ix2>,
arg1: ndarray::ArrayBase<S1, ndarray::Ix2>,
) -> ndarray::Array<T, ndarray::Ix2>
where
T: ndarray::LinalgScalar,
S0: ndarray::Data<Elem = T>,
S1: ndarray::Data<Elem = T>,
{
let (n_0_0, n_0_1) = arg0.dim();
let n_i = n_0_0;
let n_j = n_0_1;
let (n_1_0, n_1_1) = arg1.dim();
assert_eq!(n_j, n_1_0);
let n_k = n_1_1;
let mut out = ndarray::Array::zeros((n_i, n_k));
for i in 0..n_i {
for k in 0..n_k {
for j in 0..n_j {
out[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
}
}
}
out
}
"###);
}
}
40 changes: 40 additions & 0 deletions einsum-derive/src/format.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use std::{
io::Write,
process::{Command, Stdio},
};

/// Format generated Rust code using `rustfmt` run as external process.
pub fn format_block(tt: String) -> String {
let tt = format!("fn main() {{ {} }}", tt);

let mut child = Command::new("rustfmt")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.expect("Failed to spawn rustfmt process");

// Write input from another thread for avoiding deadlock.
// See https://doc.rust-lang.org/std/process/index.html#handling-io
let mut stdin = child.stdin.take().expect("Failed to open stdin");
std::thread::spawn(move || {
stdin
.write_all(tt.as_bytes())
.expect("Failed to write to stdin");
});
let output = child
.wait_with_output()
.expect("Failed to wait output of rustfmt process");

// non-UTF8 comment should be handled in the tokenize phase,
// and not be included in IR.
let out = String::from_utf8(output.stdout).expect("rustfmt output contains non-UTF8 input");

let formatted_lines: Vec<&str> = out
.lines()
.filter_map(|line| match line {
"fn main() {" | "}" => None,
_ => line.strip_prefix(" "),
})
.collect();
formatted_lines.join("\n")
}
24 changes: 24 additions & 0 deletions einsum-derive/src/ident.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
pub fn dim(n: usize) -> syn::Path {
let ix = quote::format_ident!("Ix{}", n);
syn::parse_quote! { ndarray::#ix }
}

pub fn output_ident() -> syn::Ident {
quote::format_ident!("out")
}

pub fn index_ident(i: char) -> syn::Ident {
quote::format_ident!("{}", i)
}

pub fn n_ident(i: char) -> syn::Ident {
quote::format_ident!("n_{}", i)
}

pub fn n_each_ident(argc: usize, i: usize) -> syn::Ident {
quote::format_ident!("n_{}_{}", argc, i)
}

pub fn arg_ident(argc: usize) -> syn::Ident {
quote::format_ident!("arg{}", argc)
}
Loading