Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodiff Upstreaming - rustc_codegen_llvm changes #130060

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

ZuseZ4
Copy link
Contributor

@ZuseZ4 ZuseZ4 commented Sep 7, 2024

Now that the autodiff/Enzyme backend is merged, this is an upstream PR for the rustc_codegen_llvm changes.
It also includes small changes to three files under compiler/rustc_ast, which overlap with my frontend PR (#129458).
Here I only include minimal definitions of structs and enums to be able to build this backend code.
The same goes for minimal changes to compiler/rustc_codegen_ssa, the majority of changes there will be in another PR, once either this or the frontend gets merged.

We currently have 68 files left to merge, 19 in the frontend PR, 21 (+3 from the frontend) in this PR, and then ~30 in the middle-end.

This PR is large because it includes two of my three large files (~800 loc each). I could also first only upstream enzyme_ffi.rs, but I think people might want to see some use of these bindings in the same PR?

To already highlight the things which reviewers might want to discuss:

  1. enzyme_ffi.rs: I do have a fallback module to make sure that we don't link rustc against Enzyme when we build rustc without autodiff support.

  2. add_panic_msg_to_global was a pain to write and I currently can't even use it. Enzyme writes gradients into shadow memory. Pass in one float scalar? We'll allocate and return an extra float telling you how this float affected the output. Pass in a slice of floats? We'll let you allocate the vector and pass in a mutable reference to a float slice, we'll then write the gradient into that slice. It should be at least as large as your original slice, so we check that and panic if not. Currently we panic silently, but I already generate a nicer panic message with this function. I just don't know how to print it to the user. yet. I discussed this with a few rustc devs and the best we could come up with (for now), was to look for mangled panic calls in the IR and pick one, which works surprisingly reliably. If someone knows a good way to clean this up and print the panic message I'm all in, otherwise I can remove the code that writes the nicer panic message and keep the silent panic, since it's enough for soundness. Especially since this PR is already a bit larger.

  3. SanitizeHWAddress: When differentiating C++, Enzyme can use TBAA to "understand" enums/unions, but for Rust we don't have this information. LLVM might to speculative loads which (without TBAA) confuse Enzyme, so we disable those with this attribute. This attribute is only set during the first opt run before Enzyme differentiates code. We then remove it again once we are done with autodiff and run the opt pipeline a second time. Since enums are everywhere in Rust, support for them is crucial, but if this looks too cursed I can remove these ~100 lines and keep them in my fork for now, we can then discuss them separately to make this PR simpler?

  4. Duplicated llvm-opt runs: Differentiating already optimized code (and being able to do additional optimizations on the fly, e.g. for GPU code) is the reason why Enzyme is so fast, so the compile time is acceptable for autodiff users: https://enzyme.mit.edu/talks/Publications/ (There are also algorithmic issues in Enzyme core which are more serious than running opt twice).

  5. I assume that if we merge these minimal cg_ssa changes here already, I also need to fix the other backends (GCC and cliff) to have dummy implementations, correct?

  6. I'm happy to split this PR up further if reviewers have recommendations on how to.

For the full implementation, see: #129175

Tracking:

@rustbot
Copy link
Collaborator

rustbot commented Sep 7, 2024

r? @fee1-dead

rustbot has assigned @fee1-dead.
They will have a look at your PR within the next two weeks and either review your PR or reassign to another reviewer.

Use r? to explicitly pick a reviewer

@rustbot
Copy link
Collaborator

rustbot commented Sep 7, 2024

⚠️ Warning ⚠️

  • These commits modify submodules.

@rustbot rustbot added S-waiting-on-review Status: Awaiting review from the assignee but also interested parties. T-bootstrap Relevant to the bootstrap subteam: Rust's build system (x.py and src/bootstrap) T-compiler Relevant to the compiler team, which will review and decide on the PR/issue. labels Sep 7, 2024
@rustbot
Copy link
Collaborator

rustbot commented Sep 7, 2024

This PR modifies config.example.toml.

If appropriate, please update CONFIG_CHANGE_HISTORY in src/bootstrap/src/utils/change_tracker.rs.

Some changes occurred in cfg and check-cfg configuration

cc @Urgau

@rust-log-analyzer

This comment has been minimized.

@@ -176,6 +176,8 @@ pub(crate) fn default_configuration(sess: &Session) -> Cfg {
// NOTE: These insertions should be kept in sync with
// `CheckCfg::fill_well_known` below.

ins_none!(sym::autodiff_fallback);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be insta stable, it should be at least gated behind nightly compiler.

Suggested change
ins_none!(sym::autodiff_fallback);
if sess.is_nightly_build() {
ins_none!(sym::autodiff_fallback);
}

Please also follow all the steps regarding a new cfg as defined in the top of this file (as well as the tests files):

//! ## Adding a new cfg
//!
//! Adding a new feature requires two new symbols one for the cfg it-self
//! and the second one for the unstable feature gate, those are defined in
//! `rustc_span::symbol`.
//!
//! As well as the following points,
//! - Add the activation logic in [`default_configuration`]
//! - Add the cfg to [`CheckCfg::fill_well_known`] (and related files),
//! so that the compiler can know the cfg is expected
//! - Add the cfg in [`disallow_cfgs`] to disallow users from setting it via `--cfg`
//! - Add the feature gating in `compiler/rustc_feature/src/builtin_attrs.rs`

@jieyouxu jieyouxu added the F-autodiff `#![feature(autodiff)]` label Sep 7, 2024
@fee1-dead
Copy link
Member

r? compiler

@rust-log-analyzer
Copy link
Collaborator

The job mingw-check-tidy failed! Check out the build log: (web) (plain)

Click to see the possible cause of the failure (guessed by this bot)

COPY host-x86_64/mingw-check/validate-toolstate.sh /scripts/
COPY host-x86_64/mingw-check/validate-error-codes.sh /scripts/

# NOTE: intentionally uses python2 for x.py so we can test it still works.
# validate-toolstate only runs in our CI, so it's ok for it to only support python3.
ENV SCRIPT TIDY_PRINT_DIFF=1 python2.7 ../x.py test \
           --stage 0 src/tools/tidy tidyselftest --extra-checks=py:lint,cpp:fmt
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
#    pip-compile --allow-unsafe --generate-hashes reuse-requirements.in
---
#13 2.867 Building wheels for collected packages: reuse
#13 2.868   Building wheel for reuse (pyproject.toml): started
#13 3.115   Building wheel for reuse (pyproject.toml): finished with status 'done'
#13 3.116   Created wheel for reuse: filename=reuse-4.0.3-cp310-cp310-manylinux_2_35_x86_64.whl size=132715 sha256=dfa09868353292d98f811d3efdb0d54d07389e808efc71d68e3b93c514bf8bec
#13 3.116   Stored in directory: /tmp/pip-ephem-wheel-cache-xfx0tb92/wheels/3d/8d/0a/e0fc6aba4494b28a967ab5eaf951c121d9c677958714e34532
#13 3.118 Installing collected packages: boolean-py, binaryornot, tomlkit, reuse, python-debian, markupsafe, license-expression, jinja2, chardet, attrs
#13 3.512 Successfully installed attrs-23.2.0 binaryornot-0.4.4 boolean-py-4.0 chardet-5.2.0 jinja2-3.1.4 license-expression-30.3.0 markupsafe-2.1.5 python-debian-0.1.49 reuse-4.0.3 tomlkit-0.13.0
#13 3.512 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
#13 4.045 Collecting virtualenv
#13 4.045 Collecting virtualenv
#13 4.121   Downloading virtualenv-20.26.6-py3-none-any.whl (6.0 MB)
#13 4.335      ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.0/6.0 MB 28.4 MB/s eta 0:00:00
#13 4.397 Collecting filelock<4,>=3.12.2
#13 4.403   Downloading filelock-3.16.1-py3-none-any.whl (16 kB)
#13 4.423 Collecting distlib<1,>=0.3.7
#13 4.438   Downloading distlib-0.3.8-py2.py3-none-any.whl (468 kB)
#13 4.448      ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 468.9/468.9 KB 55.6 MB/s eta 0:00:00
#13 4.497 Collecting platformdirs<5,>=3.9.1
#13 4.504   Downloading platformdirs-4.3.6-py3-none-any.whl (18 kB)
#13 4.583 Installing collected packages: distlib, platformdirs, filelock, virtualenv
#13 4.772 Successfully installed distlib-0.3.8 filelock-3.16.1 platformdirs-4.3.6 virtualenv-20.26.6
#13 DONE 4.9s

#14 [7/8] COPY host-x86_64/mingw-check/validate-toolstate.sh /scripts/
#14 DONE 0.0s
---
DirectMap4k:      200640 kB
DirectMap2M:     9236480 kB
DirectMap1G:     9437184 kB
##[endgroup]
Executing TIDY_PRINT_DIFF=1 python2.7 ../x.py test            --stage 0 src/tools/tidy tidyselftest --extra-checks=py:lint,cpp:fmt
+ TIDY_PRINT_DIFF=1 python2.7 ../x.py test --stage 0 src/tools/tidy tidyselftest --extra-checks=py:lint,cpp:fmt
    Finished `dev` profile [unoptimized] target(s) in 0.04s
##[endgroup]
downloading https://static.rust-lang.org/dist/2024-09-22/rustfmt-nightly-x86_64-unknown-linux-gnu.tar.xz
extracting /checkout/obj/build/cache/2024-09-22/rustfmt-nightly-x86_64-unknown-linux-gnu.tar.xz to /checkout/obj/build/x86_64-unknown-linux-gnu/rustfmt
---
fmt check
Diff in /checkout/compiler/rustc_codegen_llvm/src/builder.rs:3:
 use std::{iter, ptr};
 
 use libc::{c_char, c_uint};
-use rustc_codegen_ssa::MemFlags;
 use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
 use rustc_ast::expand::typetree::{FncTree, TypeTree};
+use rustc_codegen_ssa::MemFlags;
 use rustc_codegen_ssa::common::{IntPredicate, RealPredicate, SynchronizationScope, TypeKind};
 use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
 use rustc_codegen_ssa::mir::place::PlaceRef;
Diff in /checkout/compiler/rustc_codegen_llvm/src/back/write.rs:45:
 use crate::llvm::diagnostic::OptimizationDiagnosticKind;
 use crate::llvm::{
 use crate::llvm::{
-    self, enzyme_rust_forward_diff, enzyme_rust_reverse_diff, AttributeKind, CreateEnzymeLogic,
-    CreateTypeAnalysis, DiagnosticInfo, EnzymeLogicRef, EnzymeTypeAnalysisRef, FreeTypeAnalysis,
-    LLVMAppendBasicBlockInContext, LLVMBuildCall2, LLVMBuildCondBr, LLVMBuildExtractValue,
-    LLVMBuildICmp, LLVMBuildRet, LLVMBuildRetVoid, LLVMCountParams, LLVMCountStructElementTypes,
-    LLVMCreateBuilderInContext, LLVMCreateStringAttribute, LLVMDisposeBuilder, LLVMDumpModule,
-    LLVMGetFirstBasicBlock, LLVMGetFirstFunction, LLVMGetNextFunction, LLVMGetParams,
-    LLVMGetReturnType, LLVMGetStringAttributeAtIndex, LLVMGlobalGetValueType, LLVMIsEnumAttribute,
+    self, AttributeKind, CreateEnzymeLogic, CreateTypeAnalysis, DiagnosticInfo, EnzymeLogicRef,
+    EnzymeTypeAnalysisRef, FreeTypeAnalysis, LLVMAppendBasicBlockInContext, LLVMBuildCall2,
+    LLVMBuildCondBr, LLVMBuildExtractValue, LLVMBuildICmp, LLVMBuildRet, LLVMBuildRetVoid,
+    LLVMCountParams, LLVMCountStructElementTypes, LLVMCreateBuilderInContext,
+    LLVMCreateStringAttribute, LLVMDisposeBuilder, LLVMDumpModule, LLVMGetFirstBasicBlock,
+    LLVMGetFirstFunction, LLVMGetNextFunction, LLVMGetParams, LLVMGetReturnType,
+    LLVMGetStringAttributeAtIndex, LLVMGlobalGetValueType, LLVMIsEnumAttribute,
     LLVMIsStringAttribute, LLVMMetadataAsValue, LLVMPositionBuilderAtEnd,
     LLVMRemoveStringAttributeAtIndex, LLVMRustAddEnumAttributeAtIndex,
     LLVMRustAddFunctionAttributes, LLVMRustDIGetInstMetadata, LLVMRustEraseInstBefore,
Diff in /checkout/compiler/rustc_codegen_llvm/src/back/write.rs:58:
     LLVMRustEraseInstFromParent, LLVMRustGetEnumAttributeAtIndex, LLVMRustGetFunctionType,
     LLVMRustGetLastInstruction, LLVMRustGetTerminator, LLVMRustHasMetadata,
     LLVMRustRemoveEnumAttributeAtIndex, LLVMVerifyFunction, LLVMVoidTypeInContext, PassManager,
-    Value,
+    Value, enzyme_rust_forward_diff, enzyme_rust_reverse_diff,
 use crate::type_::Type;
 use crate::type_::Type;
-use crate::{base, common, llvm_util, DiffTypeTree, LlvmCodegenBackend, ModuleLlvm};
+use crate::{DiffTypeTree, LlvmCodegenBackend, ModuleLlvm, base, common, llvm_util};
 
 pub(crate) fn llvm_err<'a>(dcx: DiagCtxtHandle<'_>, err: LlvmError<'a>) -> FatalError {
Diff in /checkout/compiler/rustc_codegen_llvm/src/back/write.rs:989:
Diff in /checkout/compiler/rustc_codegen_llvm/src/back/write.rs:989:
     let src_fnc = match src_fnc_opt {
         Some(x) => x,
-            return Err(llvm_err(
-                diag_handler.handle(),
-                LlvmError::PrepareAutoDiff {
-                    src: rust_name.to_owned(),
-                    src: rust_name.to_owned(),
-                    target: rust_name2.to_owned(),
-                    error: "could not find src function".to_owned(),
-            ));
-            ));
+            return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
+                src: rust_name.to_owned(),
+                target: rust_name2.to_owned(),
+                error: "could not find src function".to_owned(),
         }
     };
     };
     let target_fnc_opt = unsafe { llvm::LLVMGetNamedFunction(llmod, name2.as_ptr()) };
Diff in /checkout/compiler/rustc_codegen_llvm/src/back/write.rs:1003:
     let target_fnc = match target_fnc_opt {
         Some(x) => x,
-            return Err(llvm_err(
-                diag_handler.handle(),
-                LlvmError::PrepareAutoDiff {
-                    src: rust_name.to_owned(),
-                    src: rust_name.to_owned(),
-                    target: rust_name2.to_owned(),
-                    error: "could not find target function".to_owned(),
-                },
-            ));
+            return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
+                src: rust_name.to_owned(),
+                target: rust_name2.to_owned(),
+                error: "could not find target function".to_owned(),
         }
     };
     };
     let src_num_args = unsafe { llvm::LLVMCountParams(src_fnc) };
Diff in /checkout/compiler/rustc_codegen_llvm/src/back/write.rs:1176:
     let logic_ref_opt: EnzymeLogicRef = unsafe { CreateEnzymeLogic(fnc_opt_base as u8) };
     for item in first_order_items {
-        let res =
-        let res =
-            enzyme_ad(llmod, llcx, &diag_handler.handle(), item, logic_ref_opt, ad);
+        let res = enzyme_ad(llmod, llcx, &diag_handler.handle(), item, logic_ref_opt, ad);
         assert!(res.is_ok());
 
fmt: checked 5598 files
fmt: checked 5598 files
fmt error: Running `"/checkout/obj/build/x86_64-unknown-linux-gnu/rustfmt/bin/rustfmt" "--config-path" "/checkout" "--edition" "2021" "--unstable-features" "--skip-children" "--check" "/checkout/compiler/rustc_codegen_llvm/src/base.rs" "/checkout/compiler/rustc_codegen_llvm/src/value.rs" "/checkout/compiler/rustc_error_codes/src/lib.rs" "/checkout/compiler/rustc_codegen_llvm/src/debuginfo/namespace.rs" "/checkout/compiler/rustc_codegen_llvm/src/debuginfo/mod.rs" "/checkout/compiler/rustc_codegen_llvm/src/debuginfo/utils.rs" "/checkout/compiler/rustc_codegen_llvm/src/debuginfo/metadata.rs" "/checkout/compiler/rustc_codegen_llvm/src/debuginfo/gdb.rs" "/checkout/compiler/rustc_codegen_llvm/src/debuginfo/create_scope_map.rs" "/checkout/compiler/rustc_fluent_macro/src/fluent.rs" "/checkout/compiler/rustc_fluent_macro/src/lib.rs" "/checkout/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs" "/checkout/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/mod.rs" "/checkout/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs" "/checkout/compiler/rustc_codegen_llvm/src/debuginfo/metadata/type_map.rs" "/checkout/compiler/rustc_codegen_llvm/src/va_arg.rs" "/checkout/compiler/rustc_codegen_llvm/src/type_.rs" "/checkout/compiler/rustc_codegen_llvm/src/builder.rs" "/checkout/compiler/rustc_codegen_llvm/src/common.rs" "/checkout/compiler/rustc_codegen_llvm/src/allocator.rs" "/checkout/compiler/rustc_codegen_llvm/src/context.rs" "/checkout/compiler/rustc_codegen_llvm/src/callee.rs" "/checkout/compiler/rustc_codegen_llvm/src/declare.rs" "/checkout/compiler/rustc_codegen_llvm/src/mono_item.rs" "/checkout/compiler/rustc_ast_pretty/src/helpers.rs" "/checkout/compiler/rustc_codegen_llvm/src/coverageinfo/map_data.rs" "/checkout/compiler/rustc_codegen_llvm/src/coverageinfo/mod.rs" "/checkout/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs" "/checkout/compiler/rustc_codegen_llvm/src/coverageinfo/ffi.rs" "/checkout/compiler/rustc_codegen_llvm/src/asm.rs" "/checkout/compiler/rustc_codegen_llvm/src/intrinsic.rs" "/checkout/compiler/rustc_ast_pretty/src/pprust/mod.rs" "/checkout/compiler/rustc_codegen_llvm/src/errors.rs" "/checkout/compiler/rustc_ast_pretty/src/pprust/tests.rs" "/checkout/compiler/rustc_codegen_llvm/src/attributes.rs" "/checkout/compiler/rustc_ast_pretty/src/pprust/state.rs" "/checkout/compiler/rustc_codegen_llvm/src/llvm_util.rs" "/checkout/compiler/rustc_ast_pretty/src/pprust/state/expr.rs" "/checkout/compiler/rustc_ast_pretty/src/pprust/state/fixup.rs" "/checkout/compiler/rustc_ast_pretty/src/pprust/state/item.rs" "/checkout/compiler/rustc_ast_pretty/src/pp.rs" "/checkout/compiler/rustc_codegen_llvm/src/back/archive.rs" "/checkout/compiler/rustc_codegen_llvm/src/back/profiling.rs" "/checkout/compiler/rustc_codegen_llvm/src/back/lto.rs" "/checkout/compiler/rustc_codegen_llvm/src/back/owned_target_machine.rs" "/checkout/compiler/rustc_codegen_llvm/src/back/write.rs" "/checkout/compiler/rustc_codegen_llvm/src/typetree.rs" "/checkout/compiler/rustc_codegen_llvm/src/type_of.rs" "/checkout/compiler/rustc_codegen_llvm/src/consts.rs" "/checkout/compiler/rustc_codegen_llvm/src/lib.rs" "/checkout/compiler/rustc_ast_pretty/src/pp/convenience.rs" "/checkout/compiler/rustc_ast_pretty/src/pp/ring.rs" "/checkout/compiler/rustc_ast_pretty/src/lib.rs" "/checkout/compiler/rustc_codegen_llvm/src/llvm/archive_ro.rs" "/checkout/compiler/rustc_codegen_llvm/src/llvm/mod.rs" "/checkout/compiler/rustc_codegen_llvm/src/llvm/diagnostic.rs" "/checkout/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs" "/checkout/compiler/rustc_codegen_llvm/src/llvm/ffi.rs" "/checkout/compiler/rustc_lint_defs/src/builtin.rs" "/checkout/compiler/rustc_lint_defs/src/lib.rs" "/checkout/compiler/rustc_macros/src/type_foldable.rs" "/checkout/compiler/rustc_macros/src/extension.rs" "/checkout/compiler/rustc_macros/src/current_version.rs" "/checkout/compiler/rustc_codegen_llvm/src/abi.rs"` failed.
If you're running `tidy`, try again with `--bless`. Or, if you just want to format code, run `./x.py fmt` instead.
  local time: Tue Oct  1 00:37:44 UTC 2024
  network time: Tue, 01 Oct 2024 00:37:45 GMT
##[error]Process completed with exit code 1.
Post job cleanup.

@bors
Copy link
Contributor

bors commented Oct 4, 2024

☔ The latest upstream changes (presumably #131237) made this pull request unmergeable. Please resolve the merge conflicts.

@michaelwoerister
Copy link
Member

r? compiler

@rustbot rustbot assigned davidtwco and unassigned michaelwoerister Oct 7, 2024
@davidtwco
Copy link
Member

There's very little chance of this being merged in one PR with one commit of this size. You'll need to split this up into well-commented/motivated PRs that can be landed one at a time. I haven't spent much time looking at this PR, so I don't have any suggestions on how to split this up. I'd recommend finding someone on the compiler team who is interested in these changes and who you can work with to do the reviews.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have time to review this, so just one drive-by note: It looks like a decent part of the extra FFI APIs in enzyme_ffi.rs are essentially duplicates of things that we already have bindings for under slightly different names and signatures. Like we already have LLVMRustAddFunctionAttributes/LLVMRustAddCallSiteAttributes and this introduces LLVMRustAddEnumAttributeAtIndex. It also looks like the code doesn't make use of the Builder abstraction and instead calls FFI APIs directly everywhere, which is probably also where the duplication comes from.

@ZuseZ4
Copy link
Contributor Author

ZuseZ4 commented Oct 8, 2024

@nikic Do you mind if I drop the approach in this PR and add Enzyme back as an LLVM pass instead of using Enzyme as a library?
I already have an alternative compilation pipeline when autodiff is used under release mode (default O3 pipeline without vectorization or unrolling -> ad -> full default O3 pipeline), so I don't think doing ad as an extra pass will be controversial?

Historically Julia (and Rust except for the very first experiments) used the lib approach, while Clang used the pass approach. The pass used to support fewer configurations and metadata information. We now added parser to read all metadata from the llvm-ir module and added missing config options, so Julia (Enzyme.jl) is now also moving to the pass since it simplifies the config a little. JuliaGPU/GPUCompiler.jl#636 (comment)
Their work is more complicated due to having JIT, GC, and type unstable code, but even the rust side would be simplified if I let the pass do some of the config handling.

To give a simplified example of the two approaches, the frontend rustc_builtin_macro autodiff is applied to a function f and generates

fn f(some_args: types) {
  // original user implementation, e.g. sin(x)
}
fn df(slightly_altered_args: similar_types) {
  // macro generated placeholder
}

I used to pass the llvm-ir function of f, together with metadata to the Enzyme library, and got a new function back.
Then I removed the placeholder body of df and added a call to the new function I got in the previous library call, together with adding wrappers to handle e.g. ABI optimizations which rustc does. In the future I would adjust the placeholder code to look like this

fn df(slightly_altered_args: similar_types) {
  return __enzyme_autodiff(f, some, config, args, slightly_altered_args);
}

(In reality, I would do the rewriting on llvm-ir, just using Rust for readability).
The enzyme pass recognizes __enzyme_autodiff calls, looks up f and differentiates based on the config settings.
A simple example that the pass would handle correctly (which our library doesn't handle yet) is getting the autodiff order right. Say differentiating f gives df and differentiating g gives dg. Say also g calls the function df. It gives us a DAG, so we must first replace the placeholder of df with the actual derivative before differentiating g to get dg, as we will otherwise compute the derivative of our placeholder for df. A documentation of the current (limited) approach is here: https://enzyme.mit.edu/index.fcgi/rust/usage/higher.html. There are more improvements, one of them being that this PR becomes a lot smaller, since I don't need to call Enzyme directly and thus won't need wrappers for those functions.

If anyone has concerns it would be great to hear them now already, otherwise I'd try to get the new PR up soon. I've already done most of the work in the past since it also offers greater debugability due to having all relevant info in the llvm-ir module, instead of passing data directly in memory to a library call, which is hard to test and reproduce.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
F-autodiff `#![feature(autodiff)]` S-waiting-on-review Status: Awaiting review from the assignee but also interested parties. T-bootstrap Relevant to the bootstrap subteam: Rust's build system (x.py and src/bootstrap) T-compiler Relevant to the compiler team, which will review and decide on the PR/issue.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants