Skip to content

Commit

Permalink
upstream rustc_codegen_llvm changes for enzyme/autodiff
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Nov 21, 2024
1 parent 3fee0f1 commit 78297a9
Show file tree
Hide file tree
Showing 13 changed files with 639 additions and 29 deletions.
19 changes: 3 additions & 16 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
use std::fmt::{self, Display, Formatter};
use std::str::FromStr;

use crate::expand::typetree::TypeTree;
use crate::expand::{Decodable, Encodable, HashStable_Generic};
use crate::ptr::P;
use crate::{Ty, TyKind};
Expand Down Expand Up @@ -79,10 +78,6 @@ pub struct AutoDiffItem {
/// The name of the function being generated
pub target: String,
pub attrs: AutoDiffAttrs,
/// Describe the memory layout of input types
pub inputs: Vec<TypeTree>,
/// Describe the memory layout of the output type
pub output: TypeTree,
}
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffAttrs {
Expand Down Expand Up @@ -262,22 +257,14 @@ impl AutoDiffAttrs {
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
}

pub fn into_item(
self,
source: String,
target: String,
inputs: Vec<TypeTree>,
output: TypeTree,
) -> AutoDiffItem {
AutoDiffItem { source, target, inputs, output, attrs: self }
pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
AutoDiffItem { source, target, attrs: self }
}
}

impl fmt::Display for AutoDiffItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
write!(f, " with attributes: {:?}", self.attrs)?;
write!(f, " with inputs: {:?}", self.inputs)?;
write!(f, " with output: {:?}", self.output)
write!(f, " with attributes: {:?}", self.attrs)
}
}
4 changes: 4 additions & 0 deletions compiler/rustc_codegen_llvm/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ codegen_llvm_prepare_thin_lto_module_with_llvm_err = failed to prepare thin LTO
codegen_llvm_run_passes = failed to run LLVM passes
codegen_llvm_run_passes_with_llvm_err = failed to run LLVM passes: {$llvm_err}
codegen_llvm_prepare_autodiff = failed to prepare AutoDiff: src: {$src}, target: {$target}, {$error}
codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare AutoDiff: {$llvm_err}, src: {$src}, target: {$target}, {$error}
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
codegen_llvm_sanitizer_memtag_requires_mte =
`-Zsanitizer=memtag` requires `-Ctarget-feature=+mte`
Expand Down
7 changes: 6 additions & 1 deletion compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,12 @@ pub(crate) fn run_pass_manager(
debug!("running the pass manager");
let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO };
let opt_level = config.opt_level.unwrap_or(config::OptLevel::No);
unsafe { write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) }?;
// We will run this again with different values in the context of automatic differentiation.
let first_run = true;
debug!("running llvm pm opt pipeline");
unsafe {
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?;
}
debug!("lto done");
Ok(())
}
Expand Down
209 changes: 201 additions & 8 deletions compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::{fs, slice, str};

use libc::{c_char, c_int, c_void, size_t};
use libc::{c_char, c_int, c_uint, c_void, size_t};
use llvm::{
LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols,
};
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_codegen_ssa::back::link::ensure_removed;
use rustc_codegen_ssa::back::versioned_llvm_target;
use rustc_codegen_ssa::back::write::{
Expand All @@ -28,7 +29,7 @@ use rustc_session::config::{
use rustc_span::InnerSpan;
use rustc_span::symbol::sym;
use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo, TlsModel};
use tracing::debug;
use tracing::{debug, trace};

use crate::back::lto::ThinBuffer;
use crate::back::owned_target_machine::OwnedTargetMachine;
Expand All @@ -41,7 +42,13 @@ use crate::errors::{
WithLlvmError, WriteBytecode,
};
use crate::llvm::diagnostic::OptimizationDiagnosticKind::*;
use crate::llvm::{self, DiagnosticInfo, PassManager};
use crate::llvm::{
self, AttributeKind, DiagnosticInfo, LLVMCreateStringAttribute, LLVMGetFirstFunction,
LLVMGetNextFunction, LLVMGetStringAttributeAtIndex, LLVMIsEnumAttribute, LLVMIsStringAttribute,
LLVMRemoveStringAttributeAtIndex, LLVMRustAddEnumAttributeAtIndex,
LLVMRustAddFunctionAttributes, LLVMRustGetEnumAttributeAtIndex,
LLVMRustRemoveEnumAttributeAtIndex, PassManager,
};
use crate::type_::Type;
use crate::{LlvmCodegenBackend, ModuleLlvm, base, common, llvm_util};

Expand Down Expand Up @@ -517,9 +524,34 @@ pub(crate) unsafe fn llvm_optimize(
config: &ModuleConfig,
opt_level: config::OptLevel,
opt_stage: llvm::OptStage,
skip_size_increasing_opts: bool,
) -> Result<(), FatalError> {
let unroll_loops =
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
// Enzyme:
// The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
// source code. However, benchmarks show that optimizations increasing the code size
// tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
// and finally re-optimize the module, now with all optimizations available.
// TODO: In a future update we could figure out how to only optimize functions getting
// differentiated.

let unroll_loops;
let vectorize_slp;
let vectorize_loop;

if skip_size_increasing_opts {
unroll_loops = false;
vectorize_slp = false;
vectorize_loop = false;
} else {
unroll_loops =
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
vectorize_slp = config.vectorize_slp;
vectorize_loop = config.vectorize_loop;
}
trace!(
"Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}",
unroll_loops, vectorize_slp, vectorize_loop
);
let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed();
let pgo_gen_path = get_pgo_gen_path(config);
let pgo_use_path = get_pgo_use_path(config);
Expand Down Expand Up @@ -583,8 +615,8 @@ pub(crate) unsafe fn llvm_optimize(
using_thin_buffers,
config.merge_functions,
unroll_loops,
config.vectorize_slp,
config.vectorize_loop,
vectorize_slp,
vectorize_loop,
config.no_builtins,
config.emit_lifetime_markers,
sanitizer_options.as_ref(),
Expand All @@ -606,6 +638,113 @@ pub(crate) unsafe fn llvm_optimize(
result.into_result().map_err(|()| llvm_err(dcx, LlvmError::RunLlvmPasses))
}

pub(crate) fn differentiate(
module: &ModuleCodegen<ModuleLlvm>,
cgcx: &CodegenContext<LlvmCodegenBackend>,
diff_items: Vec<AutoDiffItem>,
config: &ModuleConfig,
) -> Result<(), FatalError> {
for item in &diff_items {
trace!("{}", item);
}

let llmod = module.module_llvm.llmod();
let llcx = &module.module_llvm.llcx;
let diag_handler = cgcx.create_dcx();

// Before dumping the module, we want all the tt to become part of the module.
for item in diff_items.iter() {
let name = CString::new(item.source.clone()).unwrap();
let fn_def: Option<&llvm::Value> =
unsafe { llvm::LLVMGetNamedFunction(llmod, name.as_ptr()) };
let fn_def = match fn_def {
Some(x) => x,
None => {
return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
src: item.source.clone(),
target: item.target.clone(),
error: "could not find source function".to_owned(),
}));
}
};
let tgt_name = CString::new(item.target.clone()).unwrap();
dbg!("Target name: {:?}", &tgt_name);
let fn_target: Option<&llvm::Value> =
unsafe { llvm::LLVMGetNamedFunction(llmod, tgt_name.as_ptr()) };
let fn_target = match fn_target {
Some(x) => x,
None => {
return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
src: item.source.clone(),
target: item.target.clone(),
error: "could not find target function".to_owned(),
}));
}
};

crate::builder::add_opt_dbg_helper2(llmod, llcx, fn_def, fn_target, item.attrs.clone());
}

// We needed the SanitizeHWAddress attribute to prevent LLVM from optimizing enums in a way
// which Enzyme doesn't understand.
unsafe {
let mut f = LLVMGetFirstFunction(llmod);
loop {
if let Some(lf) = f {
f = LLVMGetNextFunction(lf);
let myhwattr = "enzyme_hw";
let attr = LLVMGetStringAttributeAtIndex(
lf,
c_uint::MAX,
myhwattr.as_ptr() as *const c_char,
myhwattr.as_bytes().len() as c_uint,
);
if LLVMIsStringAttribute(attr) {
LLVMRemoveStringAttributeAtIndex(
lf,
c_uint::MAX,
myhwattr.as_ptr() as *const c_char,
myhwattr.as_bytes().len() as c_uint,
);
} else {
LLVMRustRemoveEnumAttributeAtIndex(
lf,
c_uint::MAX,
AttributeKind::SanitizeHWAddress,
);
}
} else {
break;
}
}
}

if let Some(opt_level) = config.opt_level {
let opt_stage = match cgcx.lto {
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
_ => llvm::OptStage::PreLinkNoLTO,
};
let skip_size_increasing_opts = false;
dbg!("Running Module Optimization after differentiation");
unsafe {
llvm_optimize(
cgcx,
diag_handler.handle(),
module,
config,
opt_level,
opt_stage,
skip_size_increasing_opts,
)?
};
}
dbg!("Done with differentiate()");

Ok(())
}

// Unsafe due to LLVM calls.
pub(crate) unsafe fn optimize(
cgcx: &CodegenContext<LlvmCodegenBackend>,
Expand All @@ -628,14 +767,68 @@ pub(crate) unsafe fn optimize(
unsafe { llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()) };
}

// This code enables Enzyme to differentiate code containing Rust enums.
// By adding the SanitizeHWAddress attribute we prevent LLVM from Optimizing
// away the enums and allows Enzyme to understand why a value can be of different types in
// different code sections. We remove this attribute after Enzyme is done, to not affect the
// rest of the compilation.
#[cfg(llvm_enzyme)]
unsafe {
let mut f = LLVMGetFirstFunction(llmod);
loop {
if let Some(lf) = f {
f = LLVMGetNextFunction(lf);
let myhwattr = "enzyme_hw";
let myhwv = "";
let prevattr = LLVMRustGetEnumAttributeAtIndex(
lf,
c_uint::MAX,
AttributeKind::SanitizeHWAddress,
);
if LLVMIsEnumAttribute(prevattr) {
let attr = LLVMCreateStringAttribute(
llcx,
myhwattr.as_ptr() as *const c_char,
myhwattr.as_bytes().len() as c_uint,
myhwv.as_ptr() as *const c_char,
myhwv.as_bytes().len() as c_uint,
);
LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1);
} else {
LLVMRustAddEnumAttributeAtIndex(
llcx,
lf,
c_uint::MAX,
AttributeKind::SanitizeHWAddress,
);
}
} else {
break;
}
}
}

if let Some(opt_level) = config.opt_level {
let opt_stage = match cgcx.lto {
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
_ => llvm::OptStage::PreLinkNoLTO,
};
return unsafe { llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) };

// If we know that we will later run AD, then we disable vectorization and loop unrolling
let skip_size_increasing_opts = cfg!(llvm_enzyme);
return unsafe {
llvm_optimize(
cgcx,
dcx,
module,
config,
opt_level,
opt_stage,
skip_size_increasing_opts,
)
};
}
Ok(())
}
Expand Down
Loading

0 comments on commit 78297a9

Please sign in to comment.