From 2b6b22cebab5d4d24d6fc917906935c36aa45a34 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 11 Oct 2024 19:13:31 +0200 Subject: [PATCH] Single commit implementing the enzyme/autodiff frontend Co-authored-by: Lorenz Schmidt --- core/src/lib.rs | 9 +++++++++ core/src/macros/mod.rs | 18 ++++++++++++++++++ std/src/lib.rs | 9 ++++++++- 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/core/src/lib.rs b/core/src/lib.rs index 96ab5755328e1..f69a33bca8446 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -278,6 +278,15 @@ pub mod assert_matches { pub use crate::macros::{assert_matches, debug_assert_matches}; } +// We don't export this through #[macro_export] for now, to avoid breakage. +#[cfg(not(bootstrap))] +#[unstable(feature = "autodiff", issue = "124509")] +/// Unstable module containing the unstable `autodiff` macro. +pub mod autodiff { + #[unstable(feature = "autodiff", issue = "124509")] + pub use crate::macros::builtin::autodiff; +} + #[unstable(feature = "cfg_match", issue = "115585")] pub use crate::macros::cfg_match; diff --git a/core/src/macros/mod.rs b/core/src/macros/mod.rs index aa0646846e43e..b5e5b58f7051f 100644 --- a/core/src/macros/mod.rs +++ b/core/src/macros/mod.rs @@ -1539,6 +1539,24 @@ pub(crate) mod builtin { ($file:expr $(,)?) => {{ /* compiler built-in */ }}; } + /// Automatic Differentiation macro which allows generating a new function to compute + /// the derivative of a given function. It may only be applied to a function. + /// The expected usage syntax is + /// `#[autodiff(NAME, MODE, INPUT_ACTIVITIES, OUTPUT_ACTIVITY)]` + /// where: + /// NAME is a string that represents a valid function name. + /// MODE is any of Forward, Reverse, ForwardFirst, ReverseFirst. + /// INPUT_ACTIVITIES consists of one valid activity for each input parameter. + /// OUTPUT_ACTIVITY must not be set if we implicitely return nothing (or explicitely return + /// `-> ()`. Otherwise it must be set to one of the allowed activities. + #[unstable(feature = "autodiff", issue = "124509")] + #[allow_internal_unstable(rustc_attrs)] + #[rustc_builtin_macro] + #[cfg(not(bootstrap))] + pub macro autodiff($item:item) { + /* compiler built-in */ + } + /// Asserts that a boolean expression is `true` at runtime. /// /// This will invoke the [`panic!`] macro if the provided expression cannot be diff --git a/std/src/lib.rs b/std/src/lib.rs index 65a9aa66c7cc6..35ed761759bd7 100644 --- a/std/src/lib.rs +++ b/std/src/lib.rs @@ -267,6 +267,7 @@ #![allow(unused_features)] // // Features: +#![cfg_attr(not(bootstrap), feature(autodiff))] #![cfg_attr(test, feature(internal_output_capture, print_internals, update_panic_count, rt))] #![cfg_attr( all(target_vendor = "fortanix", target_env = "sgx"), @@ -627,7 +628,13 @@ pub mod simd { #[doc(inline)] pub use crate::std_float::StdFloat; } - +#[cfg(not(bootstrap))] +#[unstable(feature = "autodiff", issue = "124509")] +/// This module provides support for automatic differentiation. +pub mod autodiff { + /// This macro handles automatic differentiation. + pub use core::autodiff::autodiff; +} #[stable(feature = "futures_api", since = "1.36.0")] pub mod task { //! Types and Traits for working with asynchronous tasks.