diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs new file mode 100644 index 0000000000000..c45b74fe9b1ec --- /dev/null +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -0,0 +1,79 @@ +/// This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute, +/// we create an `AutoDiffItem` which contains the source and target function names. The source +/// is the function to which the autodiff attribute is applied, and the target is the function +/// getting generated by us (with a name given by the user as the first autodiff arg). +use crate::expand::typetree::TypeTree; +use crate::expand::{Decodable, Encodable, HashStable_Generic}; + +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub enum DiffMode { + /// No autodiff is applied (usually used during error handling). + Inactive, + /// The primal function which we will differentiate. + Source, + /// The target function, to be created using forward mode AD. + Forward, + /// The target function, to be created using reverse mode AD. + Reverse, + /// The target function, to be created using forward mode AD. + /// This target function will also be used as a source for higher order derivatives, + /// so compute it before all Forward/Reverse targets and optimize it through llvm. + ForwardFirst, + /// The target function, to be created using reverse mode AD. + /// This target function will also be used as a source for higher order derivatives, + /// so compute it before all Forward/Reverse targets and optimize it through llvm. + ReverseFirst, +} + +/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity. +/// However, under forward mode we overwrite the previous shadow value, while for reverse mode +/// we add to the previous shadow value. To not surprise users, we picked different names. +/// Dual numbers is also a quite well known name for forward mode AD types. +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub enum DiffActivity { + /// Implicit or Explicit () return type, so a special case of Const. + None, + /// Don't compute derivatives with respect to this input/output. + Const, + /// Reverse Mode, Compute derivatives for this scalar input/output. + Active, + /// Reverse Mode, Compute derivatives for this scalar output, but don't compute + /// the original return value. + ActiveOnly, + /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument + /// with it. + Dual, + /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument + /// with it. Drop the code which updates the original input/output for maximum performance. + DualOnly, + /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument. + Duplicated, + /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument. + /// Drop the code which updates the original input for maximum performance. + DuplicatedOnly, + /// All Integers must be Const, but these are used to mark the integer which represents the + /// length of a slice/vec. This is used for safety checks on slices. + FakeActivitySize, +} +/// We generate one of these structs for each `#[autodiff(...)]` attribute. +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct AutoDiffItem { + /// The name of the function getting differentiated + pub source: String, + /// The name of the function being generated + pub target: String, + pub attrs: AutoDiffAttrs, + /// Despribe the memory layout of input types + pub inputs: Vec, + /// Despribe the memory layout of the output type + pub output: TypeTree, +} +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct AutoDiffAttrs { + /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and + /// e.g. in the [JAX + /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions). + pub mode: DiffMode, + pub ret_activity: DiffActivity, + pub input_activity: Vec, +} diff --git a/compiler/rustc_ast/src/expand/mod.rs b/compiler/rustc_ast/src/expand/mod.rs index 13413281bc7cc..d259677e98e3d 100644 --- a/compiler/rustc_ast/src/expand/mod.rs +++ b/compiler/rustc_ast/src/expand/mod.rs @@ -7,6 +7,8 @@ use rustc_span::symbol::Ident; use crate::MetaItem; pub mod allocator; +pub mod autodiff_attrs; +pub mod typetree; #[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)] pub struct StrippedCfgItem { diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs new file mode 100644 index 0000000000000..e4a0bfc32aff4 --- /dev/null +++ b/compiler/rustc_ast/src/expand/typetree.rs @@ -0,0 +1,69 @@ +use std::fmt; + +use crate::expand::{Decodable, Encodable, HashStable_Generic}; + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub enum Kind { + Anything, + Integer, + Pointer, + Half, + Float, + Double, + Unknown, +} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct TypeTree(pub Vec); + +impl TypeTree { + pub fn new() -> Self { + Self(Vec::new()) + } + pub fn all_ints() -> Self { + Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }]) + } + pub fn int(size: usize) -> Self { + let mut ints = Vec::with_capacity(size); + for i in 0..size { + ints.push(Type { + offset: i as isize, + size: 1, + kind: Kind::Integer, + child: TypeTree::new(), + }); + } + Self(ints) + } +} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct FncTree { + pub args: Vec, + pub ret: TypeTree, +} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct Type { + pub offset: isize, + pub size: usize, + pub kind: Kind, + pub child: TypeTree, +} + +impl Type { + pub fn add_offset(self, add: isize) -> Self { + let offset = match self.offset { + -1 => add, + x => add + x, + }; + + Self { size: self.size, kind: self.kind, child: self.child, offset } + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } +} diff --git a/compiler/rustc_codegen_llvm/messages.ftl b/compiler/rustc_codegen_llvm/messages.ftl index df2198df14b6a..65ecf2143c178 100644 --- a/compiler/rustc_codegen_llvm/messages.ftl +++ b/compiler/rustc_codegen_llvm/messages.ftl @@ -54,6 +54,10 @@ codegen_llvm_prepare_thin_lto_module_with_llvm_err = failed to prepare thin LTO codegen_llvm_run_passes = failed to run LLVM passes codegen_llvm_run_passes_with_llvm_err = failed to run LLVM passes: {$llvm_err} +codegen_llvm_prepare_autodiff = failed to prepare AutoDiff: src: {$src}, target: {$target}, {$error} +codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare AutoDiff: {$llvm_err}, src: {$src}, target: {$target}, {$error} +codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto + codegen_llvm_sanitizer_memtag_requires_mte = `-Zsanitizer=memtag` requires `-Ctarget-feature=+mte` diff --git a/compiler/rustc_codegen_llvm/src/attributes.rs b/compiler/rustc_codegen_llvm/src/attributes.rs index 92a857c2adcf4..5ee683404518f 100644 --- a/compiler/rustc_codegen_llvm/src/attributes.rs +++ b/compiler/rustc_codegen_llvm/src/attributes.rs @@ -1,6 +1,8 @@ //! Set and unset common attributes on LLVM values. use rustc_attr::{InlineAttr, InstructionSetAttr, OptimizeAttr}; +// FIXME(ZuseZ4): Re-enable once the middle-end is merged. +//use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs; use rustc_codegen_ssa::traits::*; use rustc_hir::def_id::DefId; use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, PatchableFunctionEntry}; @@ -333,6 +335,8 @@ pub(crate) fn llfn_attrs_from_instance<'ll, 'tcx>( instance: ty::Instance<'tcx>, ) { let codegen_fn_attrs = cx.tcx.codegen_fn_attrs(instance.def_id()); + // FIXME(ZuseZ4): Re-enable once the middle-end is merged. + //let autodiff_attrs: &AutoDiffAttrs = cx.tcx.autodiff_attrs(instance.def_id()); let mut to_add = SmallVec::<[_; 16]>::new(); @@ -350,6 +354,9 @@ pub(crate) fn llfn_attrs_from_instance<'ll, 'tcx>( let inline = if codegen_fn_attrs.inline == InlineAttr::None && instance.def.requires_inline(cx.tcx) { InlineAttr::Hint + // FIXME(ZuseZ4): re-enable once the middle-end is merged. + //} else if autodiff_attrs.is_active() { + // InlineAttr::Never } else { codegen_fn_attrs.inline }; diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index aa6842c75cec7..51f226f0c4bf2 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -622,7 +622,12 @@ pub(crate) fn run_pass_manager( } let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO }; let opt_level = config.opt_level.unwrap_or(config::OptLevel::No); - write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage)?; + + // We will run this again with different values in the context of automatic differentiation. + let first_run = true; + let noop = false; + debug!("running llvm pm opt pipeline"); + write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run, noop)?; } debug!("lto done"); Ok(()) @@ -729,7 +734,12 @@ pub(crate) unsafe fn optimize_thin_module( let llcx = unsafe { llvm::LLVMRustContextCreate(cgcx.fewer_names) }; let llmod_raw = parse_module(llcx, module_name, thin_module.data(), dcx)? as *const _; let mut module = ModuleCodegen { - module_llvm: ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm) }, + module_llvm: ModuleLlvm { + llmod_raw, + llcx, + tm: ManuallyDrop::new(tm), + typetrees: Default::default(), + }, name: thin_module.name().to_string(), kind: ModuleKind::Regular, }; diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index a5c27d2282eaf..d4a143964537b 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -4,10 +4,13 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use std::{fs, slice, str}; -use libc::{c_char, c_int, c_void, size_t}; +use libc::{c_char, c_int, c_uint, c_void, size_t}; use llvm::{ - LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, + IntPredicate, LLVMRustLLVMHasZlibCompressionForDebugSymbols, + LLVMRustLLVMHasZstdCompressionForDebugSymbols, }; +use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode}; +use rustc_ast::expand::typetree::FncTree; use rustc_codegen_ssa::back::link::ensure_removed; use rustc_codegen_ssa::back::write::{ BitcodeSection, CodegenContext, EmitObj, ModuleConfig, TargetMachineFactoryConfig, @@ -15,19 +18,21 @@ use rustc_codegen_ssa::back::write::{ }; use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{CompiledModule, ModuleCodegen}; +use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::profiling::SelfProfilerRef; use rustc_data_structures::small_c_str::SmallCStr; use rustc_errors::{DiagCtxtHandle, FatalError, Level}; use rustc_fs_util::{link_or_copy, path_to_c_string}; use rustc_middle::ty::TyCtxt; use rustc_session::config::{ - self, Lto, OutputType, Passes, RemapPathScopeComponents, SplitDwarfKind, SwitchWithOptPath, + self, AutoDiff, Lto, OutputType, Passes, RemapPathScopeComponents, SplitDwarfKind, + SwitchWithOptPath, }; use rustc_session::Session; use rustc_span::symbol::sym; use rustc_span::InnerSpan; use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo, TlsModel}; -use tracing::debug; +use tracing::{debug, trace}; use crate::back::lto::ThinBuffer; use crate::back::owned_target_machine::OwnedTargetMachine; @@ -39,9 +44,24 @@ use crate::errors::{ WithLlvmError, WriteBytecode, }; use crate::llvm::diagnostic::OptimizationDiagnosticKind; -use crate::llvm::{self, DiagnosticInfo, PassManager}; +use crate::llvm::{ + self, enzyme_rust_forward_diff, enzyme_rust_reverse_diff, AttributeKind, CreateEnzymeLogic, + CreateTypeAnalysis, DiagnosticInfo, EnzymeLogicRef, EnzymeTypeAnalysisRef, FreeTypeAnalysis, + LLVMAppendBasicBlockInContext, LLVMBuildCall2, LLVMBuildCondBr, LLVMBuildExtractValue, + LLVMBuildICmp, LLVMBuildRet, LLVMBuildRetVoid, LLVMCountParams, LLVMCountStructElementTypes, + LLVMCreateBuilderInContext, LLVMCreateStringAttribute, LLVMDisposeBuilder, LLVMDumpModule, + LLVMGetFirstBasicBlock, LLVMGetFirstFunction, LLVMGetNextFunction, LLVMGetParams, + LLVMGetReturnType, LLVMGetStringAttributeAtIndex, LLVMGlobalGetValueType, LLVMIsEnumAttribute, + LLVMIsStringAttribute, LLVMMetadataAsValue, LLVMPositionBuilderAtEnd, + LLVMRemoveStringAttributeAtIndex, LLVMRustAddEnumAttributeAtIndex, + LLVMRustAddFunctionAttributes, LLVMRustDIGetInstMetadata, LLVMRustEraseInstBefore, + LLVMRustEraseInstFromParent, LLVMRustGetEnumAttributeAtIndex, LLVMRustGetFunctionType, + LLVMRustGetLastInstruction, LLVMRustGetTerminator, LLVMRustHasMetadata, + LLVMRustRemoveEnumAttributeAtIndex, LLVMVerifyFunction, LLVMVoidTypeInContext, PassManager, + Value, +}; use crate::type_::Type; -use crate::{base, common, llvm_util, LlvmCodegenBackend, ModuleLlvm}; +use crate::{base, common, llvm_util, DiffTypeTree, LlvmCodegenBackend, ModuleLlvm}; pub(crate) fn llvm_err<'a>(dcx: DiagCtxtHandle<'_>, err: LlvmError<'a>) -> FatalError { match llvm::last_error() { @@ -509,9 +529,38 @@ pub(crate) unsafe fn llvm_optimize( config: &ModuleConfig, opt_level: config::OptLevel, opt_stage: llvm::OptStage, + first_run: bool, + noop: bool, ) -> Result<(), FatalError> { - let unroll_loops = - opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + if noop { + return Ok(()); + } + // Enzyme: + // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized + // source code. However, benchmarks show that optimizations increasing the code size + // tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code + // and finally re-optimize the module, now with all optimizations available. + // TODO: In a future update we could figure out how to only optimize functions getting + // differentiated. + + let unroll_loops; + let vectorize_slp; + let vectorize_loop; + + if first_run { + unroll_loops = false; + vectorize_slp = false; + vectorize_loop = false; + } else { + unroll_loops = + opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + vectorize_slp = config.vectorize_slp; + vectorize_loop = config.vectorize_loop; + } + trace!( + "Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}", + unroll_loops, vectorize_slp, vectorize_loop + ); let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed(); let pgo_gen_path = get_pgo_gen_path(config); let pgo_use_path = get_pgo_use_path(config); @@ -575,8 +624,8 @@ pub(crate) unsafe fn llvm_optimize( using_thin_buffers, config.merge_functions, unroll_loops, - config.vectorize_slp, - config.vectorize_loop, + vectorize_slp, + vectorize_loop, config.no_builtins, config.emit_lifetime_markers, sanitizer_options.as_ref(), @@ -599,6 +648,660 @@ pub(crate) unsafe fn llvm_optimize( result.into_result().map_err(|()| llvm_err(dcx, LlvmError::RunLlvmPasses)) } +fn get_params(fnc: &Value) -> Vec<&Value> { + // SAFETY: Assumes that Value is a function. + unsafe { + let param_num = LLVMCountParams(fnc) as usize; + let mut fnc_args: Vec<&Value> = vec![]; + fnc_args.reserve(param_num); + LLVMGetParams(fnc, fnc_args.as_mut_ptr()); + fnc_args.set_len(param_num); + fnc_args + } +} + +// DESIGN: +// Today we have our placeholder function, and our Enzyme generated one. +// We create a wrapper function and delete the placeholder body. You can see the +// placeholder by running `cargo expand` on an autodiff invocation. We call the wrapper +// from the placeholder. This function is a bit longer, because it matches the Rust level +// autodiff macro with LLVM level Enzyme autodiff expectations. +// +// Think of computing the derivative with respect to &[f32] by marking it as duplicated. +// The user will then pass an extra &mut [f32] and we want add the derivative to that. +// On LLVM/Enzyme level, &[f32] however becomes `ptr, i64` and we mark ptr as duplicated, +// and i64 (len) as const. Enzyme will then expect `ptr, ptr, i64` as arguments. See how the +// second i64 from the mut slice isn't used? That's why we add a safety check to assert +// that the second (mut) slice is at least as long as the first (const) slice. Otherwise, +// Enzyme would write out of bounds if the first (const) slice is longer than the second. + +unsafe fn create_call<'a>( + tgt: &'a Value, + src: &'a Value, + llmod: &'a llvm::Module, + llcx: &llvm::Context, + // FIXME: Instead of recomputing the positions as we do it below, we should + // start using this list of positions that indicate length integers. + _size_positions: &[usize], + ad: &[AutoDiff], +) { + unsafe { + // first, remove all calls from fnc + let bb = LLVMGetFirstBasicBlock(tgt); + let br = LLVMRustGetTerminator(bb); + LLVMRustEraseInstFromParent(br); + + // now add a call to inner. + // append call to src at end of bb. + let f_ty = LLVMRustGetFunctionType(src); + + let inner_param_num = LLVMCountParams(src); + let outer_param_num = LLVMCountParams(tgt); + let outer_args: Vec<&Value> = get_params(tgt); + let inner_args: Vec<&Value> = get_params(src); + let mut call_args: Vec<&Value> = vec![]; + + let mut safety_vals = vec![]; + let builder = LLVMCreateBuilderInContext(llcx); + let last_inst = LLVMRustGetLastInstruction(bb).unwrap(); + LLVMPositionBuilderAtEnd(builder, bb); + + let safety_run_checks = !ad.contains(&AutoDiff::NoSafetyChecks); + + if inner_param_num == outer_param_num { + call_args = outer_args; + } else { + trace!("Different number of args, adjusting"); + let mut outer_pos: usize = 0; + let mut inner_pos: usize = 0; + // copy over if they are identical. + // If not, skip the outer arg (and assert it's int). + while outer_pos < outer_param_num as usize { + let inner_arg = inner_args[inner_pos]; + let outer_arg = outer_args[outer_pos]; + let inner_arg_ty = llvm::LLVMTypeOf(inner_arg); + let outer_arg_ty = llvm::LLVMTypeOf(outer_arg); + if inner_arg_ty == outer_arg_ty { + call_args.push(outer_arg); + inner_pos += 1; + outer_pos += 1; + } else { + // out: rust: (&[f32], &mut [f32]) + // out: llvm: (ptr, <>int1, ptr, int2) + // inner: (ptr, <>ptr, int) + // goal: call (ptr, ptr, int1), skipping int2 + // we are here: <> + assert!(llvm::LLVMRustGetTypeKind(outer_arg_ty) == llvm::TypeKind::Integer); + assert!(llvm::LLVMRustGetTypeKind(inner_arg_ty) == llvm::TypeKind::Pointer); + let next_outer_arg = outer_args[outer_pos + 1]; + let next_inner_arg = inner_args[inner_pos + 1]; + let next_outer_arg_ty = llvm::LLVMTypeOf(next_outer_arg); + let next_inner_arg_ty = llvm::LLVMTypeOf(next_inner_arg); + assert!( + llvm::LLVMRustGetTypeKind(next_outer_arg_ty) == llvm::TypeKind::Pointer + ); + assert!( + llvm::LLVMRustGetTypeKind(next_inner_arg_ty) == llvm::TypeKind::Integer + ); + let next2_outer_arg = outer_args[outer_pos + 2]; + let next2_outer_arg_ty = llvm::LLVMTypeOf(next2_outer_arg); + assert!( + llvm::LLVMRustGetTypeKind(next2_outer_arg_ty) == llvm::TypeKind::Integer + ); + call_args.push(next_outer_arg); + call_args.push(outer_arg); + + outer_pos += 3; + inner_pos += 2; + + if safety_run_checks { + // Now we assert if int1 <= int2 + let res = LLVMBuildICmp( + builder, + IntPredicate::IntULE as u32, + outer_arg, + next2_outer_arg, + "safety_check".as_ptr() as *const c_char, + ); + safety_vals.push(res); + } + } + } + } + + if inner_param_num as usize != call_args.len() { + panic!( + "Args len shouldn't differ. Please report this. {} : {}", + inner_param_num, + call_args.len() + ); + } + + // Now add the safety checks. + if !safety_vals.is_empty() { + dbg!("Adding safety checks"); + assert!(safety_run_checks); + // first we create one bb per check and two more for the fail and success case. + let fail_bb = LLVMAppendBasicBlockInContext( + llcx, + tgt, + "ad_safety_fail".as_ptr() as *const c_char, + ); + let success_bb = LLVMAppendBasicBlockInContext( + llcx, + tgt, + "ad_safety_success".as_ptr() as *const c_char, + ); + for i in 1..safety_vals.len() { + // 'or' all safety checks together + // Doing some binary tree style or'ing here would be more efficient, + // but I assume LLVM will opt it anyway + let prev = safety_vals[i - 1]; + let curr = safety_vals[i]; + let res = llvm::LLVMBuildOr( + builder, + prev, + curr, + "safety_check".as_ptr() as *const c_char, + ); + safety_vals[i] = res; + } + LLVMBuildCondBr(builder, safety_vals.last().unwrap(), success_bb, fail_bb); + LLVMPositionBuilderAtEnd(builder, fail_bb); + + let panic_name: CString = get_panic_name(llmod); + + let mut arg_vec = vec![add_panic_msg_to_global(llmod, llcx)]; + + let fnc1 = llvm::LLVMGetNamedFunction(llmod, panic_name.as_ptr() as *const c_char); + assert!(fnc1.is_some()); + let fnc1 = fnc1.unwrap(); + let ty = LLVMRustGetFunctionType(fnc1); + let call = LLVMBuildCall2( + builder, + ty, + fnc1, + arg_vec.as_mut_ptr(), + arg_vec.len(), + panic_name.as_ptr() as *const c_char, + ); + llvm::LLVMSetTailCall(call, 1); + llvm::LLVMBuildUnreachable(builder); + LLVMPositionBuilderAtEnd(builder, success_bb); + } + + let inner_fnc_name = llvm::get_value_name(src); + let c_inner_fnc_name = CString::new(inner_fnc_name).unwrap(); + + let mut struct_ret = LLVMBuildCall2( + builder, + f_ty, + src, + call_args.as_mut_ptr(), + call_args.len(), + c_inner_fnc_name.as_ptr(), + ); + + // Add dummy dbg info to our newly generated call, if we have any. + let md_ty = llvm::LLVMGetMDKindIDInContext( + llcx, + "dbg".as_ptr() as *const c_char, + "dbg".len() as c_uint, + ); + + if LLVMRustHasMetadata(last_inst, md_ty) { + let md = LLVMRustDIGetInstMetadata(last_inst); + let md_val = LLVMMetadataAsValue(llcx, md); + let _md2 = llvm::LLVMSetMetadata(struct_ret, md_ty, md_val); + } else { + trace!("No dbg info"); + } + + // Now clean up placeholder code. + LLVMRustEraseInstBefore(bb, last_inst); + + let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src)); + let f_is_struct = llvm::LLVMRustIsStructType(f_return_type); + let void_type = LLVMVoidTypeInContext(llcx); + // Now unwrap the struct_ret if it's actually a struct + if f_is_struct { + let num_elem_in_ret_struct = LLVMCountStructElementTypes(f_return_type); + if num_elem_in_ret_struct == 1 { + let inner_grad_name = "foo".to_string(); + let c_inner_grad_name = CString::new(inner_grad_name).unwrap(); + struct_ret = + LLVMBuildExtractValue(builder, struct_ret, 0, c_inner_grad_name.as_ptr()); + } + } + if f_return_type != void_type { + let _ret = LLVMBuildRet(builder, struct_ret); + } else { + let _ret = LLVMBuildRetVoid(builder); + } + LLVMDisposeBuilder(builder); + let _fnc_ok = + LLVMVerifyFunction(tgt, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); + } +} + +unsafe fn get_panic_name(llmod: &llvm::Module) -> CString { + // The names are mangled and their ending changes based on a hash, so just take whichever. + let mut f = unsafe { LLVMGetFirstFunction(llmod) }; + loop { + if let Some(lf) = f { + f = unsafe { LLVMGetNextFunction(lf) }; + let fnc_name = llvm::get_value_name(lf); + let fnc_name: String = String::from_utf8(fnc_name.to_vec()).unwrap(); + if fnc_name.starts_with("_ZN4core9panicking14panic_explicit") { + return CString::new(fnc_name).unwrap(); + } else if fnc_name.starts_with("_RN4core9panicking14panic_explicit") { + return CString::new(fnc_name).unwrap(); + } + } else { + break; + } + } + panic!("Could not find panic function"); +} + +// This code is called when Enzyme detects at runtime that one of the safety invariants is violated. +// For now we only check if shadow arguments are large enough. In this case we look for Rust panic +// functions in the module and call it. Due to hashing we can't hardcode the panic function name. +// Note: This worked even for panic=abort tests so seems solid enough for now. +// FIXME: Pick a panic function which allows displaying an error message. +// FIXME: We probably want to keep a handle at higher level and pass it down instead of searching. +unsafe fn add_panic_msg_to_global<'a>( + llmod: &'a llvm::Module, + llcx: &'a llvm::Context, +) -> &'a llvm::Value { + unsafe { + use llvm::*; + + // Convert the message to a CString + let msg = "autodiff safety check failed!"; + let cmsg = CString::new(msg).unwrap(); + + let msg_global_name = "ad_safety_msg".to_string(); + let cmsg_global_name = CString::new(msg_global_name).unwrap(); + + // Get the length of the message + let msg_len = msg.len(); + + // Create the array type + let i8_array_type = LLVMArrayType2(LLVMInt8TypeInContext(llcx), msg_len as u64); + + // Create the string constant + let _string_const_val = + LLVMConstStringInContext2(llcx, cmsg.as_ptr() as *const i8, msg_len as usize, 0); + + // Create the array initializer + let mut array_elems: Vec<_> = Vec::with_capacity(msg_len); + for i in 0..msg_len { + let char_value = + LLVMConstInt(LLVMInt8TypeInContext(llcx), cmsg.as_bytes()[i] as u64, 0); + array_elems.push(char_value); + } + let array_initializer = + LLVMConstArray2(LLVMInt8TypeInContext(llcx), array_elems.as_mut_ptr(), msg_len as u64); + + // Create the struct type + let global_type = LLVMStructTypeInContext(llcx, [i8_array_type].as_mut_ptr(), 1, 0); + + // Create the struct initializer + let struct_initializer = + LLVMConstStructInContext(llcx, [array_initializer].as_mut_ptr(), 1, 0); + + // Add the global variable to the module + let global_var = LLVMAddGlobal(llmod, global_type, cmsg_global_name.as_ptr() as *const i8); + LLVMRustSetLinkage(global_var, Linkage::PrivateLinkage); + LLVMSetInitializer(global_var, struct_initializer); + + global_var + } +} +use rustc_errors::DiagCtxt; + +pub(crate) fn enzyme_ad( + llmod: &llvm::Module, + llcx: &llvm::Context, + diag_handler: &DiagCtxt, + item: AutoDiffItem, + logic_ref: EnzymeLogicRef, + ad: &[AutoDiff], +) -> Result<(), FatalError> { + let autodiff_mode = item.attrs.mode; + let rust_name = item.source; + let rust_name2 = &item.target; + + let args_activity = item.attrs.input_activity.clone(); + let ret_activity: DiffActivity = item.attrs.ret_activity; + + // get target and source function + let name = CString::new(rust_name.to_owned()).unwrap(); + let name2 = CString::new(rust_name2.clone()).unwrap(); + let src_fnc_opt = unsafe { llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr()) }; + let src_fnc = match src_fnc_opt { + Some(x) => x, + None => { + return Err(llvm_err( + diag_handler.handle(), + LlvmError::PrepareAutoDiff { + src: rust_name.to_owned(), + target: rust_name2.to_owned(), + error: "could not find src function".to_owned(), + }, + )); + } + }; + let target_fnc_opt = unsafe { llvm::LLVMGetNamedFunction(llmod, name2.as_ptr()) }; + let target_fnc = match target_fnc_opt { + Some(x) => x, + None => { + return Err(llvm_err( + diag_handler.handle(), + LlvmError::PrepareAutoDiff { + src: rust_name.to_owned(), + target: rust_name2.to_owned(), + error: "could not find target function".to_owned(), + }, + )); + } + }; + let src_num_args = unsafe { llvm::LLVMCountParams(src_fnc) }; + let target_num_args = unsafe { llvm::LLVMCountParams(target_fnc) }; + // A really simple check + assert!(src_num_args <= target_num_args); + + let type_analysis: EnzymeTypeAnalysisRef = + unsafe { CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0) }; + + llvm::set_strict_aliasing(false); + + if ad.contains(&AutoDiff::PrintTA) { + llvm::set_print_type(true); + } + if ad.contains(&AutoDiff::PrintTA) { + llvm::set_print_type(true); + } + if ad.contains(&AutoDiff::PrintPerf) { + llvm::set_print_perf(true); + } + if ad.contains(&AutoDiff::Print) { + llvm::set_print(true); + } + + let mode = match autodiff_mode { + DiffMode::Forward => DiffMode::Forward, + DiffMode::Reverse => DiffMode::Reverse, + DiffMode::ForwardFirst => DiffMode::Forward, + DiffMode::ReverseFirst => DiffMode::Reverse, + _ => unreachable!(), + }; + + unsafe { + let void_type = LLVMVoidTypeInContext(llcx); + let return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src_fnc)); + let void_ret = void_type == return_type; + let tmp = match mode { + DiffMode::Forward => enzyme_rust_forward_diff( + logic_ref, + type_analysis, + src_fnc, + args_activity, + ret_activity, + void_ret, + ), + DiffMode::Reverse => enzyme_rust_reverse_diff( + logic_ref, + type_analysis, + src_fnc, + args_activity, + ret_activity, + ), + _ => unreachable!(), + }; + let res: &Value = tmp.0; + // res is getting wrapped, but we don't want the perf overhead of a fnc call indirection. + // So we'll add an alwaysinline attribute to let llvm handle it for us. + // + // FIXME(ZuseZ4): enable this, but with the right position of the arg + //let always_inline = llvm::AttributeKind::AlwaysInline; + //let attr = llvm::LLVMRustCreateAttrNoValue(llcx, always_inline); + //llvm::LLVMRustAddFunctionAttributes(res, 9, &attr, 1); + + let size_positions: Vec = tmp.1; + + create_call(target_fnc, res, llmod, llcx, &size_positions, ad); + // TODO: implement drop for wrapper type? + FreeTypeAnalysis(type_analysis); + } + + Ok(()) +} + +pub(crate) unsafe fn differentiate( + module: &ModuleCodegen, + cgcx: &CodegenContext, + diff_items: Vec, + _typetrees: FxHashMap, + config: &ModuleConfig, +) -> Result<(), FatalError> { + for item in &diff_items { + trace!("{:?}", item); + } + + let llmod = module.module_llvm.llmod(); + let llcx = &module.module_llvm.llcx; + let diag_handler = cgcx.create_dcx(); + + llvm::set_strict_aliasing(false); + + let ad = &config.autodiff; + + if ad.contains(&AutoDiff::LooseTypes) { + dbg!("Setting loose types to true"); + llvm::set_loose_types(true); + } + + // Before dumping the module, we want all the tt to become part of the module. + for (i, item) in diff_items.iter().enumerate() { + let tt: FncTree = FncTree { args: item.inputs.clone(), ret: item.output.clone() }; + let name = CString::new(item.source.clone()).unwrap(); + let fn_def: &llvm::Value = + unsafe { llvm::LLVMGetNamedFunction(llmod, name.as_ptr()).unwrap() }; + crate::builder::add_tt2(llmod, llcx, fn_def, tt); + + // Before dumping the module, we also might want to add dummy functions, which will + // trigger the LLVMEnzyme pass to run on them, if we invoke the opt binary. + // This is super helpfull if we want to create a MWE bug reproducer, e.g. to run in + // Enzyme's compiler explorer. TODO: Can we run llvm-extract on the module to remove all other functions? + if ad.contains(&AutoDiff::OPT) { + dbg!("Enable extra debug helper to debug Enzyme through the opt plugin"); + crate::builder::add_opt_dbg_helper(llmod, llcx, fn_def, item.attrs.clone(), i); + } + } + + if ad.contains(&AutoDiff::PrintModBefore) || ad.contains(&AutoDiff::OPT) { + unsafe { + LLVMDumpModule(llmod); + } + } + + if ad.contains(&AutoDiff::Inline) { + dbg!("Setting inline to true"); + llvm::set_inline(true); + } + + if ad.contains(&AutoDiff::RuntimeActivity) { + dbg!("Setting runtime activity check to true"); + llvm::set_runtime_activity_check(true); + } + + for val in ad { + match &val { + AutoDiff::TTDepth(depth) => { + assert!(*depth >= 1); + llvm::set_max_int_offset(*depth); + } + AutoDiff::TTWidth(width) => { + assert!(*width >= 1); + llvm::set_max_type_offset(*width); + } + _ => {} + } + } + + let differentiate = !diff_items.is_empty(); + let mut first_order_items: Vec = vec![]; + let mut higher_order_items: Vec = vec![]; + for item in diff_items { + if item.attrs.mode == DiffMode::ForwardFirst || item.attrs.mode == DiffMode::ReverseFirst { + first_order_items.push(item); + } else { + // default + higher_order_items.push(item); + } + } + + let fnc_opt = ad.contains(&AutoDiff::EnableFncOpt); + + // If a function is a base for some higher order ad, always optimize + let fnc_opt_base = true; + let logic_ref_opt: EnzymeLogicRef = unsafe { CreateEnzymeLogic(fnc_opt_base as u8) }; + + for item in first_order_items { + let res = + enzyme_ad(llmod, llcx, &diag_handler.handle(), item, logic_ref_opt, ad); + assert!(res.is_ok()); + } + + // For the rest, follow the user choice on debug vs release. + // Reuse the opt one if possible for better compile time (Enzyme internal caching). + let logic_ref = match fnc_opt { + true => { + dbg!("Enable extra optimizations for Enzyme"); + logic_ref_opt + } + false => unsafe { CreateEnzymeLogic(fnc_opt as u8) }, + }; + for item in higher_order_items { + let res = enzyme_ad(llmod, llcx, &diag_handler.handle(), item, logic_ref, ad); + assert!(res.is_ok()); + } + + // Remove these attributes after Enzyme is done, + // since we have only added them to prevent a specific LLVM optimization around enums. + // TODO(ZuseZ4, wsmoses): This probably wants some further discussions. + unsafe { + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let attr = LLVMGetStringAttributeAtIndex( + lf, + c_uint::MAX, + myhwattr.as_ptr() as *const c_char, + myhwattr.as_bytes().len() as c_uint, + ); + if LLVMIsStringAttribute(attr) { + LLVMRemoveStringAttributeAtIndex( + lf, + c_uint::MAX, + myhwattr.as_ptr() as *const c_char, + myhwattr.as_bytes().len() as c_uint, + ); + } else { + LLVMRustRemoveEnumAttributeAtIndex( + lf, + c_uint::MAX, + AttributeKind::SanitizeHWAddress, + ); + } + } else { + break; + } + } + if ad.contains(&AutoDiff::PrintModAfterEnzyme) { + LLVMDumpModule(llmod); + } + } + + if ad.contains(&AutoDiff::NoModOptAfter) || !differentiate { + trace!("Skipping module optimization after automatic differentiation"); + } else { + if let Some(opt_level) = config.opt_level { + let opt_stage = match cgcx.lto { + Lto::Fat => llvm::OptStage::PreLinkFatLTO, + Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO, + _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, + _ => llvm::OptStage::PreLinkNoLTO, + }; + let mut first_run = false; + dbg!("Running Module Optimization after differentiation"); + if ad.contains(&AutoDiff::NoVecUnroll) { + // disables vectorization and loop unrolling + first_run = true; + } + if ad.contains(&AutoDiff::AltPipeline) { + dbg!("Running first postAD optimization"); + first_run = true; + } + let noop = false; + unsafe { + llvm_optimize( + cgcx, + diag_handler.handle(), + module, + config, + opt_level, + opt_stage, + first_run, + noop, + )? + }; + } + if ad.contains(&AutoDiff::AltPipeline) { + dbg!("Running Second postAD optimization"); + if let Some(opt_level) = config.opt_level { + let opt_stage = match cgcx.lto { + Lto::Fat => llvm::OptStage::PreLinkFatLTO, + Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO, + _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, + _ => llvm::OptStage::PreLinkNoLTO, + }; + let mut first_run = false; + dbg!("Running Module Optimization after differentiation"); + if ad.contains(&AutoDiff::NoVecUnroll) { + // enables vectorization and loop unrolling + first_run = false; + } + let noop = false; + unsafe { + llvm_optimize( + cgcx, + diag_handler.handle(), + module, + config, + opt_level, + opt_stage, + first_run, + noop, + )? + }; + } + } + } + + if ad.contains(&AutoDiff::PrintModAfterOpts) { + unsafe { + LLVMDumpModule(llmod); + } + } + + Ok(()) +} + // Unsafe due to LLVM calls. pub(crate) unsafe fn optimize( cgcx: &CodegenContext, @@ -621,6 +1324,47 @@ pub(crate) unsafe fn optimize( unsafe { llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()) }; } + // This code enables Enzyme to differentiate code containing Rust enums. + // By adding the SanitizeHWAddress attribute we prevent LLVM from Optimizing + // away the enums and allows Enzyme to understand why a value can be of different types in + // different code sections. We remove this attribute after Enzyme is done, to not affect the + // rest of the compilation. + // TODO: only enable this code when at least one function gets differentiated. + unsafe { + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let myhwv = ""; + let prevattr = LLVMRustGetEnumAttributeAtIndex( + lf, + c_uint::MAX, + AttributeKind::SanitizeHWAddress, + ); + if LLVMIsEnumAttribute(prevattr) { + let attr = LLVMCreateStringAttribute( + llcx, + myhwattr.as_ptr() as *const c_char, + myhwattr.as_bytes().len() as c_uint, + myhwv.as_ptr() as *const c_char, + myhwv.as_bytes().len() as c_uint, + ); + LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1); + } else { + LLVMRustAddEnumAttributeAtIndex( + llcx, + lf, + c_uint::MAX, + AttributeKind::SanitizeHWAddress, + ); + } + } else { + break; + } + } + } + if let Some(opt_level) = config.opt_level { let opt_stage = match cgcx.lto { Lto::Fat => llvm::OptStage::PreLinkFatLTO, @@ -628,7 +1372,19 @@ pub(crate) unsafe fn optimize( _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, _ => llvm::OptStage::PreLinkNoLTO, }; - return unsafe { llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) }; + + // Second run only relevant for AD + let first_run = true; + let noop = false; + //if ad.contains(&AutoDiff::AltPipeline) { + // noop = true; + // dbg!("Skipping PreAD optimization"); + //} else { + // noop = false; + //} + return unsafe { + llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run, noop) + }; } Ok(()) } diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 40783825cae57..4d7a403ca8bea 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -3,6 +3,8 @@ use std::ops::Deref; use std::{iter, ptr}; use libc::{c_char, c_uint}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; +use rustc_ast::expand::typetree::{FncTree, TypeTree}; use rustc_codegen_ssa::common::{IntPredicate, RealPredicate, SynchronizationScope, TypeKind}; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; use rustc_codegen_ssa::mir::place::PlaceRef; @@ -23,7 +25,7 @@ use rustc_target::abi::call::FnAbi; use rustc_target::abi::{self, Align, Size, WrappingRange}; use rustc_target::spec::{HasTargetSpec, SanitizerSet, Target}; use smallvec::SmallVec; -use tracing::{debug, instrument}; +use tracing::{debug, instrument, trace}; use crate::abi::FnAbiLlvmExt; use crate::common::Funclet; @@ -31,9 +33,205 @@ use crate::context::CodegenCx; use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, True}; use crate::type_::Type; use crate::type_of::LayoutLlvmExt; +use crate::typetree::to_enzyme_typetree; use crate::value::Value; use crate::{attributes, llvm_util}; +pub(crate) fn add_tt2<'ll>( + llmod: &'ll llvm::Module, + llcx: &'ll llvm::Context, + fn_def: &'ll Value, + tt: FncTree, +) { + let inputs = tt.args; + let ret_tt: TypeTree = tt.ret; + let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + let attr_name = "enzyme_type"; + let c_attr_name = std::ffi::CString::new(attr_name).unwrap(); + for (i, &ref input) in inputs.iter().enumerate() { + let c_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + let c_str = unsafe { llvm::EnzymeTypeTreeToString(c_tt.inner) }; + let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) }; + unsafe { + let attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + llvm::LLVMRustAddFncParamAttr(fn_def, i as u32, attr); + } + unsafe { llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()) }; + } + let ret_attr = unsafe { + let c_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); + let c_str = llvm::EnzymeTypeTreeToString(c_tt.inner); + let c_str = std::ffi::CStr::from_ptr(c_str); + let attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); + attr + }; + unsafe { + llvm::LLVMRustAddRetFncAttr(fn_def, ret_attr); + } +} + +#[allow(unused)] +pub(crate) fn add_opt_dbg_helper<'ll>( + llmod: &'ll llvm::Module, + llcx: &'ll llvm::Context, + val: &'ll Value, + attrs: AutoDiffAttrs, + i: usize, +) { + let inputs = attrs.input_activity; + let outputs = attrs.ret_activity; + let ad_name = match attrs.mode { + DiffMode::Forward => "__enzyme_fwddiff", + DiffMode::Reverse => "__enzyme_autodiff", + DiffMode::ForwardFirst => "__enzyme_fwddiff", + DiffMode::ReverseFirst => "__enzyme_autodiff", + _ => panic!("Why are we here?"), + }; + + // Assuming that our val is the fnc square, want to generate the following llvm-ir: + // declare double @__enzyme_autodiff(...) + // + // define double @dsquare(double %x) { + // entry: + // %0 = tail call double (...) @__enzyme_autodiff(double (double)* nonnull @square, double %x) + // ret double %0 + // } + + let mut final_num_args; + unsafe { + let fn_ty = llvm::LLVMRustGetFunctionType(val); + let ret_ty = llvm::LLVMGetReturnType(fn_ty); + + // First we add the declaration of the __enzyme function + let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True); + let ad_fn = llvm::LLVMRustGetOrInsertFunction( + llmod, + ad_name.as_ptr() as *const c_char, + ad_name.len().try_into().unwrap(), + enzyme_ty, + ); + + let wrapper_name = String::from("enzyme_opt_helper_") + i.to_string().as_str(); + let wrapper_fn = llvm::LLVMRustGetOrInsertFunction( + llmod, + wrapper_name.as_ptr() as *const c_char, + wrapper_name.len().try_into().unwrap(), + fn_ty, + ); + let entry = llvm::LLVMAppendBasicBlockInContext( + llcx, + wrapper_fn, + "entry".as_ptr() as *const c_char, + ); + let builder = llvm::LLVMCreateBuilderInContext(llcx); + llvm::LLVMPositionBuilderAtEnd(builder, entry); + let num_args = llvm::LLVMCountParams(wrapper_fn); + let mut args = Vec::with_capacity(num_args as usize + 1); + args.push(val); + let enzyme_const = + llvm::LLVMMDStringInContext(llcx, "enzyme_const".as_ptr() as *const c_char, 12); + let enzyme_out = + llvm::LLVMMDStringInContext(llcx, "enzyme_out".as_ptr() as *const c_char, 10); + let enzyme_dup = + llvm::LLVMMDStringInContext(llcx, "enzyme_dup".as_ptr() as *const c_char, 10); + let enzyme_dupnoneed = + llvm::LLVMMDStringInContext(llcx, "enzyme_dupnoneed".as_ptr() as *const c_char, 16); + final_num_args = num_args * 2 + 1; + for i in 0..num_args { + let arg = llvm::LLVMGetParam(wrapper_fn, i); + let activity = inputs[i as usize]; + let (activity, duplicated): (&Value, bool) = match activity { + DiffActivity::None => panic!(), + DiffActivity::Const => (enzyme_const, false), + DiffActivity::Active => (enzyme_out, false), + DiffActivity::ActiveOnly => (enzyme_out, false), + DiffActivity::Dual => (enzyme_dup, true), + DiffActivity::DualOnly => (enzyme_dupnoneed, true), + DiffActivity::Duplicated => (enzyme_dup, true), + DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true), + DiffActivity::FakeActivitySize => (enzyme_const, false), + }; + args.push(activity); + args.push(arg); + if duplicated { + final_num_args += 1; + args.push(arg); + } + } + + // declare void @__enzyme_autodiff(...) + + // define void @enzyme_opt_helper_0(ptr %0, ptr %1) { + // call void (...) @__enzyme_autodiff(ptr @ffff, ptr %0, ptr %1) + // ret void + // } + + let call = llvm::LLVMBuildCall2( + builder, + enzyme_ty, + ad_fn, + args.as_mut_ptr(), + final_num_args as usize, + ad_name.as_ptr() as *const c_char, + ); + let void_ty = llvm::LLVMVoidTypeInContext(llcx); + if llvm::LLVMTypeOf(call) != void_ty { + llvm::LLVMBuildRet(builder, call); + } else { + llvm::LLVMBuildRetVoid(builder); + } + llvm::LLVMDisposeBuilder(builder); + + let _fnc_ok = llvm::LLVMVerifyFunction( + wrapper_fn, + llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction, + ); + } +} + +fn add_tt<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, val: &'ll Value, tt: FncTree) { + let inputs = tt.args; + let _ret: TypeTree = tt.ret; + let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + let attr_name = "enzyme_type"; + let c_attr_name = std::ffi::CString::new(attr_name).unwrap(); + for (i, &ref input) in inputs.iter().enumerate() { + let c_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + let c_str = unsafe { llvm::EnzymeTypeTreeToString(c_tt.inner) }; + let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) }; + unsafe { + let attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + llvm::LLVMRustAddParamAttr(val, i as u32, attr); + } + unsafe { llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()) }; + } +} + // All Builders must have an llfn associated with them #[must_use] pub(crate) struct Builder<'a, 'll, 'tcx> { @@ -949,11 +1147,14 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align: Align, size: &'ll Value, flags: MemFlags, + //FIXME(ZuseZ4): re-enable once autodiff middle-end is merged + //tt: Option, ) { + let tt = None; assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported"); let size = self.intcast(size, self.type_isize(), false); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let val = unsafe { llvm::LLVMRustBuildMemCpy( self.llbuilder, dst, @@ -962,7 +1163,15 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align.bytes() as c_uint, size, is_volatile, - ); + ) + }; + + if let Some(tt) = tt { + let llmod = self.cx.llmod; + let llcx = self.cx.llcx; + add_tt(llmod, llcx, val, tt); + } else { + trace!("builder: no tt"); } } @@ -974,11 +1183,14 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align: Align, size: &'ll Value, flags: MemFlags, + //FIXME(ZuseZ4): re-enable once autodiff middle-end is merged + //tt: Option, ) { + let tt = None; assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memmove not supported"); let size = self.intcast(size, self.type_isize(), false); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let val = unsafe { llvm::LLVMRustBuildMemMove( self.llbuilder, dst, @@ -987,7 +1199,13 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align.bytes() as c_uint, size, is_volatile, - ); + ) + }; + + if let Some(tt) = tt { + let llmod = self.cx.llmod; + let llcx = self.cx.llcx; + add_tt(llmod, llcx, val, tt); } } @@ -998,10 +1216,13 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { size: &'ll Value, align: Align, flags: MemFlags, + //FIXME(ZuseZ4): re-enable once autodiff middle-end is merged + //tt: Option, ) { + let tt = None; assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memset not supported"); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let val = unsafe { llvm::LLVMRustBuildMemSet( self.llbuilder, ptr, @@ -1009,7 +1230,13 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { fill_byte, size, is_volatile, - ); + ) + }; + + if let Some(tt) = tt { + let llmod = self.cx.llmod; + let llcx = self.cx.llcx; + add_tt(llmod, llcx, val, tt); } } diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 01aae24ab56c2..69f440b73ef23 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -705,6 +705,11 @@ impl<'ll, 'tcx> MiscMethods<'tcx> for CodegenCx<'ll, 'tcx> { None } } + + // TODO(ZuseZ4): I think we can drop this and construct the empty vec on the fly? + fn create_autodiff(&self) -> Vec { + return vec![]; + } } impl<'ll> CodegenCx<'ll, '_> { diff --git a/compiler/rustc_codegen_llvm/src/errors.rs b/compiler/rustc_codegen_llvm/src/errors.rs index bb481d2a30856..7caac4ff3101e 100644 --- a/compiler/rustc_codegen_llvm/src/errors.rs +++ b/compiler/rustc_codegen_llvm/src/errors.rs @@ -104,6 +104,11 @@ impl Diagnostic<'_, G> for TargetFeatureDisableOrEnable<'_ } } +#[derive(Diagnostic)] +#[diag(codegen_llvm_autodiff_without_lto)] +#[note] +pub(crate) struct AutoDiffWithoutLTO; + #[derive(Diagnostic)] #[diag(codegen_llvm_lto_disallowed)] pub(crate) struct LtoDisallowed; @@ -146,6 +151,8 @@ pub enum LlvmError<'a> { PrepareThinLtoModule, #[diag(codegen_llvm_parse_bitcode)] ParseBitcode, + #[diag(codegen_llvm_prepare_autodiff)] + PrepareAutoDiff { src: String, target: String, error: String }, } pub(crate) struct WithLlvmError<'a>(pub LlvmError<'a>, pub String); @@ -167,6 +174,7 @@ impl Diagnostic<'_, G> for WithLlvmError<'_> { } PrepareThinLtoModule => fluent::codegen_llvm_prepare_thin_lto_module_with_llvm_err, ParseBitcode => fluent::codegen_llvm_parse_bitcode_with_llvm_err, + PrepareAutoDiff { .. } => fluent::codegen_llvm_prepare_autodiff_with_llvm_err, }; self.0 .into_diag(dcx, level) diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 6a303e1e6024b..00bdb7f52ad5c 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -27,22 +27,25 @@ use std::mem::ManuallyDrop; use back::owned_target_machine::OwnedTargetMachine; use back::write::{create_informational_target_machine, create_target_machine}; -use errors::ParseTargetMachineConfig; +use errors::{AutoDiffWithoutLTO, ParseTargetMachineConfig}; +#[allow(unused_imports)] +use llvm::TypeTree; pub use llvm_util::target_features; use rustc_ast::expand::allocator::AllocatorKind; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use rustc_codegen_ssa::back::write::{ CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryConfig, TargetMachineFactoryFn, }; use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{CodegenResults, CompiledModule, ModuleCodegen}; -use rustc_data_structures::fx::FxIndexMap; +use rustc_data_structures::fx::{FxHashMap, FxIndexMap}; use rustc_errors::{DiagCtxtHandle, ErrorGuaranteed, FatalError}; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; use rustc_middle::ty::TyCtxt; use rustc_middle::util::Providers; -use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest}; +use rustc_session::config::{Lto, OptLevel, OutputFilenames, PrintKind, PrintRequest}; use rustc_session::Session; use rustc_span::symbol::Symbol; @@ -69,6 +72,7 @@ mod debuginfo; mod declare; mod errors; mod intrinsic; +mod typetree; // The following is a workaround that replaces `pub mod llvm;` and that fixes issue 53912. #[path = "llvm/mod.rs"] @@ -164,6 +168,7 @@ impl WriteBackendMethods for LlvmCodegenBackend { type TargetMachineError = crate::errors::LlvmError<'static>; type ThinData = back::lto::ThinData; type ThinBuffer = back::lto::ThinBuffer; + type TypeTree = DiffTypeTree; fn print_pass_timings(&self) { unsafe { let mut size = 0; @@ -250,6 +255,26 @@ impl WriteBackendMethods for LlvmCodegenBackend { fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer) { (module.name, back::lto::ModuleBuffer::new(module.module_llvm.llmod())) } + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result<(), FatalError> { + if cgcx.lto != Lto::Fat { + let dcx = cgcx.create_dcx(); + return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutLTO {})); + } + unsafe { back::write::differentiate(module, cgcx, diff_fncs, typetrees, config) } + } + + // The typetrees contain all information, their order therefore is irrelevant. + #[allow(rustc::potential_query_instability)] + fn typetrees(module: &mut Self::Module) -> FxHashMap { + module.typetrees.drain().collect() + } } unsafe impl Send for LlvmCodegenBackend {} // Llvm is on a per-thread basis @@ -405,6 +430,13 @@ impl CodegenBackend for LlvmCodegenBackend { } } +#[derive(Clone, Debug)] +pub struct DiffTypeTree { + pub ret_tt: TypeTree, + pub input_tt: Vec, +} + +#[allow(dead_code)] pub struct ModuleLlvm { llcx: &'static mut llvm::Context, llmod_raw: *const llvm::Module, @@ -412,6 +444,7 @@ pub struct ModuleLlvm { // This field is `ManuallyDrop` because it is important that the `TargetMachine` // is disposed prior to the `Context` being disposed otherwise UAFs can occur. tm: ManuallyDrop, + typetrees: FxHashMap, } unsafe impl Send for ModuleLlvm {} @@ -426,6 +459,7 @@ impl ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(create_target_machine(tcx, mod_name)), + typetrees: Default::default(), } } } @@ -438,6 +472,7 @@ impl ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(create_informational_target_machine(tcx.sess, false)), + typetrees: Default::default(), } } } @@ -459,7 +494,12 @@ impl ModuleLlvm { } }; - Ok(ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm) }) + Ok(ModuleLlvm { + llmod_raw, + llcx, + tm: ManuallyDrop::new(tm), + typetrees: Default::default(), + }) } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs new file mode 100644 index 0000000000000..887aa8689d8dc --- /dev/null +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -0,0 +1,807 @@ +#![allow(non_camel_case_types)] + +use libc::{c_char, c_uint, size_t}; +use rustc_ast::expand::autodiff_attrs::DiffActivity; +use tracing::trace; + +use super::ffi::*; + +extern "C" { + pub fn LLVMRustAddFncParamAttr<'a>(F: &'a Value, index: c_uint, Attr: &'a Attribute); + + pub fn LLVMRustAddRetFncAttr(F: &Value, attr: &Attribute); + pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool; + pub fn LLVMRustEraseInstBefore(BB: &BasicBlock, I: &Value); + pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>; + pub fn LLVMRustDIGetInstMetadata(I: &Value) -> &Metadata; + pub fn LLVMRustEraseInstFromParent(V: &Value); + pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value; + pub fn LLVMGetReturnType(T: &Type) -> &Type; + pub fn LLVMRustIsStructType(T: &Type) -> bool; + pub fn LLVMDumpModule(M: &Module); + pub fn LLVMCountStructElementTypes(T: &Type) -> c_uint; + pub fn LLVMVerifyFunction(V: &Value, action: LLVMVerifierFailureAction) -> bool; + pub fn LLVMGetParams(Fnc: &Value, parms: *mut &Value); + pub fn LLVMBuildCall2<'a>( + arg1: &Builder<'a>, + ty: &Type, + func: &Value, + args: *mut &Value, + num_args: size_t, + name: *const c_char, + ) -> &'a Value; + pub fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>; + pub fn LLVMGetNextFunction(V: &Value) -> Option<&Value>; + pub fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>; + pub fn LLVMRustGetFunctionType(fnc: &Value) -> &Type; + + pub fn LLVMRemoveStringAttributeAtIndex(F: &Value, Idx: c_uint, K: *const c_char, KLen: c_uint); + pub fn LLVMGetStringAttributeAtIndex( + F: &Value, + Idx: c_uint, + K: *const c_char, + KLen: c_uint, + ) -> &Attribute; + pub fn LLVMIsEnumAttribute(A: &Attribute) -> bool; + pub fn LLVMIsStringAttribute(A: &Attribute) -> bool; + pub fn LLVMRustAddEnumAttributeAtIndex( + C: &Context, + V: &Value, + index: c_uint, + attr: AttributeKind, + ); + pub fn LLVMRustRemoveEnumAttributeAtIndex(V: &Value, index: c_uint, attr: AttributeKind); + pub fn LLVMRustGetEnumAttributeAtIndex( + V: &Value, + index: c_uint, + attr: AttributeKind, + ) -> &Attribute; + + pub fn LLVMRustAddParamAttr<'a>(Instr: &'a Value, index: c_uint, Attr: &'a Attribute); + +} + +#[repr(C)] +pub enum LLVMVerifierFailureAction { + LLVMAbortProcessAction, + LLVMPrintMessageAction, + LLVMReturnStatusAction, +} + +pub(crate) unsafe fn enzyme_rust_forward_diff( + logic_ref: EnzymeLogicRef, + type_analysis: EnzymeTypeAnalysisRef, + fnc: &Value, + input_diffactivity: Vec, + ret_diffactivity: DiffActivity, + void_ret: bool, +) -> (&Value, Vec) { + let ret_activity = cdiffe_from(ret_diffactivity); + assert!(ret_activity != CDIFFE_TYPE::DFT_OUT_DIFF); + let mut input_activity: Vec = vec![]; + for input in input_diffactivity { + let act = cdiffe_from(input); + assert!( + act == CDIFFE_TYPE::DFT_CONSTANT + || act == CDIFFE_TYPE::DFT_DUP_ARG + || act == CDIFFE_TYPE::DFT_DUP_NONEED + ); + input_activity.push(act); + } + + // if we have void ret, this must be false; + let ret_primary_ret = if void_ret { + false + } else { + match ret_activity { + CDIFFE_TYPE::DFT_CONSTANT => true, + CDIFFE_TYPE::DFT_DUP_ARG => true, + CDIFFE_TYPE::DFT_DUP_NONEED => false, + _ => panic!("Implementation error in enzyme_rust_forward_diff."), + } + }; + trace!("ret_primary_ret: {}", &ret_primary_ret); + + // We don't support volatile / extern / (global?) values. + // Just because I didn't had time to test them, and it seems less urgent. + let args_uncacheable = vec![0; input_activity.len()]; + let num_fnc_args = unsafe { LLVMCountParams(fnc) }; + trace!("num_fnc_args: {}", num_fnc_args); + trace!("input_activity.len(): {}", input_activity.len()); + assert!(num_fnc_args == input_activity.len() as u32); + + let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; + + let mut known_values = vec![kv_tmp; input_activity.len()]; + + let tree_tmp = TypeTree::new(); + let mut args_tree = vec![tree_tmp.inner; input_activity.len()]; + + let ret_tt = TypeTree::new(); + let dummy_type = CFnTypeInfo { + Arguments: args_tree.as_mut_ptr(), + Return: ret_tt.inner, + KnownValues: known_values.as_mut_ptr(), + }; + + trace!("ret_activity: {}", &ret_activity); + for i in &input_activity { + trace!("input_activity i: {}", &i); + } + trace!("before calling Enzyme"); + let res = unsafe { + EnzymeCreateForwardDiff( + logic_ref, // Logic + std::ptr::null(), + std::ptr::null(), + fnc, + ret_activity, // LLVM function, return type + input_activity.as_ptr(), + input_activity.len(), // constant arguments + type_analysis, // type analysis struct + ret_primary_ret as u8, + CDerivativeMode::DEM_ForwardMode, // return value, dret_used, top_level which was 1 + 1, // free memory + 1, // vector mode width + Option::None, + dummy_type, // additional_arg, type info (return + args) + args_uncacheable.as_ptr(), + args_uncacheable.len(), // uncacheable arguments + std::ptr::null_mut(), // write augmented function to this + ) + }; + trace!("after calling Enzyme"); + (res, vec![]) +} + +pub(crate) unsafe fn enzyme_rust_reverse_diff( + logic_ref: EnzymeLogicRef, + type_analysis: EnzymeTypeAnalysisRef, + fnc: &Value, + rust_input_activity: Vec, + ret_activity: DiffActivity, +) -> (&Value, Vec) { + let (primary_ret, ret_activity) = match ret_activity { + DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT), + DiffActivity::Active => (true, CDIFFE_TYPE::DFT_OUT_DIFF), + DiffActivity::ActiveOnly => (false, CDIFFE_TYPE::DFT_OUT_DIFF), + DiffActivity::None => (false, CDIFFE_TYPE::DFT_CONSTANT), + _ => panic!("Invalid return activity"), + }; + // This only is needed for split-mode AD, which we don't support. + // See Julia: + // https://github.com/EnzymeAD/Enzyme.jl/blob/a511e4e6979d6161699f5c9919d49801c0764a09/src/compiler.jl#L3132 + // https://github.com/EnzymeAD/Enzyme.jl/blob/a511e4e6979d6161699f5c9919d49801c0764a09/src/compiler.jl#L3092 + let diff_ret = false; + + let mut primal_sizes = vec![]; + let mut input_activity: Vec = vec![]; + for (i, &x) in rust_input_activity.iter().enumerate() { + if is_size(x) { + primal_sizes.push(i); + input_activity.push(CDIFFE_TYPE::DFT_CONSTANT); + continue; + } + input_activity.push(cdiffe_from(x)); + } + + // We don't support volatile / extern / (global?) values. + // Just because I didn't had time to test them, and it seems less urgent. + let args_uncacheable = vec![0; input_activity.len()]; + let num_fnc_args = unsafe { LLVMCountParams(fnc) }; + println!("num_fnc_args: {}", num_fnc_args); + println!("input_activity.len(): {}", input_activity.len()); + assert!(num_fnc_args == input_activity.len() as u32); + let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; + + let mut known_values = vec![kv_tmp; input_activity.len()]; + + let tree_tmp = TypeTree::new(); + let mut args_tree = vec![tree_tmp.inner; input_activity.len()]; + let ret_tt = TypeTree::new(); + let dummy_type = CFnTypeInfo { + Arguments: args_tree.as_mut_ptr(), + Return: ret_tt.inner, + KnownValues: known_values.as_mut_ptr(), + }; + + trace!("primary_ret: {}", &primary_ret); + trace!("ret_activity: {}", &ret_activity); + for i in &input_activity { + trace!("input_activity i: {}", &i); + } + trace!("before calling Enzyme"); + let res = unsafe { + EnzymeCreatePrimalAndGradient( + logic_ref, // Logic + std::ptr::null(), + std::ptr::null(), + fnc, + ret_activity, // LLVM function, return type + input_activity.as_ptr(), + input_activity.len(), // constant arguments + type_analysis, // type analysis struct + primary_ret as u8, + diff_ret as u8, //0 + CDerivativeMode::DEM_ReverseModeCombined, // return value, dret_used, top_level which was 1 + 1, // vector mode width + 1, // free memory + Option::None, + 0, // do not force anonymous tape + dummy_type, // additional_arg, type info (return + args) + args_uncacheable.as_ptr(), + args_uncacheable.len(), // uncacheable arguments + std::ptr::null_mut(), // write augmented function to this + 0, + ) + }; + trace!("after calling Enzyme"); + (res, primal_sizes) +} + +#[cfg(not(llvm_enzyme))] +pub use self::Fallback_AD::*; + +#[cfg(not(llvm_enzyme))] +pub mod Fallback_AD { + #![allow(unused_variables)] + use super::*; + + pub fn EnzymeNewTypeTree() -> CTypeTreeRef { + unimplemented!() + } + pub fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) { + unimplemented!() + } + pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8) { + unimplemented!() + } + pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64) { + unimplemented!() + } + + pub fn set_inline(val: bool) { + unimplemented!() + } + pub fn set_runtime_activity_check(check: bool) { + unimplemented!() + } + pub fn set_max_int_offset(offset: u64) { + unimplemented!() + } + pub fn set_max_type_offset(offset: u64) { + unimplemented!() + } + pub fn set_max_type_depth(depth: u64) { + unimplemented!() + } + pub fn set_print_perf(print: bool) { + unimplemented!() + } + pub fn set_print_activity(print: bool) { + unimplemented!() + } + pub fn set_print_type(print: bool) { + unimplemented!() + } + pub fn set_print(print: bool) { + unimplemented!() + } + pub fn set_strict_aliasing(strict: bool) { + unimplemented!() + } + pub fn set_loose_types(loose: bool) { + unimplemented!() + } + + pub fn EnzymeCreatePrimalAndGradient<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8, // &'a Builder<'_>, + _callerCtx: *const u8, // &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + dretUsed: u8, + mode: CDerivativeMode, + width: ::std::os::raw::c_uint, + freeMemory: u8, + additionalArg: Option<&Type>, + forceAnonymousTape: u8, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + AtomicAdd: u8, + ) -> &'a Value { + unimplemented!() + } + pub fn EnzymeCreateForwardDiff<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8, // &'a Builder<'_>, + _callerCtx: *const u8, // &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + mode: CDerivativeMode, + freeMemory: u8, + width: ::std::os::raw::c_uint, + additionalArg: Option<&Type>, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + ) -> &'a Value { + unimplemented!() + } + pub type CustomRuleType = ::std::option::Option< + unsafe extern "C" fn( + direction: ::std::os::raw::c_int, + ret: CTypeTreeRef, + args: *mut CTypeTreeRef, + known_values: *mut IntList, + num_args: size_t, + fnc: &Value, + ta: *const ::std::os::raw::c_void, + ) -> u8, + >; + extern "C" { + pub fn CreateTypeAnalysis( + Log: EnzymeLogicRef, + customRuleNames: *mut *mut ::std::os::raw::c_char, + customRules: *mut CustomRuleType, + numRules: size_t, + ) -> EnzymeTypeAnalysisRef; + } + //pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef) { unimplemented!() } + pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef) { + unimplemented!() + } + pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef { + unimplemented!() + } + pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef) { + unimplemented!() + } + pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef) { + unimplemented!() + } + + pub fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef { + unimplemented!() + } + pub fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef { + unimplemented!() + } + pub fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool { + unimplemented!() + } + pub fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) { + unimplemented!() + } + pub fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) { + unimplemented!() + } + pub fn EnzymeTypeTreeShiftIndiciesEq( + arg1: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ) { + unimplemented!() + } + pub fn EnzymeTypeTreeToStringFree(arg1: *const c_char) { + unimplemented!() + } + pub fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char { + unimplemented!() + } +} + +// Enzyme specific, but doesn't require Enzyme to be build +pub use self::Shared_AD::*; +pub mod Shared_AD { + // Depending on the AD backend (Enzyme or Fallback), some functions might or might not be + // unsafe. So we just allways call them in an unsafe context. + #![allow(unused_unsafe)] + #![allow(unused_variables)] + + use core::fmt; + use std::ffi::{CStr, CString}; + + use libc::size_t; + use rustc_ast::expand::autodiff_attrs::DiffActivity; + + use super::Context; + #[cfg(llvm_enzyme)] + use super::Enzyme_AD::*; + #[cfg(not(llvm_enzyme))] + use super::Fallback_AD::*; + #[repr(u32)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum CDIFFE_TYPE { + DFT_OUT_DIFF = 0, + DFT_DUP_ARG = 1, + DFT_CONSTANT = 2, + DFT_DUP_NONEED = 3, + } + + impl fmt::Display for CDIFFE_TYPE { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let value = match self { + CDIFFE_TYPE::DFT_OUT_DIFF => "DFT_OUT_DIFF", + CDIFFE_TYPE::DFT_DUP_ARG => "DFT_DUP_ARG", + CDIFFE_TYPE::DFT_CONSTANT => "DFT_CONSTANT", + CDIFFE_TYPE::DFT_DUP_NONEED => "DFT_DUP_NONEED", + }; + write!(f, "{}", value) + } + } + + pub fn cdiffe_from(act: DiffActivity) -> CDIFFE_TYPE { + return match act { + DiffActivity::None => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Const => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Active => CDIFFE_TYPE::DFT_OUT_DIFF, + DiffActivity::ActiveOnly => CDIFFE_TYPE::DFT_OUT_DIFF, + DiffActivity::Dual => CDIFFE_TYPE::DFT_DUP_ARG, + DiffActivity::DualOnly => CDIFFE_TYPE::DFT_DUP_NONEED, + DiffActivity::Duplicated => CDIFFE_TYPE::DFT_DUP_ARG, + DiffActivity::DuplicatedOnly => CDIFFE_TYPE::DFT_DUP_NONEED, + DiffActivity::FakeActivitySize => panic!("Implementation error"), + }; + } + + pub fn is_size(act: DiffActivity) -> bool { + return act == DiffActivity::FakeActivitySize; + } + + #[repr(u32)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum CDerivativeMode { + DEM_ForwardMode = 0, + DEM_ReverseModePrimal = 1, + DEM_ReverseModeGradient = 2, + DEM_ReverseModeCombined = 3, + DEM_ForwardModeSplit = 4, + } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct EnzymeOpaqueTypeAnalysis { + _unused: [u8; 0], + } + pub type EnzymeTypeAnalysisRef = *mut EnzymeOpaqueTypeAnalysis; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct EnzymeOpaqueLogic { + _unused: [u8; 0], + } + pub type EnzymeLogicRef = *mut EnzymeOpaqueLogic; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct EnzymeOpaqueAugmentedReturn { + _unused: [u8; 0], + } + pub type EnzymeAugmentedReturnPtr = *mut EnzymeOpaqueAugmentedReturn; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct IntList { + pub data: *mut i64, + pub size: size_t, + } + #[repr(u32)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum CConcreteType { + DT_Anything = 0, + DT_Integer = 1, + DT_Pointer = 2, + DT_Half = 3, + DT_Float = 4, + DT_Double = 5, + DT_Unknown = 6, + } + + pub type CTypeTreeRef = *mut EnzymeTypeTree; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct EnzymeTypeTree { + _unused: [u8; 0], + } + pub struct TypeTree { + pub inner: CTypeTreeRef, + } + + impl TypeTree { + pub fn new() -> TypeTree { + let inner = unsafe { EnzymeNewTypeTree() }; + TypeTree { inner } + } + + #[must_use] + pub fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree { + let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) }; + TypeTree { inner } + } + + #[must_use] + pub fn only(self, idx: isize) -> TypeTree { + unsafe { + EnzymeTypeTreeOnlyEq(self.inner, idx as i64); + } + self + } + + #[must_use] + pub fn data0(self) -> TypeTree { + unsafe { + EnzymeTypeTreeData0Eq(self.inner); + } + self + } + + pub fn merge(self, other: Self) -> Self { + unsafe { + EnzymeMergeTypeTree(self.inner, other.inner); + } + drop(other); + self + } + + #[must_use] + pub fn shift( + self, + layout: &str, + offset: isize, + max_size: isize, + add_offset: usize, + ) -> Self { + let layout = CString::new(layout).unwrap(); + + unsafe { + EnzymeTypeTreeShiftIndiciesEq( + self.inner, + layout.as_ptr(), + offset as i64, + max_size as i64, + add_offset as u64, + ) + } + + self + } + } + + impl Clone for TypeTree { + fn clone(&self) -> Self { + let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) }; + TypeTree { inner } + } + } + + impl fmt::Display for TypeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let ptr = unsafe { EnzymeTypeTreeToString(self.inner) }; + let cstr = unsafe { CStr::from_ptr(ptr) }; + match cstr.to_str() { + Ok(x) => write!(f, "{}", x)?, + Err(err) => write!(f, "could not parse: {}", err)?, + } + + // delete C string pointer + unsafe { EnzymeTypeTreeToStringFree(ptr) } + + Ok(()) + } + } + + impl fmt::Debug for TypeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } + } + + impl Drop for TypeTree { + fn drop(&mut self) { + unsafe { EnzymeFreeTypeTree(self.inner) } + } + } + + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct CFnTypeInfo { + #[doc = " Types of arguments, assumed of size len(Arguments)"] + pub Arguments: *mut CTypeTreeRef, + #[doc = " Type of return"] + pub Return: CTypeTreeRef, + #[doc = " The specific constant(s) known to represented by an argument, if constant"] + pub KnownValues: *mut IntList, + } +} + +#[cfg(llvm_enzyme)] +pub use self::Enzyme_AD::*; + +// Enzyme is an optional component, so we do need to provide a fallback when it is ont getting +// compiled. We deny the usage of #[autodiff(..)] on a higher level, so a placeholder implementation +// here is completely fine. +#[cfg(llvm_enzyme)] +pub mod Enzyme_AD { + use libc::{c_char, c_void, size_t}; + + use super::*; + + extern "C" { + pub fn EnzymeNewTypeTree() -> CTypeTreeRef; + pub fn EnzymeFreeTypeTree(CTT: CTypeTreeRef); + pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); + pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64); + } + + extern "C" { + static mut MaxIntOffset: c_void; + static mut MaxTypeOffset: c_void; + static mut EnzymeMaxTypeDepth: c_void; + + static mut EnzymeRuntimeActivityCheck: c_void; + static mut EnzymePrintPerf: c_void; + static mut EnzymePrintActivity: c_void; + static mut EnzymePrintType: c_void; + static mut EnzymePrint: c_void; + static mut EnzymeStrictAliasing: c_void; + static mut looseTypeAnalysis: c_void; + static mut EnzymeInline: c_void; + } + pub fn set_runtime_activity_check(check: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeRuntimeActivityCheck), check as u8); + } + } + pub fn set_max_int_offset(offset: u64) { + let offset = offset.try_into().unwrap(); + unsafe { + EnzymeSetCLInteger(std::ptr::addr_of_mut!(MaxIntOffset), offset); + } + } + pub fn set_max_type_offset(offset: u64) { + let offset = offset.try_into().unwrap(); + unsafe { + EnzymeSetCLInteger(std::ptr::addr_of_mut!(MaxTypeOffset), offset); + } + } + pub fn set_max_type_depth(depth: u64) { + let depth = depth.try_into().unwrap(); + unsafe { + EnzymeSetCLInteger(std::ptr::addr_of_mut!(EnzymeMaxTypeDepth), depth); + } + } + pub fn set_print_perf(print: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8); + } + } + pub fn set_print_activity(print: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8); + } + } + pub fn set_print_type(print: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8); + } + } + pub fn set_print(print: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8); + } + } + pub fn set_strict_aliasing(strict: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8); + } + } + pub fn set_loose_types(loose: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8); + } + } + pub fn set_inline(val: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), val as u8); + } + } + + extern "C" { + pub fn EnzymeCreatePrimalAndGradient<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8, // &'a Builder<'_>, + _callerCtx: *const u8, // &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + dretUsed: u8, + mode: CDerivativeMode, + width: ::std::os::raw::c_uint, + freeMemory: u8, + additionalArg: Option<&Type>, + forceAnonymousTape: u8, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + AtomicAdd: u8, + ) -> &'a Value; + } + extern "C" { + pub fn EnzymeCreateForwardDiff<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8, // &'a Builder<'_>, + _callerCtx: *const u8, // &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + mode: CDerivativeMode, + freeMemory: u8, + width: ::std::os::raw::c_uint, + additionalArg: Option<&Type>, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + ) -> &'a Value; + } + pub type CustomRuleType = ::std::option::Option< + unsafe extern "C" fn( + direction: ::std::os::raw::c_int, + ret: CTypeTreeRef, + args: *mut CTypeTreeRef, + known_values: *mut IntList, + num_args: size_t, + fnc: &Value, + ta: *const ::std::os::raw::c_void, + ) -> u8, + >; + extern "C" { + pub fn CreateTypeAnalysis( + Log: EnzymeLogicRef, + customRuleNames: *mut *mut ::std::os::raw::c_char, + customRules: *mut CustomRuleType, + numRules: size_t, + ) -> EnzymeTypeAnalysisRef; + } + extern "C" { + //pub(super) fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef); + pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef); + pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef; + pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef); + pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef); + } + + extern "C" { + pub(super) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef; + pub(super) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef; + pub(super) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool; + pub(super) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64); + pub(super) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef); + pub(super) fn EnzymeTypeTreeShiftIndiciesEq( + arg1: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ); + pub fn EnzymeTypeTreeToStringFree(arg1: *const c_char); + pub fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; + } +} diff --git a/compiler/rustc_codegen_llvm/src/llvm/mod.rs b/compiler/rustc_codegen_llvm/src/llvm/mod.rs index d0db350a149e8..880f960151840 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/mod.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/mod.rs @@ -20,8 +20,10 @@ pub use self::RealPredicate::*; pub mod archive_ro; pub mod diagnostic; +pub mod enzyme_ffi; mod ffi; +pub use self::enzyme_ffi::*; pub use self::ffi::*; impl LLVMRustResult { diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs new file mode 100644 index 0000000000000..9612ac335a873 --- /dev/null +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -0,0 +1,34 @@ +use rustc_ast::expand::typetree::{Kind, TypeTree}; + +use crate::llvm; + +pub(crate) fn to_enzyme_typetree( + tree: TypeTree, + llvm_data_layout: &str, + llcx: &llvm::Context, +) -> llvm::TypeTree { + tree.0.iter().fold(llvm::TypeTree::new(), |obj, x| { + let scalar = match x.kind { + Kind::Integer => llvm::CConcreteType::DT_Integer, + Kind::Float => llvm::CConcreteType::DT_Float, + Kind::Double => llvm::CConcreteType::DT_Double, + Kind::Pointer => llvm::CConcreteType::DT_Pointer, + _ => panic!("Unknown kind {:?}", x.kind), + }; + + let tt = llvm::TypeTree::from_type(scalar, llcx).only(-1); + + let tt = if !x.child.0.is_empty() { + let inner_tt = to_enzyme_typetree(x.child.clone(), llvm_data_layout, llcx); + tt.merge(inner_tt.only(-1)) + } else { + tt + }; + + if x.offset != -1 { + obj.merge(tt.shift(llvm_data_layout, 0, x.size as isize, x.offset as usize)) + } else { + obj.merge(tt) + } + }) +} diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index feb27c148a188..2170bafcb0c6b 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -121,6 +121,7 @@ pub struct ModuleConfig { pub merge_functions: bool, pub emit_lifetime_markers: bool, pub llvm_plugins: Vec, + pub autodiff: Vec, } impl ModuleConfig { @@ -281,6 +282,7 @@ impl ModuleConfig { emit_lifetime_markers: sess.emit_lifetime_markers(), llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]), + autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]), } } diff --git a/compiler/rustc_codegen_ssa/src/traits/misc.rs b/compiler/rustc_codegen_ssa/src/traits/misc.rs index 40a49b3e1b578..608a4e038c91b 100644 --- a/compiler/rustc_codegen_ssa/src/traits/misc.rs +++ b/compiler/rustc_codegen_ssa/src/traits/misc.rs @@ -27,4 +27,6 @@ pub trait MiscMethods<'tcx>: BackendTypes { fn apply_target_cpu_attr(&self, llfn: Self::Function); /// Declares the extern "C" main function for the entry point. Returns None if the symbol already exists. fn declare_c_main(&self, fn_type: Self::Type) -> Option; + // TODO: Manuel: I think we can drop this and construct the empty vec on the fly? + fn create_autodiff(&self) -> Vec; } diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index aabe9e33c4aa1..5846edbd441c3 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -1,3 +1,5 @@ +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; +use rustc_data_structures::fx::FxHashMap; use rustc_errors::{DiagCtxtHandle, FatalError}; use rustc_middle::dep_graph::WorkProduct; @@ -12,6 +14,7 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { type ModuleBuffer: ModuleBufferMethods; type ThinData: Send + Sync; type ThinBuffer: ThinBufferMethods; + type TypeTree: Clone; /// Merge all modules into main_module and returning it fn run_link( @@ -61,6 +64,15 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { want_summary: bool, ) -> (String, Self::ThinBuffer); fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer); + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result<(), FatalError>; + fn typetrees(module: &mut Self::Module) -> FxHashMap; } pub trait ThinBufferMethods: Send + Sync { diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index ed12318c88dae..8cf15e6e720d0 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -342,6 +342,10 @@ extern "C" void LLVMRustAddFunctionAttributes(LLVMValueRef Fn, unsigned Index, AddAttributes(F, Index, Attrs, AttrsLen); } +extern "C" bool LLVMRustIsStructType(LLVMTypeRef Ty) { + return unwrap(Ty)->isStructTy(); +} + extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr, unsigned Index, LLVMAttributeRef *Attrs, @@ -350,11 +354,44 @@ extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr, AddAttributes(Call, Index, Attrs, AttrsLen); } +extern "C" LLVMValueRef LLVMRustGetTerminator(LLVMBasicBlockRef BB) { + Instruction *ret = unwrap(BB)->getTerminator(); + return wrap(ret); +} + +extern "C" void LLVMRustEraseInstFromParent(LLVMValueRef Instr) { + if (auto I = dyn_cast(unwrap(Instr))) { + I->eraseFromParent(); + } +} + +extern "C" LLVMTypeRef LLVMRustGetFunctionType(LLVMValueRef Fn) { + auto Ftype = unwrap(Fn)->getFunctionType(); + return wrap(Ftype); +} + +extern "C" void LLVMRustRemoveEnumAttributeAtIndex(LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + LLVMRemoveEnumAttributeAtIndex(F, index, fromRust(RustAttr)); +} + extern "C" LLVMAttributeRef LLVMRustCreateAttrNoValue(LLVMContextRef C, LLVMRustAttribute RustAttr) { return wrap(Attribute::get(*unwrap(C), fromRust(RustAttr))); } +extern "C" void LLVMRustAddEnumAttributeAtIndex(LLVMContextRef C, + LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + LLVMAddAttributeAtIndex(F, index, LLVMRustCreateAttrNoValue(C, RustAttr)); +} + +extern "C" LLVMAttributeRef +LLVMRustGetEnumAttributeAtIndex(LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + return LLVMGetEnumAttributeAtIndex(F, index, fromRust(RustAttr)); +} + extern "C" LLVMAttributeRef LLVMRustCreateAlignmentAttr(LLVMContextRef C, uint64_t Bytes) { return wrap(Attribute::getWithAlignment(*unwrap(C), llvm::Align(Bytes))); @@ -878,6 +915,66 @@ extern "C" bool LLVMRustHasModuleFlag(LLVMModuleRef M, const char *Name, return unwrap(M)->getModuleFlag(StringRef(Name, Len)) != nullptr; } +extern "C" LLVMValueRef LLVMRustGetLastInstruction(LLVMBasicBlockRef BB) { + auto Point = unwrap(BB)->rbegin(); + if (Point != unwrap(BB)->rend()) + return wrap(&*Point); + return nullptr; +} + +extern "C" void LLVMRustEraseInstBefore(LLVMBasicBlockRef bb, LLVMValueRef I) { + auto &BB = *unwrap(bb); + auto &Inst = *unwrap(I); + auto It = BB.begin(); + while (&*It != &Inst) + ++It; + assert(It != BB.end()); + // Delete in rev order to ensure no dangling references. + while (It != BB.begin()) { + auto Prev = std::prev(It); + It->eraseFromParent(); + It = Prev; + } + It->eraseFromParent(); +} + +extern "C" bool LLVMRustHasMetadata(LLVMValueRef inst, unsigned kindID) { + if (auto *I = dyn_cast(unwrap(inst))) { + return I->hasMetadata(kindID); + } + return false; +} + +extern "C" void LLVMRustAddFncParamAttr(LLVMValueRef F, unsigned i, + LLVMAttributeRef RustAttr) { + if (auto *Fn = dyn_cast(unwrap(F))) { + Fn->addParamAttr(i, unwrap(RustAttr)); + } +} + +extern "C" void LLVMRustAddRetFncAttr(LLVMValueRef F, + LLVMAttributeRef RustAttr) { + if (auto *Fn = dyn_cast(unwrap(F))) { + Fn->addRetAttr(unwrap(RustAttr)); + } +} + +extern "C" LLVMMetadataRef LLVMRustDIGetInstMetadata(LLVMValueRef x) { + if (auto *I = dyn_cast(unwrap(x))) { + // auto *MD = I->getMetadata(LLVMContext::MD_dbg); + auto *MD = I->getDebugLoc().getAsMDNode(); + return wrap(MD); + } + return nullptr; +} + +extern "C" void LLVMRustAddParamAttr(LLVMValueRef call, unsigned i, + LLVMAttributeRef RustAttr) { + if (auto *CI = dyn_cast(unwrap(call))) { + CI->addParamAttr(i, unwrap(RustAttr)); + } +} + extern "C" void LLVMRustGlobalAddMetadata(LLVMValueRef Global, unsigned Kind, LLVMMetadataRef MD) { unwrap(Global)->addMetadata(Kind, *unwrap(MD)); diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index f3e3b36111c54..67d98c9f6dfca 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -194,6 +194,53 @@ impl Default for CoverageLevel { } } +/// The different settings that the `-Z ad` flag can have. +#[derive(Clone, Copy, PartialEq, Hash, Debug)] +pub enum AutoDiff { + /// Print TypeAnalysis information + PrintTA, + /// Print ActivityAnalysis Information + PrintAA, + /// Print Performance Warnings from Enzyme + PrintPerf, + /// Combines the three print flags above. + Print, + /// Print the whole module, before running opts. + PrintModBefore, + /// Print the whole module just before we pass it to Enzyme. + /// For Debug purpose, prefer the OPT flag below + PrintModAfterOpts, + /// Print the module after Enzyme differentiated everything. + PrintModAfterEnzyme, + + /// Enzyme's loose type debug helper (can cause incorrect gradients) + LooseTypes, + /// Output a Module using __enzyme calls to prepare it for opt + enzyme pass usage + OPT, + + /// TypeTree options + /// TODO: Figure out how to let users construct these, + /// or whether we want to leave this option in the first place. + TTWidth(u64), + TTDepth(u64), + + /// More flags + NoModOptAfter, + /// Tell Enzyme to run LLVM Opts on each function it generated. By default off, + /// since we already optimize the whole module after Enzyme is done. + EnableFncOpt, + NoVecUnroll, + /// Obviously unsafe, disable the length checks that we have for shadow args. + NoSafetyChecks, + RuntimeActivity, + /// Runs Enzyme specific Inlining + Inline, + /// Runs Optimization twice after AD, and zero times after. + /// This is mainly for Benchmarking purpose to show that + /// compiler based AD has a performance benefit. TODO: fix + AltPipeline, +} + /// Settings for `-Z instrument-xray` flag. #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] pub struct InstrumentXRay { @@ -3017,7 +3064,7 @@ pub(crate) mod dep_tracking { }; use super::{ - BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions, + AutoDiff, BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions, CrateType, DebugInfo, DebugInfoCompression, ErrorOutputType, FmtDebug, FunctionReturn, InliningThreshold, InstrumentCoverage, InstrumentXRay, LinkerPluginLto, LocationDetail, LtoCli, NextSolverConfig, OomStrategy, OptLevel, OutFileName, OutputType, OutputTypes, @@ -3065,6 +3112,7 @@ pub(crate) mod dep_tracking { } impl_dep_tracking_hash_via_hash!( + AutoDiff, bool, usize, NonZero, diff --git a/compiler/rustc_session/src/config/cfg.rs b/compiler/rustc_session/src/config/cfg.rs index 0fa776ecb5c18..6b7d67e16afe2 100644 --- a/compiler/rustc_session/src/config/cfg.rs +++ b/compiler/rustc_session/src/config/cfg.rs @@ -176,6 +176,8 @@ pub(crate) fn default_configuration(sess: &Session) -> Cfg { // NOTE: These insertions should be kept in sync with // `CheckCfg::fill_well_known` below. + ins_none!(sym::autodiff_fallback); + if sess.opts.debug_assertions { ins_none!(sym::debug_assertions); } @@ -339,6 +341,7 @@ impl CheckCfg { // Don't forget to update `src/doc/rustc/src/check-cfg.md` // in the unstable book as well! + ins!(sym::autodiff_fallback, no_values); ins!(sym::debug_assertions, no_values); ins!(sym::fmt_debug, empty_values).extend(FmtDebug::all()); diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index a57dc80b3168d..a5206613c74c8 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -370,6 +370,7 @@ mod desc { pub(crate) const parse_list: &str = "a space-separated list of strings"; pub(crate) const parse_list_with_polarity: &str = "a comma-separated list of strings, with elements beginning with + or -"; + pub(crate) const parse_autodiff: &str = "various values"; pub(crate) const parse_comma_list: &str = "a comma-separated list of strings"; pub(crate) const parse_opt_comma_list: &str = parse_comma_list; pub(crate) const parse_number: &str = "a number"; @@ -978,6 +979,38 @@ mod parse { } } + pub(crate) fn parse_autodiff(slot: &mut Vec, v: Option<&str>) -> bool { + let Some(v) = v else { + *slot = vec![]; + return true; + }; + let mut v: Vec<&str> = v.split(",").collect(); + v.sort_unstable(); + for &val in v.iter() { + let variant = match val { + "PrintTA" => AutoDiff::PrintTA, + "PrintAA" => AutoDiff::PrintAA, + "PrintPerf" => AutoDiff::PrintPerf, + "Print" => AutoDiff::Print, + "PrintModBefore" => AutoDiff::PrintModBefore, + "PrintModAfterOpts" => AutoDiff::PrintModAfterOpts, + "PrintModAfterEnzyme" => AutoDiff::PrintModAfterEnzyme, + "LooseTypes" => AutoDiff::LooseTypes, + "OPT" => AutoDiff::OPT, + "NoModOptAfter" => AutoDiff::NoModOptAfter, + "EnableFncOpt" => AutoDiff::EnableFncOpt, + "NoVecUnroll" => AutoDiff::NoVecUnroll, + "NoSafetyChecks" => AutoDiff::NoSafetyChecks, + "Inline" => AutoDiff::Inline, + "AltPipeline" => AutoDiff::AltPipeline, + _ => return false, + }; + slot.push(variant); + } + + true + } + pub(crate) fn parse_instrument_coverage( slot: &mut InstrumentCoverage, v: Option<&str>, @@ -1646,6 +1679,8 @@ options! { either `loaded` or `not-loaded`."), assume_incomplete_release: bool = (false, parse_bool, [TRACKED], "make cfg(version) treat the current version as incomplete (default: no)"), + autodiff: Vec = (Vec::new(), parse_autodiff, [TRACKED], + "a list autodiff flags to enable (comma separated)"), #[rustc_lint_opt_deny_field_access("use `Session::binary_dep_depinfo` instead of this field")] binary_dep_depinfo: bool = (false, parse_bool, [TRACKED], "include artifacts (sysroot, crate dependencies) used during compilation in dep-info \ diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 418d1078900ac..6dae822f55e2d 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -477,6 +477,8 @@ symbols! { audit_that, augmented_assignments, auto_traits, + autodiff, + autodiff_fallback, automatically_derived, avx, avx512_target_feature, @@ -535,6 +537,7 @@ symbols! { cfg_accessible, cfg_attr, cfg_attr_multi, + cfg_autodiff_fallback, cfg_doctest, cfg_eval, cfg_fmt_debug, @@ -1607,6 +1610,7 @@ symbols! { rustc_allow_incoherent_impl, rustc_allowed_through_unstable_modules, rustc_attrs, + rustc_autodiff, rustc_box, rustc_builtin_macro, rustc_capture_analysis, diff --git a/config.example.toml b/config.example.toml index e9433c9c9bd08..155a9c44ebc4b 100644 --- a/config.example.toml +++ b/config.example.toml @@ -155,6 +155,9 @@ # Whether to build the clang compiler. #clang = false +# Wheter to build Enzyme as AutoDiff backend. +#enzyme = true + # Whether to enable llvm compilation warnings. #enable-warnings = false diff --git a/src/tools/enzyme b/src/tools/enzyme index 2fe5164a2423d..ce81c9c8bb517 160000 --- a/src/tools/enzyme +++ b/src/tools/enzyme @@ -1 +1 @@ -Subproject commit 2fe5164a2423dd67ef25e2c4fb204fd06362494b +Subproject commit ce81c9c8bb517b48d103975acce427ad3d348aa3