Skip to content

Commit

Permalink
Auto merge of rust-lang#129458 - EnzymeAD:enzyme-frontend, r=jieyouxu
Browse files Browse the repository at this point in the history
Autodiff Upstreaming - enzyme frontend

This is an upstream PR for the `autodiff` rustc_builtin_macro that is part of the autodiff feature.

For the full implementation, see: rust-lang#129175

**Content:**
It contains a new `#[autodiff(<args>)]` rustc_builtin_macro, as well as a `#[rustc_autodiff]` builtin attribute.
The autodiff macro is applied on function `f` and will expand to a second function `df` (name given by user).
It will add a dummy body to `df` to make sure it type-checks. The body will later be replaced by enzyme on llvm-ir level,
we therefore don't really care about the content. Most of the changes (700 from 1.2k) are in `compiler/rustc_builtin_macros/src/autodiff.rs`, which expand the macro. Nothing except expansion is implemented for now.
I have a fallback implementation for relevant functions in case that rustc should be build without autodiff support. The default for now will be off, although we want to flip it later (once everything landed) to on for nightly. For the sake of CI, I have flipped the defaults, I'll revert this before merging.

**Dummy function Body:**
The first line is an `inline_asm` nop to make inlining less likely (I have additional checks to prevent this in the middle end of rustc. If `f` gets inlined too early, we can't pass it to enzyme and thus can't differentiate it.
If `df` gets inlined too early, the call site will just compute this dummy code instead of the derivatives, a correctness issue. The following black_box lines make sure that none of the input arguments is getting optimized away before we replace the body.

**Motivation:**
The user facing autodiff macro can verify the user input. Then I write it as args to the rustc_attribute, so from here on I can know that these values should be sensible. A rustc_attribute also turned out to be quite nice to attach this information to the corresponding function and carry it till the backend.
This is also just an experiment, I expect to adjust the user facing autodiff macro based on user feedback, to improve usability.

As a simple example of what this will do, we can see this expansion:
From:
```
#[autodiff(df, Reverse, Duplicated, Const, Active)]
pub fn f1(x: &[f64], y: f64) -> f64 {
    unimplemented!()
}
```
to
```
#[rustc_autodiff]
#[inline(never)]
pub fn f1(x: &[f64], y: f64) -> f64 {
    ::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
#[inline(never)]
pub fn df(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
    unsafe { asm!("NOP"); };
    ::core::hint::black_box(f1(x, y));
    ::core::hint::black_box((dx, dret));
    ::core::hint::black_box(f1(x, y))
}
```
I will add a few more tests once I figured out why rustc rebuilds every time I touch a test.

Tracking:

- rust-lang#124509

try-job: dist-x86_64-msvc
  • Loading branch information
bors committed Oct 15, 2024
2 parents ecb3830 + 2b6b22c commit cc7730e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
9 changes: 9 additions & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,15 @@ pub mod assert_matches {
pub use crate::macros::{assert_matches, debug_assert_matches};
}

// We don't export this through #[macro_export] for now, to avoid breakage.
#[cfg(not(bootstrap))]
#[unstable(feature = "autodiff", issue = "124509")]
/// Unstable module containing the unstable `autodiff` macro.
pub mod autodiff {
#[unstable(feature = "autodiff", issue = "124509")]
pub use crate::macros::builtin::autodiff;
}

#[unstable(feature = "cfg_match", issue = "115585")]
pub use crate::macros::cfg_match;

Expand Down
18 changes: 18 additions & 0 deletions core/src/macros/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1539,6 +1539,24 @@ pub(crate) mod builtin {
($file:expr $(,)?) => {{ /* compiler built-in */ }};
}

/// Automatic Differentiation macro which allows generating a new function to compute
/// the derivative of a given function. It may only be applied to a function.
/// The expected usage syntax is
/// `#[autodiff(NAME, MODE, INPUT_ACTIVITIES, OUTPUT_ACTIVITY)]`
/// where:
/// NAME is a string that represents a valid function name.
/// MODE is any of Forward, Reverse, ForwardFirst, ReverseFirst.
/// INPUT_ACTIVITIES consists of one valid activity for each input parameter.
/// OUTPUT_ACTIVITY must not be set if we implicitely return nothing (or explicitely return
/// `-> ()`. Otherwise it must be set to one of the allowed activities.
#[unstable(feature = "autodiff", issue = "124509")]
#[allow_internal_unstable(rustc_attrs)]
#[rustc_builtin_macro]
#[cfg(not(bootstrap))]
pub macro autodiff($item:item) {
/* compiler built-in */
}

/// Asserts that a boolean expression is `true` at runtime.
///
/// This will invoke the [`panic!`] macro if the provided expression cannot be
Expand Down
9 changes: 8 additions & 1 deletion std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@
#![allow(unused_features)]
//
// Features:
#![cfg_attr(not(bootstrap), feature(autodiff))]
#![cfg_attr(test, feature(internal_output_capture, print_internals, update_panic_count, rt))]
#![cfg_attr(
all(target_vendor = "fortanix", target_env = "sgx"),
Expand Down Expand Up @@ -624,7 +625,13 @@ pub mod simd {
#[doc(inline)]
pub use crate::std_float::StdFloat;
}

#[cfg(not(bootstrap))]
#[unstable(feature = "autodiff", issue = "124509")]
/// This module provides support for automatic differentiation.
pub mod autodiff {
/// This macro handles automatic differentiation.
pub use core::autodiff::autodiff;
}
#[stable(feature = "futures_api", since = "1.36.0")]
pub mod task {
//! Types and Traits for working with asynchronous tasks.
Expand Down

0 comments on commit cc7730e

Please sign in to comment.