Skip to content

Commit

Permalink
Single commit implementing the enzyme/autodiff frontend
Browse files Browse the repository at this point in the history
Co-authored-by: Lorenz Schmidt <bytesnake@mailbox.org>
  • Loading branch information
ZuseZ4 and bytesnake committed Oct 2, 2024
1 parent c3ce4e6 commit f86bed6
Show file tree
Hide file tree
Showing 23 changed files with 1,875 additions and 0 deletions.
281 changes: 281 additions & 0 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
//! we create an `AutoDiffItem` which contains the source and target function names. The source
//! is the function to which the autodiff attribute is applied, and the target is the function
//! getting generated by us (with a name given by the user as the first autodiff arg).

use std::fmt::{self, Display, Formatter};
use std::str::FromStr;

use crate::expand::typetree::TypeTree;
use crate::expand::{Decodable, Encodable, HashStable_Generic};
use crate::ptr::P;
use crate::{Ty, TyKind};

/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
/// are a hack to support higher order derivatives. We need to compute first order derivatives
/// before we compute second order derivatives, otherwise we would differentiate our placeholder
/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
/// as it's already done in the C++ and Julia frontend of Enzyme. (FIXME) remove *First variants.
/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffMode {
/// No autodiff is applied (used during error handling).
Error,
/// The primal function which we will differentiate.
Source,
/// The target function, to be created using forward mode AD.
Forward,
/// The target function, to be created using reverse mode AD.
Reverse,
/// The target function, to be created using forward mode AD.
/// This target function will also be used as a source for higher order derivatives,
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
ForwardFirst,
/// The target function, to be created using reverse mode AD.
/// This target function will also be used as a source for higher order derivatives,
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
ReverseFirst,
}

/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
/// we add to the previous shadow value. To not surprise users, we picked different names.
/// Dual numbers is also a quite well known name for forward mode AD types.
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffActivity {
/// Implicit or Explicit () return type, so a special case of Const.
None,
/// Don't compute derivatives with respect to this input/output.
Const,
/// Reverse Mode, Compute derivatives for this scalar input/output.
Active,
/// Reverse Mode, Compute derivatives for this scalar output, but don't compute
/// the original return value.
ActiveOnly,
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
/// with it.
Dual,
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
/// with it. Drop the code which updates the original input/output for maximum performance.
DualOnly,
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
Duplicated,
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
/// Drop the code which updates the original input for maximum performance.
DuplicatedOnly,
/// All Integers must be Const, but these are used to mark the integer which represents the
/// length of a slice/vec. This is used for safety checks on slices.
FakeActivitySize,
}
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffItem {
/// The name of the function getting differentiated
pub source: String,
/// The name of the function being generated
pub target: String,
pub attrs: AutoDiffAttrs,
/// Describe the memory layout of input types
pub inputs: Vec<TypeTree>,
/// Describe the memory layout of the output type
pub output: TypeTree,
}
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffAttrs {
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
/// e.g. in the [JAX
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
pub mode: DiffMode,
pub ret_activity: DiffActivity,
pub input_activity: Vec<DiffActivity>,
}

impl DiffMode {
pub fn is_rev(&self) -> bool {
matches!(self, DiffMode::Reverse | DiffMode::ReverseFirst)
}
pub fn is_fwd(&self) -> bool {
matches!(self, DiffMode::Forward | DiffMode::ForwardFirst)
}
}

impl Display for DiffMode {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
DiffMode::Error => write!(f, "Error"),
DiffMode::Source => write!(f, "Source"),
DiffMode::Forward => write!(f, "Forward"),
DiffMode::Reverse => write!(f, "Reverse"),
DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
}
}
}

/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
if activity == DiffActivity::None {
// Only valid if primal returns (), but we can't check that here.
return true;
}
match mode {
DiffMode::Error => false,
DiffMode::Source => false,
DiffMode::Forward | DiffMode::ForwardFirst => {
activity == DiffActivity::Dual
|| activity == DiffActivity::DualOnly
|| activity == DiffActivity::Const
}
DiffMode::Reverse | DiffMode::ReverseFirst => {
activity == DiffActivity::Const
|| activity == DiffActivity::Active
|| activity == DiffActivity::ActiveOnly
}
}
}

/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
/// for the given argument, but we generally can't know the size of such a type.
/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
/// users here from marking scalars as Duplicated, due to type aliases.
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
use DiffActivity::*;
// It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
if matches!(activity, Const) {
return true;
}
if matches!(activity, Dual | DualOnly) {
return true;
}
// FIXME(ZuseZ4) We should make this more robust to also
// handle type aliases. Once that is done, we can be more restrictive here.
if matches!(activity, Active | ActiveOnly) {
return true;
}
matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
&& matches!(activity, Duplicated | DuplicatedOnly)
}
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
use DiffActivity::*;
return match mode {
DiffMode::Error => false,
DiffMode::Source => false,
DiffMode::Forward | DiffMode::ForwardFirst => {
matches!(activity, Dual | DualOnly | Const)
}
DiffMode::Reverse | DiffMode::ReverseFirst => {
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
}
};
}

impl Display for DiffActivity {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
DiffActivity::None => write!(f, "None"),
DiffActivity::Const => write!(f, "Const"),
DiffActivity::Active => write!(f, "Active"),
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
DiffActivity::Dual => write!(f, "Dual"),
DiffActivity::DualOnly => write!(f, "DualOnly"),
DiffActivity::Duplicated => write!(f, "Duplicated"),
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
}
}
}

impl FromStr for DiffMode {
type Err = ();

fn from_str(s: &str) -> Result<DiffMode, ()> {
match s {
"Error" => Ok(DiffMode::Error),
"Source" => Ok(DiffMode::Source),
"Forward" => Ok(DiffMode::Forward),
"Reverse" => Ok(DiffMode::Reverse),
"ForwardFirst" => Ok(DiffMode::ForwardFirst),
"ReverseFirst" => Ok(DiffMode::ReverseFirst),
_ => Err(()),
}
}
}
impl FromStr for DiffActivity {
type Err = ();

fn from_str(s: &str) -> Result<DiffActivity, ()> {
match s {
"None" => Ok(DiffActivity::None),
"Active" => Ok(DiffActivity::Active),
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
"Const" => Ok(DiffActivity::Const),
"Dual" => Ok(DiffActivity::Dual),
"DualOnly" => Ok(DiffActivity::DualOnly),
"Duplicated" => Ok(DiffActivity::Duplicated),
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
_ => Err(()),
}
}
}

impl AutoDiffAttrs {
pub fn has_ret_activity(&self) -> bool {
self.ret_activity != DiffActivity::None
}
pub fn has_active_only_ret(&self) -> bool {
self.ret_activity == DiffActivity::ActiveOnly
}

pub fn error() -> Self {
AutoDiffAttrs {
mode: DiffMode::Error,
ret_activity: DiffActivity::None,
input_activity: Vec::new(),
}
}
pub fn source() -> Self {
AutoDiffAttrs {
mode: DiffMode::Source,
ret_activity: DiffActivity::None,
input_activity: Vec::new(),
}
}

pub fn is_active(&self) -> bool {
self.mode != DiffMode::Error
}

pub fn is_source(&self) -> bool {
self.mode == DiffMode::Source
}
pub fn apply_autodiff(&self) -> bool {
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
}

pub fn into_item(
self,
source: String,
target: String,
inputs: Vec<TypeTree>,
output: TypeTree,
) -> AutoDiffItem {
AutoDiffItem { source, target, inputs, output, attrs: self }
}
}

impl fmt::Display for AutoDiffItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
write!(f, " with attributes: {:?}", self.attrs)?;
write!(f, " with inputs: {:?}", self.inputs)?;
write!(f, " with output: {:?}", self.output)
}
}
2 changes: 2 additions & 0 deletions compiler/rustc_ast/src/expand/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use rustc_span::symbol::Ident;
use crate::MetaItem;

pub mod allocator;
pub mod autodiff_attrs;
pub mod typetree;

#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
pub struct StrippedCfgItem<ModId = DefId> {
Expand Down
69 changes: 69 additions & 0 deletions compiler/rustc_ast/src/expand/typetree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::fmt;

use crate::expand::{Decodable, Encodable, HashStable_Generic};

#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum Kind {
Anything,
Integer,
Pointer,
Half,
Float,
Double,
Unknown,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct TypeTree(pub Vec<Type>);

impl TypeTree {
pub fn new() -> Self {
Self(Vec::new())
}
pub fn all_ints() -> Self {
Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
}
pub fn int(size: usize) -> Self {
let mut ints = Vec::with_capacity(size);
for i in 0..size {
ints.push(Type {
offset: i as isize,
size: 1,
kind: Kind::Integer,
child: TypeTree::new(),
});
}
Self(ints)
}
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct FncTree {
pub args: Vec<TypeTree>,
pub ret: TypeTree,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct Type {
pub offset: isize,
pub size: usize,
pub kind: Kind,
pub child: TypeTree,
}

impl Type {
pub fn add_offset(self, add: isize) -> Self {
let offset = match self.offset {
-1 => add,
x => add + x,
};

Self { size: self.size, kind: self.kind, child: self.child, offset }
}
}

impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<Self as fmt::Debug>::fmt(self, f)
}
}
4 changes: 4 additions & 0 deletions compiler/rustc_builtin_macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ name = "rustc_builtin_macros"
version = "0.0.0"
edition = "2021"


[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(llvm_enzyme)'] }

[lib]
doctest = false

Expand Down
9 changes: 9 additions & 0 deletions compiler/rustc_builtin_macros/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ builtin_macros_assert_requires_boolean = macro requires a boolean expression as
builtin_macros_assert_requires_expression = macro requires an expression as an argument
.suggestion = try removing semicolon
builtin_macros_autodiff = autodiff must be applied to function
builtin_macros_autodiff_missing_config = autodiff requires at least a name and mode
builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse`
builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
.label = not applicable here
.label2 = not a `struct`, `enum` or `union`
Expand Down
Loading

0 comments on commit f86bed6

Please sign in to comment.