Skip to content

Commit

Permalink
addressing reviewer feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Sep 3, 2024
1 parent ca94290 commit f82e085
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 83 deletions.
163 changes: 91 additions & 72 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
/// This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
/// we create an `AutoDiffItem` which contains the source and target function names. The source
/// is the function to which the autodiff attribute is applied, and the target is the function
/// getting generated by us (with a name given by the user as the first autodiff arg).
use std::fmt::{self, Display, Formatter};
use std::str::FromStr;

Expand All @@ -6,27 +10,91 @@ use crate::expand::{Decodable, Encodable, HashStable_Generic};
use crate::ptr::P;
use crate::{Ty, TyKind};

#[allow(dead_code)]
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffMode {
/// No autodiff is applied (usually used during error handling).
Inactive,
/// The primal function which we will differentiate.
Source,
/// The target function, to be created using forward mode AD.
Forward,
/// The target function, to be created using reverse mode AD.
Reverse,
/// The target function, to be created using forward mode AD.
/// This target function will also be used as a source for higher order derivatives,
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
ForwardFirst,
/// The target function, to be created using reverse mode AD.
/// This target function will also be used as a source for higher order derivatives,
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
ReverseFirst,
}

pub fn is_rev(mode: DiffMode) -> bool {
match mode {
DiffMode::Reverse | DiffMode::ReverseFirst => true,
_ => false,
}
/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
/// we add to the previous shadow value. To not surprise users, we picked different names.
/// Dual numbers is also a quite well known name for forward mode AD types.
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffActivity {
/// Implicit or Explicit () return type, so a special case of Const.
None,
/// Don't compute derivatives with respect to this input/output.
Const,
/// Reverse Mode, Compute derivatives for this scalar input/output.
Active,
/// Reverse Mode, Compute derivatives for this scalar output, but don't compute
/// the original return value.
ActiveOnly,
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
/// with it.
Dual,
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
/// with it. Drop the code which updates the original input/output for maximum performance.
DualOnly,
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
Duplicated,
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
/// Drop the code which updates the original input for maximum performance.
DuplicatedOnly,
/// All Integers must be Const, but these are used to mark the integer which represents the
/// length of a slice/vec. This is used for safety checks on slices.
FakeActivitySize,
}
pub fn is_fwd(mode: DiffMode) -> bool {
match mode {
DiffMode::Forward | DiffMode::ForwardFirst => true,
_ => false,
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffItem {
/// The name of the function getting differentiated
pub source: String,
/// The name of the function being generated
pub target: String,
pub attrs: AutoDiffAttrs,
/// Despribe the memory layout of input types
pub inputs: Vec<TypeTree>,
/// Despribe the memory layout of the output type
pub output: TypeTree,
}
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffAttrs {
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
/// e.g. in the [JAX
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
pub mode: DiffMode,
pub ret_activity: DiffActivity,
pub input_activity: Vec<DiffActivity>,
}

impl DiffMode {
pub fn is_rev(&self) -> bool {
match self {
DiffMode::Reverse | DiffMode::ReverseFirst => true,
_ => false,
}
}
pub fn is_fwd(&self) -> bool {
match self {
DiffMode::Forward | DiffMode::ForwardFirst => true,
_ => false,
}
}
}

Expand Down Expand Up @@ -63,30 +131,20 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
}
}
}
fn is_ptr_or_ref(ty: &Ty) -> bool {
match ty.kind {
TyKind::Ptr(_) | TyKind::Ref(_, _) => true,
_ => false,
}
}
// TODO We should make this more robust to also

// FIXME(ZuseZ4) We should make this more robust to also
// accept aliases of f32 and f64
//fn is_float(ty: &Ty) -> bool {
// false
//}
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
if is_ptr_or_ref(ty) {
return activity == DiffActivity::Dual
|| activity == DiffActivity::DualOnly
|| activity == DiffActivity::Duplicated
|| activity == DiffActivity::DuplicatedOnly
|| activity == DiffActivity::Const;
match ty.kind {
TyKind::Ptr(_) | TyKind::Ref(..) => {
return activity == DiffActivity::Dual
|| activity == DiffActivity::DualOnly
|| activity == DiffActivity::Duplicated
|| activity == DiffActivity::DuplicatedOnly
|| activity == DiffActivity::Const;
}
_ => false,
}
true
//if is_scalar_ty(&ty) {
// return activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly ||
// activity == DiffActivity::Const;
//}
}
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
return match mode {
Expand Down Expand Up @@ -117,20 +175,6 @@ pub fn invalid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -
None
}

#[allow(dead_code)]
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffActivity {
None,
Const,
Active,
ActiveOnly,
Dual,
DualOnly,
Duplicated,
DuplicatedOnly,
FakeActivitySize,
}

impl Display for DiffActivity {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Expand Down Expand Up @@ -180,30 +224,14 @@ impl FromStr for DiffActivity {
}
}

#[allow(dead_code)]
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffAttrs {
pub mode: DiffMode,
pub ret_activity: DiffActivity,
pub input_activity: Vec<DiffActivity>,
}

impl AutoDiffAttrs {
pub fn has_ret_activity(&self) -> bool {
match self.ret_activity {
DiffActivity::None => false,
_ => true,
}
self.ret_activity != DiffActivity::None
}
pub fn has_active_only_ret(&self) -> bool {
match self.ret_activity {
DiffActivity::ActiveOnly => true,
_ => false,
}
self.ret_activity == DiffActivity::ActiveOnly
}
}

impl AutoDiffAttrs {
pub fn inactive() -> Self {
AutoDiffAttrs {
mode: DiffMode::Inactive,
Expand Down Expand Up @@ -251,15 +279,6 @@ impl AutoDiffAttrs {
}
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffItem {
pub source: String,
pub target: String,
pub attrs: AutoDiffAttrs,
pub inputs: Vec<TypeTree>,
pub output: TypeTree,
}

impl fmt::Display for AutoDiffItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
Expand Down
21 changes: 10 additions & 11 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use std::str::FromStr;
use std::string::String;

use rustc_ast::expand::autodiff_attrs::{
is_fwd, is_rev, valid_input_activity, valid_ty_for_activity, AutoDiffAttrs, DiffActivity,
DiffMode,
valid_input_activity, valid_ty_for_activity, AutoDiffAttrs, DiffActivity, DiffMode,
};
use rustc_ast::ptr::P;
use rustc_ast::token::{Token, TokenKind};
Expand Down Expand Up @@ -382,7 +381,7 @@ fn gen_enzyme_body(
// So that can be treated identical to not having one in the first place.
let primal_ret = sig.decl.output.has_ret() && !x.has_active_only_ret();

if primal_ret && n_active == 0 && is_rev(x.mode) {
if primal_ret && n_active == 0 && x.mode.is_rev() {
// We only have the primal ret.
body.stmts.push(ecx.stmt_expr(black_box_primal_call.clone()));
return body;
Expand Down Expand Up @@ -437,7 +436,7 @@ fn gen_enzyme_body(
panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
}
};
if is_fwd(x.mode) {
if x.mode.is_fwd() {
if x.ret_activity == DiffActivity::Dual {
assert!(d_ret_ty.len() == 2);
// both should be identical, by construction
Expand All @@ -451,7 +450,7 @@ fn gen_enzyme_body(
exprs.push(default_call_expr);
}
} else {
assert!(is_rev(x.mode));
assert!(x.mode.is_rev());

if primal_ret {
// We have extra handling above for the primal ret
Expand Down Expand Up @@ -562,7 +561,7 @@ fn gen_enzyme_decl(
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
ident.name
} else {
trace!("{:#?}", &shadow_arg.pat);
dbg!(&shadow_arg.pat);
panic!("not an ident?");
};
let name: String = format!("d{}", old_name);
Expand All @@ -582,7 +581,7 @@ fn gen_enzyme_decl(
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
ident.name
} else {
trace!("{:#?}", &shadow_arg.pat);
dbg!(&shadow_arg.pat);
panic!("not an ident?");
};
let name: String = format!("b{}", old_name);
Expand All @@ -601,7 +600,7 @@ fn gen_enzyme_decl(
// Nothing to do here.
}
_ => {
trace!{"{:#?}", &activity};
dbg!(&activity);
panic!("Not implemented");
}
}
Expand All @@ -614,12 +613,12 @@ fn gen_enzyme_decl(

let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
if active_only_ret {
assert!(is_rev(x.mode));
assert!(x.mode.is_rev());
}

// If we return a scalar in the primal and the scalar is active,
// then add it as last arg to the inputs.
if is_rev(x.mode) {
if x.mode.is_rev() {
match x.ret_activity {
DiffActivity::Active | DiffActivity::ActiveOnly => {
let ty = match d_decl.output {
Expand Down Expand Up @@ -651,7 +650,7 @@ fn gen_enzyme_decl(
}
d_decl.inputs = d_inputs.into();

if is_fwd(x.mode) {
if x.mode.is_fwd() {
if let DiffActivity::Dual = x.ret_activity {
let ty = match d_decl.output {
FnRetTy::Ty(ref ty) => ty.clone(),
Expand Down

0 comments on commit f82e085

Please sign in to comment.