diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 05714731b9d4d..7ef8bc1797384 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -6,7 +6,6 @@ use std::fmt::{self, Display, Formatter}; use std::str::FromStr; -use crate::expand::typetree::TypeTree; use crate::expand::{Decodable, Encodable, HashStable_Generic}; use crate::ptr::P; use crate::{Ty, TyKind}; @@ -79,10 +78,6 @@ pub struct AutoDiffItem { /// The name of the function being generated pub target: String, pub attrs: AutoDiffAttrs, - /// Describe the memory layout of input types - pub inputs: Vec, - /// Describe the memory layout of the output type - pub output: TypeTree, } #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] pub struct AutoDiffAttrs { @@ -262,22 +257,14 @@ impl AutoDiffAttrs { !matches!(self.mode, DiffMode::Error | DiffMode::Source) } - pub fn into_item( - self, - source: String, - target: String, - inputs: Vec, - output: TypeTree, - ) -> AutoDiffItem { - AutoDiffItem { source, target, inputs, output, attrs: self } + pub fn into_item(self, source: String, target: String) -> AutoDiffItem { + AutoDiffItem { source, target, attrs: self } } } impl fmt::Display for AutoDiffItem { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Differentiating {} -> {}", self.source, self.target)?; - write!(f, " with attributes: {:?}", self.attrs)?; - write!(f, " with inputs: {:?}", self.inputs)?; - write!(f, " with output: {:?}", self.output) + write!(f, " with attributes: {:?}", self.attrs) } } diff --git a/compiler/rustc_codegen_llvm/messages.ftl b/compiler/rustc_codegen_llvm/messages.ftl index 63c64269eb805..93f91d58170f9 100644 --- a/compiler/rustc_codegen_llvm/messages.ftl +++ b/compiler/rustc_codegen_llvm/messages.ftl @@ -56,6 +56,10 @@ codegen_llvm_prepare_thin_lto_module_with_llvm_err = failed to prepare thin LTO codegen_llvm_run_passes = failed to run LLVM passes codegen_llvm_run_passes_with_llvm_err = failed to run LLVM passes: {$llvm_err} +codegen_llvm_prepare_autodiff = failed to prepare AutoDiff: src: {$src}, target: {$target}, {$error} +codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare AutoDiff: {$llvm_err}, src: {$src}, target: {$target}, {$error} +codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto + codegen_llvm_sanitizer_memtag_requires_mte = `-Zsanitizer=memtag` requires `-Ctarget-feature=+mte` diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 48beb9be2b2a1..8216d71c3da97 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -604,7 +604,12 @@ pub(crate) fn run_pass_manager( debug!("running the pass manager"); let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO }; let opt_level = config.opt_level.unwrap_or(config::OptLevel::No); - unsafe { write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) }?; + // We will run this again with different values in the context of automatic differentiation. + let first_run = true; + debug!("running llvm pm opt pipeline"); + unsafe { + write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?; + } debug!("lto done"); Ok(()) } diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index a65ae4df1e378..8c2939c44d858 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -4,10 +4,11 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use std::{fs, slice, str}; -use libc::{c_char, c_int, c_void, size_t}; +use libc::{c_char, c_int, c_uint, c_void, size_t}; use llvm::{ LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, }; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_codegen_ssa::back::link::ensure_removed; use rustc_codegen_ssa::back::versioned_llvm_target; use rustc_codegen_ssa::back::write::{ @@ -28,7 +29,7 @@ use rustc_session::config::{ use rustc_span::InnerSpan; use rustc_span::symbol::sym; use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo, TlsModel}; -use tracing::debug; +use tracing::{debug, trace}; use crate::back::lto::ThinBuffer; use crate::back::owned_target_machine::OwnedTargetMachine; @@ -41,7 +42,13 @@ use crate::errors::{ WithLlvmError, WriteBytecode, }; use crate::llvm::diagnostic::OptimizationDiagnosticKind::*; -use crate::llvm::{self, DiagnosticInfo, PassManager}; +use crate::llvm::{ + self, AttributeKind, DiagnosticInfo, LLVMCreateStringAttribute, LLVMGetFirstFunction, + LLVMGetNextFunction, LLVMGetStringAttributeAtIndex, LLVMIsEnumAttribute, LLVMIsStringAttribute, + LLVMRemoveStringAttributeAtIndex, LLVMRustAddEnumAttributeAtIndex, + LLVMRustAddFunctionAttributes, LLVMRustGetEnumAttributeAtIndex, + LLVMRustRemoveEnumAttributeAtIndex, PassManager, +}; use crate::type_::Type; use crate::{LlvmCodegenBackend, ModuleLlvm, base, common, llvm_util}; @@ -517,9 +524,34 @@ pub(crate) unsafe fn llvm_optimize( config: &ModuleConfig, opt_level: config::OptLevel, opt_stage: llvm::OptStage, + skip_size_increasing_opts: bool, ) -> Result<(), FatalError> { - let unroll_loops = - opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + // Enzyme: + // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized + // source code. However, benchmarks show that optimizations increasing the code size + // tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code + // and finally re-optimize the module, now with all optimizations available. + // TODO: In a future update we could figure out how to only optimize functions getting + // differentiated. + + let unroll_loops; + let vectorize_slp; + let vectorize_loop; + + if skip_size_increasing_opts { + unroll_loops = false; + vectorize_slp = false; + vectorize_loop = false; + } else { + unroll_loops = + opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + vectorize_slp = config.vectorize_slp; + vectorize_loop = config.vectorize_loop; + } + trace!( + "Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}", + unroll_loops, vectorize_slp, vectorize_loop + ); let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed(); let pgo_gen_path = get_pgo_gen_path(config); let pgo_use_path = get_pgo_use_path(config); @@ -583,8 +615,8 @@ pub(crate) unsafe fn llvm_optimize( using_thin_buffers, config.merge_functions, unroll_loops, - config.vectorize_slp, - config.vectorize_loop, + vectorize_slp, + vectorize_loop, config.no_builtins, config.emit_lifetime_markers, sanitizer_options.as_ref(), @@ -606,6 +638,113 @@ pub(crate) unsafe fn llvm_optimize( result.into_result().map_err(|()| llvm_err(dcx, LlvmError::RunLlvmPasses)) } +pub(crate) fn differentiate( + module: &ModuleCodegen, + cgcx: &CodegenContext, + diff_items: Vec, + config: &ModuleConfig, +) -> Result<(), FatalError> { + for item in &diff_items { + trace!("{}", item); + } + + let llmod = module.module_llvm.llmod(); + let llcx = &module.module_llvm.llcx; + let diag_handler = cgcx.create_dcx(); + + // Before dumping the module, we want all the tt to become part of the module. + for item in diff_items.iter() { + let name = CString::new(item.source.clone()).unwrap(); + let fn_def: Option<&llvm::Value> = + unsafe { llvm::LLVMGetNamedFunction(llmod, name.as_ptr()) }; + let fn_def = match fn_def { + Some(x) => x, + None => { + return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff { + src: item.source.clone(), + target: item.target.clone(), + error: "could not find source function".to_owned(), + })); + } + }; + let tgt_name = CString::new(item.target.clone()).unwrap(); + dbg!("Target name: {:?}", &tgt_name); + let fn_target: Option<&llvm::Value> = + unsafe { llvm::LLVMGetNamedFunction(llmod, tgt_name.as_ptr()) }; + let fn_target = match fn_target { + Some(x) => x, + None => { + return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff { + src: item.source.clone(), + target: item.target.clone(), + error: "could not find target function".to_owned(), + })); + } + }; + + crate::builder::add_opt_dbg_helper2(llmod, llcx, fn_def, fn_target, item.attrs.clone()); + } + + // We needed the SanitizeHWAddress attribute to prevent LLVM from optimizing enums in a way + // which Enzyme doesn't understand. + unsafe { + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let attr = LLVMGetStringAttributeAtIndex( + lf, + c_uint::MAX, + myhwattr.as_ptr() as *const c_char, + myhwattr.as_bytes().len() as c_uint, + ); + if LLVMIsStringAttribute(attr) { + LLVMRemoveStringAttributeAtIndex( + lf, + c_uint::MAX, + myhwattr.as_ptr() as *const c_char, + myhwattr.as_bytes().len() as c_uint, + ); + } else { + LLVMRustRemoveEnumAttributeAtIndex( + lf, + c_uint::MAX, + AttributeKind::SanitizeHWAddress, + ); + } + } else { + break; + } + } + } + + if let Some(opt_level) = config.opt_level { + let opt_stage = match cgcx.lto { + Lto::Fat => llvm::OptStage::PreLinkFatLTO, + Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO, + _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, + _ => llvm::OptStage::PreLinkNoLTO, + }; + let skip_size_increasing_opts = false; + dbg!("Running Module Optimization after differentiation"); + unsafe { + llvm_optimize( + cgcx, + diag_handler.handle(), + module, + config, + opt_level, + opt_stage, + skip_size_increasing_opts, + )? + }; + } + dbg!("Done with differentiate()"); + + Ok(()) +} + // Unsafe due to LLVM calls. pub(crate) unsafe fn optimize( cgcx: &CodegenContext, @@ -628,6 +767,47 @@ pub(crate) unsafe fn optimize( unsafe { llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()) }; } + // This code enables Enzyme to differentiate code containing Rust enums. + // By adding the SanitizeHWAddress attribute we prevent LLVM from Optimizing + // away the enums and allows Enzyme to understand why a value can be of different types in + // different code sections. We remove this attribute after Enzyme is done, to not affect the + // rest of the compilation. + #[cfg(llvm_enzyme)] + unsafe { + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let myhwv = ""; + let prevattr = LLVMRustGetEnumAttributeAtIndex( + lf, + c_uint::MAX, + AttributeKind::SanitizeHWAddress, + ); + if LLVMIsEnumAttribute(prevattr) { + let attr = LLVMCreateStringAttribute( + llcx, + myhwattr.as_ptr() as *const c_char, + myhwattr.as_bytes().len() as c_uint, + myhwv.as_ptr() as *const c_char, + myhwv.as_bytes().len() as c_uint, + ); + LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1); + } else { + LLVMRustAddEnumAttributeAtIndex( + llcx, + lf, + c_uint::MAX, + AttributeKind::SanitizeHWAddress, + ); + } + } else { + break; + } + } + } + if let Some(opt_level) = config.opt_level { let opt_stage = match cgcx.lto { Lto::Fat => llvm::OptStage::PreLinkFatLTO, @@ -635,7 +815,20 @@ pub(crate) unsafe fn optimize( _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, _ => llvm::OptStage::PreLinkNoLTO, }; - return unsafe { llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) }; + + // If we know that we will later run AD, then we disable vectorization and loop unrolling + let skip_size_increasing_opts = cfg!(llvm_enzyme); + return unsafe { + llvm_optimize( + cgcx, + dcx, + module, + config, + opt_level, + opt_stage, + skip_size_increasing_opts, + ) + }; } Ok(()) } diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index b5bb7630ca6c9..f36cb23ce1982 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -5,6 +5,7 @@ use std::{iter, ptr}; use libc::{c_char, c_uint}; use rustc_abi as abi; use rustc_abi::{Align, Size, WrappingRange}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_codegen_ssa::MemFlags; use rustc_codegen_ssa::common::{IntPredicate, RealPredicate, SynchronizationScope, TypeKind}; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; @@ -24,17 +25,225 @@ use rustc_span::Span; use rustc_target::callconv::FnAbi; use rustc_target::spec::{HasTargetSpec, SanitizerSet, Target}; use smallvec::SmallVec; -use tracing::{debug, instrument}; +use tracing::{debug, instrument, trace}; use crate::abi::FnAbiLlvmExt; use crate::attributes; use crate::common::Funclet; use crate::context::CodegenCx; -use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, True}; +use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, Metadata, True}; use crate::type_::Type; use crate::type_of::LayoutLlvmExt; use crate::value::Value; +fn get_params(fnc: &Value) -> Vec<&Value> { + unsafe { + let param_num = llvm::LLVMCountParams(fnc) as usize; + let mut fnc_args: Vec<&Value> = vec![]; + fnc_args.reserve(param_num); + llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr()); + fnc_args.set_len(param_num); + fnc_args + } +} + +pub(crate) fn add_opt_dbg_helper2<'ll>( + llmod: &'ll llvm::Module, + llcx: &'ll llvm::Context, + todiff: &'ll Value, + outer_fn: &'ll Value, + attrs: AutoDiffAttrs, +) { + let inputs = attrs.input_activity; + let output = attrs.ret_activity; + let mut ad_name: String = match attrs.mode { + DiffMode::Forward => "__enzyme_fwddiff", + DiffMode::Reverse => "__enzyme_autodiff", + DiffMode::ForwardFirst => "__enzyme_fwddiff", + DiffMode::ReverseFirst => "__enzyme_autodiff", + _ => panic!("Why are we here?"), + } + .to_string(); + // add outer_fn name to ad_name to make it unique + let outer_fn_name = unsafe { + let mut len: usize = 0; + let name = llvm::LLVMGetValueName2(outer_fn, &mut len as *mut usize); + std::ffi::CStr::from_ptr(name).to_str().unwrap() + }; + ad_name.push_str(outer_fn_name.to_string().as_str()); + + // Assuming that our todiff is the fnc square, want to generate the following llvm-ir: + // declare double @__enzyme_autodiff(...) + // + // define double @dsquare(double %x) { + // entry: + // %0 = tail call double (...) @__enzyme_autodiff(double (double)* nonnull @square, double %x) + // ret double %0 + // } + + unsafe { + let fn_ty = llvm::LLVMRustGetFunctionType(outer_fn); + let ret_ty = llvm::LLVMGetReturnType(fn_ty); + + // First we add the declaration of the __enzyme function + let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True); + let ad_fn = llvm::LLVMRustGetOrInsertFunction( + llmod, + ad_name.as_ptr() as *const c_char, + ad_name.len().try_into().unwrap(), + enzyme_ty, + ); + llvm::LLVMRustAddEnumAttributeAtIndex( + llcx, + ad_fn, + c_uint::MAX, + llvm::AttributeKind::NoInline, + ); + + // first, remove all calls from fnc + let entry = llvm::LLVMGetFirstBasicBlock(outer_fn); + let br = llvm::LLVMRustGetTerminator(entry); + llvm::LLVMRustEraseInstFromParent(br); + + let builder = llvm::LLVMCreateBuilderInContext(llcx); + let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap(); + llvm::LLVMPositionBuilderAtEnd(builder, entry); + + let num_args = llvm::LLVMCountParams(todiff); + let mut args = Vec::with_capacity(num_args as usize + 1); + args.push(todiff); + let enzyme_const = + llvm::LLVMMDStringInContext2(llcx, "enzyme_const".as_ptr() as *const c_char, 12); + let enzyme_out = + llvm::LLVMMDStringInContext2(llcx, "enzyme_out".as_ptr() as *const c_char, 10); + let enzyme_dup = + llvm::LLVMMDStringInContext2(llcx, "enzyme_dup".as_ptr() as *const c_char, 10); + let enzyme_dupnoneed = + llvm::LLVMMDStringInContext2(llcx, "enzyme_dupnoneed".as_ptr() as *const c_char, 16); + let enzyme_primal_ret = llvm::LLVMMDStringInContext2( + llcx, + "enzyme_primal_return".as_ptr() as *const c_char, + 20, + ); + + match output { + DiffActivity::Dual => { + args.push(llvm::LLVMMetadataAsValue(llcx, enzyme_primal_ret)); + } + DiffActivity::Active => { + args.push(llvm::LLVMMetadataAsValue(llcx, enzyme_primal_ret)); + } + _ => {} + } + + trace!("Matching args"); + // Idea: We follow the outer function's arguments and activities. + // We still keep track of our inner function's arguments, but just + // to verify that they match. + let mut outer_pos: usize = 0; + let mut activity_pos = 0; + let outer_args: Vec<&llvm::Value> = get_params(outer_fn); + + while activity_pos < inputs.len() { + let activity = inputs[activity_pos as usize]; + let (activity, duplicated): (&Metadata, bool) = match activity { + DiffActivity::None => panic!(), + DiffActivity::Const => (enzyme_const, false), + DiffActivity::Active => (enzyme_out, false), + DiffActivity::ActiveOnly => (enzyme_out, false), + DiffActivity::Dual => (enzyme_dup, true), + DiffActivity::DualOnly => (enzyme_dupnoneed, true), + DiffActivity::Duplicated => (enzyme_dup, true), + DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true), + DiffActivity::FakeActivitySize => (enzyme_const, false), + }; + let outer_arg = outer_args[outer_pos]; + args.push(llvm::LLVMMetadataAsValue(llcx, activity)); + args.push(outer_arg); + // Now if we have a slice and duplicate, then it get's interesting. + // + if duplicated { + let next_outer_arg = outer_args[outer_pos + 1]; + let next_outer_ty = llvm::LLVMTypeOf(next_outer_arg); + let slice = { + if activity_pos + 1 >= inputs.len() { + // If there is no arg following our ptr, it also can't be a slice, + // since that would lead to a ptr, int pair. + false + } else { + let next_activity = inputs[activity_pos + 1]; + next_activity == DiffActivity::FakeActivitySize + } + }; + if slice { + assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Integer); + let next_outer_arg2 = outer_args[outer_pos + 2]; + let next_outer_ty2 = llvm::LLVMTypeOf(next_outer_arg2); + assert!(llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Pointer); + let next_outer_arg3 = outer_args[outer_pos + 3]; + let next_outer_ty3 = llvm::LLVMTypeOf(next_outer_arg3); + assert!(llvm::LLVMRustGetTypeKind(next_outer_ty3) == llvm::TypeKind::Integer); + args.push(next_outer_arg2); + args.push(llvm::LLVMMetadataAsValue(llcx, enzyme_const)); + args.push(next_outer_arg); + outer_pos += 4; + activity_pos += 2; + } else { + assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer); + args.push(next_outer_arg); + outer_pos += 2; + activity_pos += 1; + } + } else { + outer_pos += 1; + activity_pos += 1; + } + } + + let call = llvm::LLVMBuildCall2( + builder, + enzyme_ty, + ad_fn, + args.as_mut_ptr(), + args.len().try_into().unwrap(), + ad_name.as_ptr() as *const c_char, + ); + + // Add dummy dbg info to our newly generated call, if we have any. + let md_ty = llvm::LLVMGetMDKindIDInContext( + llcx, + "dbg".as_ptr() as *const c_char, + "dbg".len() as c_uint, + ); + + if llvm::LLVMRustHasMetadata(last_inst, md_ty) { + let md = llvm::LLVMRustDIGetInstMetadata(last_inst); + let md_todiff = llvm::LLVMMetadataAsValue(llcx, md); + let _md2 = llvm::LLVMSetMetadata(call, md_ty, md_todiff); + } else { + trace!("No dbg info"); + } + llvm::LLVMRustEraseInstBefore(entry, last_inst); + + let void_ty = llvm::LLVMVoidTypeInContext(llcx); + if llvm::LLVMTypeOf(call) != void_ty { + llvm::LLVMBuildRet(builder, call); + } else { + llvm::LLVMBuildRetVoid(builder); + }; + llvm::LLVMDisposeBuilder(builder); + + let _fnc_ok = llvm::LLVMVerifyFunction( + outer_fn, + llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction, + ); + let _fnc_ok = llvm::LLVMVerifyFunction( + todiff, + llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction, + ); + } +} + // All Builders must have an llfn associated with them #[must_use] pub(crate) struct Builder<'a, 'll, 'tcx> { diff --git a/compiler/rustc_codegen_llvm/src/errors.rs b/compiler/rustc_codegen_llvm/src/errors.rs index 3cdb5b971d908..f340b06e876cd 100644 --- a/compiler/rustc_codegen_llvm/src/errors.rs +++ b/compiler/rustc_codegen_llvm/src/errors.rs @@ -89,6 +89,11 @@ impl Diagnostic<'_, G> for ParseTargetMachineConfig<'_> { } } +#[derive(Diagnostic)] +#[diag(codegen_llvm_autodiff_without_lto)] +#[note] +pub(crate) struct AutoDiffWithoutLTO; + #[derive(Diagnostic)] #[diag(codegen_llvm_lto_disallowed)] pub(crate) struct LtoDisallowed; @@ -131,6 +136,8 @@ pub enum LlvmError<'a> { PrepareThinLtoModule, #[diag(codegen_llvm_parse_bitcode)] ParseBitcode, + #[diag(codegen_llvm_prepare_autodiff)] + PrepareAutoDiff { src: String, target: String, error: String }, } pub(crate) struct WithLlvmError<'a>(pub LlvmError<'a>, pub String); @@ -152,6 +159,7 @@ impl Diagnostic<'_, G> for WithLlvmError<'_> { } PrepareThinLtoModule => fluent::codegen_llvm_prepare_thin_lto_module_with_llvm_err, ParseBitcode => fluent::codegen_llvm_parse_bitcode_with_llvm_err, + PrepareAutoDiff { .. } => fluent::codegen_llvm_prepare_autodiff_with_llvm_err, }; self.0 .into_diag(dcx, level) diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 3dfb86d422dd2..be2f6c1d0bd9e 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -26,9 +26,10 @@ use std::mem::ManuallyDrop; use back::owned_target_machine::OwnedTargetMachine; use back::write::{create_informational_target_machine, create_target_machine}; -use errors::ParseTargetMachineConfig; +use errors::{AutoDiffWithoutLTO, ParseTargetMachineConfig}; pub use llvm_util::target_features; use rustc_ast::expand::allocator::AllocatorKind; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use rustc_codegen_ssa::back::write::{ CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryConfig, TargetMachineFactoryFn, @@ -42,7 +43,7 @@ use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; use rustc_middle::ty::TyCtxt; use rustc_middle::util::Providers; use rustc_session::Session; -use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest}; +use rustc_session::config::{Lto, OptLevel, OutputFilenames, PrintKind, PrintRequest}; use rustc_span::symbol::Symbol; mod back { @@ -231,6 +232,19 @@ impl WriteBackendMethods for LlvmCodegenBackend { fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer) { (module.name, back::lto::ModuleBuffer::new(module.module_llvm.llmod())) } + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + config: &ModuleConfig, + ) -> Result<(), FatalError> { + if cgcx.lto != Lto::Fat { + let dcx = cgcx.create_dcx(); + return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutLTO {})); + } + back::write::differentiate(module, cgcx, diff_fncs, config) + } } unsafe impl Send for LlvmCodegenBackend {} // Llvm is on a per-thread basis @@ -386,6 +400,7 @@ impl CodegenBackend for LlvmCodegenBackend { } } +#[allow(dead_code)] pub struct ModuleLlvm { llcx: &'static mut llvm::Context, llmod_raw: *const llvm::Module, diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs new file mode 100644 index 0000000000000..556b848996145 --- /dev/null +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -0,0 +1,66 @@ +#![allow(non_camel_case_types)] + +use libc::{c_char, c_uint, size_t}; + +use super::ffi::*; + +extern "C" { + // Enzyme + pub fn LLVMRustAddFncParamAttr<'a>(F: &'a Value, index: c_uint, Attr: &'a Attribute); + pub fn LLVMRustAddRetFncAttr(F: &Value, attr: &Attribute); + pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool; + pub fn LLVMRustEraseInstBefore(BB: &BasicBlock, I: &Value); + pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>; + pub fn LLVMRustDIGetInstMetadata(I: &Value) -> &Metadata; + pub fn LLVMRustEraseInstFromParent(V: &Value); + pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value; + pub fn LLVMRustAddEnumAttributeAtIndex( + C: &Context, + V: &Value, + index: c_uint, + attr: AttributeKind, + ); + pub fn LLVMRustRemoveEnumAttributeAtIndex(V: &Value, index: c_uint, attr: AttributeKind); + pub fn LLVMRustGetEnumAttributeAtIndex( + V: &Value, + index: c_uint, + attr: AttributeKind, + ) -> &Attribute; + pub fn LLVMRustAddParamAttr<'a>(Instr: &'a Value, index: c_uint, Attr: &'a Attribute); + + pub fn LLVMGetReturnType(T: &Type) -> &Type; + pub fn LLVMDumpModule(M: &Module); + pub fn LLVMCountStructElementTypes(T: &Type) -> c_uint; + pub fn LLVMVerifyFunction(V: &Value, action: LLVMVerifierFailureAction) -> bool; + pub fn LLVMGetParams(Fnc: &Value, parms: *mut &Value); + pub fn LLVMBuildCall2<'a>( + arg1: &Builder<'a>, + ty: &Type, + func: &Value, + args: *mut &Value, + num_args: size_t, + name: *const c_char, + ) -> &'a Value; + pub fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>; + pub fn LLVMGetNextFunction(V: &Value) -> Option<&Value>; + pub fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>; + pub fn LLVMRustGetFunctionType(fnc: &Value) -> &Type; + + pub fn LLVMRemoveStringAttributeAtIndex(F: &Value, Idx: c_uint, K: *const c_char, KLen: c_uint); + pub fn LLVMGetStringAttributeAtIndex( + F: &Value, + Idx: c_uint, + K: *const c_char, + KLen: c_uint, + ) -> &Attribute; + pub fn LLVMIsEnumAttribute(A: &Attribute) -> bool; + pub fn LLVMIsStringAttribute(A: &Attribute) -> bool; + +} + +#[repr(C)] +pub enum LLVMVerifierFailureAction { + LLVMAbortProcessAction, + LLVMPrintMessageAction, + LLVMReturnStatusAction, +} diff --git a/compiler/rustc_codegen_llvm/src/llvm/mod.rs b/compiler/rustc_codegen_llvm/src/llvm/mod.rs index 909afe35a179b..58b50ee12a101 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/mod.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/mod.rs @@ -22,8 +22,12 @@ use crate::common::AsCCharPtr; pub mod archive_ro; pub mod diagnostic; +pub mod enzyme_ffi; mod ffi; +pub use self::enzyme_ffi::*; +pub use self::ffi::*; + impl LLVMRustResult { pub fn into_result(self) -> Result<(), ()> { match self { diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index ab8b06a05fc74..0a257c2904c5f 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -1,11 +1,13 @@ use std::ffi::CString; use std::sync::Arc; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_data_structures::memmap::Mmap; use rustc_errors::FatalError; use super::write::CodegenContext; use crate::ModuleCodegen; +use crate::back::write::ModuleConfig; use crate::traits::*; pub struct ThinModule { @@ -81,6 +83,23 @@ impl LtoModuleCodegen { LtoModuleCodegen::Thin(ref m) => m.cost(), } } + + /// Run autodiff on Fat LTO module + pub unsafe fn autodiff( + self, + cgcx: &CodegenContext, + diff_fncs: Vec, + config: &ModuleConfig, + ) -> Result, FatalError> { + match &self { + LtoModuleCodegen::Fat(module) => { + B::autodiff(cgcx, &module, diff_fncs, config)?; + } + _ => panic!("Unreachable? Autodiff called with non-fat LTO module"), + } + + Ok(self) + } } pub enum SerializedModule { diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index aabe9e33c4aa1..97fe614aa10cd 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -1,3 +1,4 @@ +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_errors::{DiagCtxtHandle, FatalError}; use rustc_middle::dep_graph::WorkProduct; @@ -61,6 +62,12 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { want_summary: bool, ) -> (String, Self::ThinBuffer); fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer); + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + config: &ModuleConfig, + ) -> Result<(), FatalError>; } pub trait ThinBufferMethods: Send + Sync { diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index cd70c3f266920..45594bbc22cf6 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -388,11 +388,44 @@ extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr, AddAttributes(Call, Index, Attrs, AttrsLen); } +extern "C" LLVMValueRef LLVMRustGetTerminator(LLVMBasicBlockRef BB) { + Instruction *ret = unwrap(BB)->getTerminator(); + return wrap(ret); +} + +extern "C" void LLVMRustEraseInstFromParent(LLVMValueRef Instr) { + if (auto I = dyn_cast(unwrap(Instr))) { + I->eraseFromParent(); + } +} + +extern "C" LLVMTypeRef LLVMRustGetFunctionType(LLVMValueRef Fn) { + auto Ftype = unwrap(Fn)->getFunctionType(); + return wrap(Ftype); +} + +extern "C" void LLVMRustRemoveEnumAttributeAtIndex(LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + LLVMRemoveEnumAttributeAtIndex(F, index, fromRust(RustAttr)); +} + extern "C" LLVMAttributeRef LLVMRustCreateAttrNoValue(LLVMContextRef C, LLVMRustAttributeKind RustAttr) { return wrap(Attribute::get(*unwrap(C), fromRust(RustAttr))); } +extern "C" void LLVMRustAddEnumAttributeAtIndex(LLVMContextRef C, + LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + LLVMAddAttributeAtIndex(F, index, LLVMRustCreateAttrNoValue(C, RustAttr)); +} + +extern "C" LLVMAttributeRef +LLVMRustGetEnumAttributeAtIndex(LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + return LLVMGetEnumAttributeAtIndex(F, index, fromRust(RustAttr)); +} + extern "C" LLVMAttributeRef LLVMRustCreateAlignmentAttr(LLVMContextRef C, uint64_t Bytes) { return wrap(Attribute::getWithAlignment(*unwrap(C), llvm::Align(Bytes))); @@ -954,6 +987,63 @@ extern "C" void LLVMRustAddModuleFlagString( MDString::get(unwrap(M)->getContext(), StringRef(Value, ValueLen))); } +extern "C" bool LLVMRustHasModuleFlag(LLVMModuleRef M, const char *Name, + size_t Len) { + return unwrap(M)->getModuleFlag(StringRef(Name, Len)) != nullptr; +} + +extern "C" LLVMValueRef LLVMRustGetLastInstruction(LLVMBasicBlockRef BB) { + auto Point = unwrap(BB)->rbegin(); + if (Point != unwrap(BB)->rend()) + return wrap(&*Point); + return nullptr; +} + +extern "C" void LLVMRustEraseInstBefore(LLVMBasicBlockRef bb, LLVMValueRef I) { + auto &BB = *unwrap(bb); + auto &Inst = *unwrap(I); + auto It = BB.begin(); + while (&*It != &Inst) + ++It; + assert(It != BB.end()); + // Delete in rev order to ensure no dangling references. + while (It != BB.begin()) { + auto Prev = std::prev(It); + It->eraseFromParent(); + It = Prev; + } + It->eraseFromParent(); +} + +extern "C" bool LLVMRustHasMetadata(LLVMValueRef inst, unsigned kindID) { + if (auto *I = dyn_cast(unwrap(inst))) { + return I->hasMetadata(kindID); + } + return false; +} + +extern "C" void LLVMRustAddFncParamAttr(LLVMValueRef F, unsigned i, + LLVMAttributeRef RustAttr) { + if (auto *Fn = dyn_cast(unwrap(F))) { + Fn->addParamAttr(i, unwrap(RustAttr)); + } +} + +extern "C" void LLVMRustAddRetFncAttr(LLVMValueRef F, + LLVMAttributeRef RustAttr) { + if (auto *Fn = dyn_cast(unwrap(F))) { + Fn->addRetAttr(unwrap(RustAttr)); + } +} + +extern "C" LLVMMetadataRef LLVMRustDIGetInstMetadata(LLVMValueRef x) { + if (auto *I = dyn_cast(unwrap(x))) { + auto *MD = I->getDebugLoc().getAsMDNode(); + return wrap(MD); + } + return nullptr; +} + extern "C" void LLVMRustGlobalAddMetadata(LLVMValueRef Global, unsigned Kind, LLVMMetadataRef MD) { unwrap(Global)->addMetadata(Kind, *unwrap(MD)); diff --git a/config.example.toml b/config.example.toml index d3233ad17b511..44b4f5e2597d6 100644 --- a/config.example.toml +++ b/config.example.toml @@ -161,6 +161,9 @@ # Whether to build the clang compiler. #clang = false +# Wheter to build Enzyme as AutoDiff backend. +#enzyme = true + # Whether to enable llvm compilation warnings. #enable-warnings = false