Skip to content

Commit

Permalink
Initial Forward Split Mode (rust-lang#539)
Browse files Browse the repository at this point in the history
* Begin forward split

* Add atomicadd test

* Starting to function split fwd

* Start tests

* Modref

* Get most of fwd split working

* memmove

* Fix test

* Adjust tests for higher llvm
  • Loading branch information
wsmoses authored Mar 4, 2022
1 parent 0ad5203 commit e32d116
Show file tree
Hide file tree
Showing 102 changed files with 5,214 additions and 505 deletions.
229 changes: 134 additions & 95 deletions enzyme/Enzyme/AdjointGenerator.h

Large diffs are not rendered by default.

25 changes: 14 additions & 11 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,9 @@ LLVMValueRef EnzymeCreateForwardDiff(
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode,
unsigned width, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo,
uint8_t *_uncacheable_args, size_t uncacheable_args_size) {
uint8_t freeMemory, unsigned width, LLVMTypeRef additionalArg,
CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
size_t uncacheable_args_size, EnzymeAugmentedReturnPtr augmented) {
std::vector<DIFFE_TYPE> nconstant_args((DIFFE_TYPE *)constant_args,
(DIFFE_TYPE *)constant_args +
constant_args_size);
Expand All @@ -393,9 +394,9 @@ LLVMValueRef EnzymeCreateForwardDiff(
}
return wrap(eunwrap(Logic).CreateForwardDiff(
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
eunwrap(TA), returnValue, (DerivativeMode)mode, width,
eunwrap(TA), returnValue, (DerivativeMode)mode, freeMemory, width,
unwrap(additionalArg), eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
uncacheable_args));
uncacheable_args, eunwrap(augmented)));
}
LLVMValueRef EnzymeCreatePrimalAndGradient(
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
Expand Down Expand Up @@ -432,12 +433,14 @@ LLVMValueRef EnzymeCreatePrimalAndGradient(
},
eunwrap(TA), eunwrap(augmented)));
}
EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnUsed, CFnTypeInfo typeInfo,
uint8_t *_uncacheable_args, size_t uncacheable_args_size,
uint8_t forceAnonymousTape, uint8_t AtomicAdd) {
EnzymeAugmentedReturnPtr
EnzymeCreateAugmentedPrimal(EnzymeLogicRef Logic, LLVMValueRef todiff,
CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
size_t constant_args_size, EnzymeTypeAnalysisRef TA,
uint8_t returnUsed, uint8_t shadowReturnUsed,
CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
size_t uncacheable_args_size,
uint8_t forceAnonymousTape, uint8_t AtomicAdd) {

std::vector<DIFFE_TYPE> nconstant_args((DIFFE_TYPE *)constant_args,
(DIFFE_TYPE *)constant_args +
Expand All @@ -451,7 +454,7 @@ EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
}
return ewrap(eunwrap(Logic).CreateAugmentedPrimal(
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
eunwrap(TA), returnUsed,
eunwrap(TA), returnUsed, shadowReturnUsed,
eunwrap(typeInfo, cast<Function>(unwrap(todiff))), uncacheable_args,
forceAnonymousTape, AtomicAdd));
}
Expand Down
12 changes: 7 additions & 5 deletions enzyme/Enzyme/CApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ LLVMValueRef EnzymeCreateForwardDiff(
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode,
unsigned width, LLVMTypeRef additionalArg, struct CFnTypeInfo typeInfo,
uint8_t *_uncacheable_args, size_t uncacheable_args_size);
uint8_t freeMemory, unsigned width, LLVMTypeRef additionalArg,
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
size_t uncacheable_args_size, EnzymeAugmentedReturnPtr augmented);

LLVMValueRef EnzymeCreatePrimalAndGradient(
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
Expand All @@ -136,9 +137,10 @@ LLVMValueRef EnzymeCreatePrimalAndGradient(
EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnUsed, struct CFnTypeInfo typeInfo,
uint8_t *_uncacheable_args, size_t uncacheable_args_size,
uint8_t forceAnonymousTape, uint8_t AtomicAdd);
EnzymeTypeAnalysisRef TA, uint8_t returnUsed, uint8_t shadowReturnUsed,
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
size_t uncacheable_args_size, uint8_t forceAnonymousTape,
uint8_t AtomicAdd);

typedef uint8_t (*CustomRuleType)(int /*direction*/, CTypeTreeRef /*return*/,
CTypeTreeRef * /*args*/,
Expand Down
181 changes: 130 additions & 51 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ llvm::cl::opt<bool> EnzymeOMPOpt("enzyme-omp-opt", cl::init(false), cl::Hidden,
#endif
namespace {

template <const char *handlername, int numargs>
template <const char *handlername, DerivativeMode Mode, int numargs>
static void
handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g,
std::vector<GlobalVariable *> &globalsToErase) {
Expand Down Expand Up @@ -130,7 +130,8 @@ handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g,
}
}

if (numargs == 3) {
if (Mode == DerivativeMode::ReverseModeGradient) {
assert(numargs == 3);
Fs[0]->setMetadata(
"enzyme_augment",
llvm::MDTuple::get(Fs[0]->getContext(),
Expand All @@ -139,12 +140,24 @@ handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g,
"enzyme_gradient",
llvm::MDTuple::get(Fs[0]->getContext(),
{llvm::ValueAsMetadata::get(Fs[2])}));
} else if (numargs == 2) {
} else if (Mode == DerivativeMode::ForwardMode) {
assert(numargs == 2);
Fs[0]->setMetadata(
"enzyme_derivative",
llvm::MDTuple::get(Fs[0]->getContext(),
{llvm::ValueAsMetadata::get(Fs[1])}));
}
} else if (Mode == DerivativeMode::ForwardModeSplit) {
assert(numargs == 3);
Fs[0]->setMetadata(
"enzyme_augment",
llvm::MDTuple::get(Fs[0]->getContext(),
{llvm::ValueAsMetadata::get(Fs[1])}));
Fs[0]->setMetadata(
"enzyme_splitderivative",
llvm::MDTuple::get(Fs[0]->getContext(),
{llvm::ValueAsMetadata::get(Fs[2])}));
} else
assert("Unknown mode");
}
} else {
llvm::errs() << M << "\n";
Expand Down Expand Up @@ -444,6 +457,8 @@ class Enzyme : public ModulePass {
IRBuilder<> Builder(CI);
unsigned truei = 0;
unsigned width = 1;
bool returnUsed = !cast<Function>(fn)->getReturnType()->isVoidTy() &&
!cast<Function>(fn)->getReturnType()->isEmptyTy();

// determine width
#if LLVM_VERSION_MAJOR >= 14
Expand All @@ -455,44 +470,44 @@ class Enzyme : public ModulePass {
{
Value *arg = CI->getArgOperand(i);

if (getMetadataName(arg) && *getMetadataName(arg) == "enzyme_width") {
assert(mode == DerivativeMode::ForwardMode);

if (found) {
EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI,
"vector width declared more than once",
*CI->getArgOperand(i), " in", *CI);
return false;
}
if (auto MDName = getMetadataName(arg)) {
if (*MDName == "enzyme_width") {
if (found) {
EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI,
"vector width declared more than once",
*CI->getArgOperand(i), " in", *CI);
return false;
}

#if LLVM_VERSION_MAJOR >= 14
if (i + 1 >= CI->arg_size())
if (i + 1 >= CI->arg_size())
#else
if (i + 1 >= CI->getNumArgOperands())
if (i + 1 >= CI->getNumArgOperands())
#endif
{
EmitFailure("MissingVectorWidth", CI->getDebugLoc(), CI,
"constant integer followong enzyme_width is missing",
*CI->getArgOperand(i), " in", *CI);
return false;
}
{
EmitFailure("MissingVectorWidth", CI->getDebugLoc(), CI,
"constant integer followong enzyme_width is missing",
*CI->getArgOperand(i), " in", *CI);
return false;
}

Value *width_arg = CI->getArgOperand(i + 1);
if (auto cint = dyn_cast<ConstantInt>(width_arg)) {
width = cint->getZExtValue();
found = true;
} else {
EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI,
"enzyme_width must be a constant integer",
*CI->getArgOperand(i), " in", *CI);
return false;
}
Value *width_arg = CI->getArgOperand(i + 1);
if (auto cint = dyn_cast<ConstantInt>(width_arg)) {
width = cint->getZExtValue();
found = true;
} else {
EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI,
"enzyme_width must be a constant integer",
*CI->getArgOperand(i), " in", *CI);
return false;
}

if (!found) {
EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI,
"illegal enzyme vector argument width ",
*CI->getArgOperand(i), " in", *CI);
return false;
if (!found) {
EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI,
"illegal enzyme vector argument width ",
*CI->getArgOperand(i), " in", *CI);
return false;
}
}
}
}
Expand Down Expand Up @@ -579,8 +594,8 @@ class Enzyme : public ModulePass {

DIFFE_TYPE retType = whatType(cast<Function>(fn)->getReturnType(), mode);

bool differentialReturn = mode != DerivativeMode::ForwardMode &&
mode != DerivativeMode::ReverseModePrimal &&
bool differentialReturn = (mode == DerivativeMode::ReverseModeCombined ||
mode == DerivativeMode::ReverseModeGradient) &&
(retType == DIFFE_TYPE::OUT_DIFF);

std::map<int, Type *> byVal;
Expand All @@ -598,7 +613,9 @@ class Enzyme : public ModulePass {
Value *res = CI->getArgOperand(i);

if (truei >= FT->getNumParams()) {
if (mode == DerivativeMode::ReverseModeGradient) {
if (!isa<MetadataAsValue>(res) &&
(mode == DerivativeMode::ReverseModeGradient ||
mode == DerivativeMode::ForwardModeSplit)) {
if (differentialReturn && differet == nullptr) {
differet = res;
if (CI->paramHasAttr(i, Attribute::ByVal)) {
Expand Down Expand Up @@ -644,6 +661,9 @@ class Enzyme : public ModulePass {
ty = DIFFE_TYPE::OUT_DIFF;
} else if (*metaString == "enzyme_const") {
ty = DIFFE_TYPE::CONSTANT;
} else if (*metaString == "enzyme_noret") {
returnUsed = false;
continue;
} else if (*metaString == "enzyme_allocated") {
assert(!sizeOnly);
++i;
Expand Down Expand Up @@ -814,13 +834,57 @@ class Enzyme : public ModulePass {
Type *tapeType = nullptr;
const AugmentedReturn *aug;
switch (mode) {
case DerivativeMode::ForwardModeSplit:
case DerivativeMode::ForwardMode:
newFunc = Logic.CreateForwardDiff(
cast<Function>(fn), retType, constants, TA,
/*should return*/ false, mode, width,
/*addedType*/ nullptr, type_args, volatile_args);
/*should return*/ false, mode, freeMemory, width,
/*addedType*/ nullptr, type_args, volatile_args,
/*augmented*/ nullptr);
break;
case DerivativeMode::ForwardModeSplit: {
bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1;
aug = &Logic.CreateAugmentedPrimal(
cast<Function>(fn), retType, constants, TA,
/*returnUsed*/ false, /*shadowReturnUsed*/ false, type_args,
volatile_args, forceAnonymousTape, /*atomicAdd*/ AtomicAdd);
auto &DL = cast<Function>(fn)->getParent()->getDataLayout();
if (!forceAnonymousTape) {
assert(!aug->tapeType);
if (aug->returns.find(AugmentedStruct::Tape) != aug->returns.end()) {
auto tapeIdx = aug->returns.find(AugmentedStruct::Tape)->second;
tapeType = (tapeIdx == -1)
? aug->fn->getReturnType()
: cast<StructType>(aug->fn->getReturnType())
->getElementType(tapeIdx);
} else {
if (sizeOnly) {
CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0, false));
CI->eraseFromParent();
return true;
}
}
if (sizeOnly) {
auto size = DL.getTypeSizeInBits(tapeType) / 8;
CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), size, false));
CI->eraseFromParent();
return true;
}
if (tapeType &&
DL.getTypeSizeInBits(tapeType) < 8 * (size_t)allocatedTapeSize) {
auto bytes = DL.getTypeSizeInBits(tapeType) / 8;
EmitFailure("Insufficient tape allocation size", CI->getDebugLoc(),
CI, "need ", bytes, " bytes have ", allocatedTapeSize,
" bytes");
}
} else {
tapeType = PointerType::getInt8PtrTy(fn->getContext());
}
newFunc = Logic.CreateForwardDiff(
cast<Function>(fn), retType, constants, TA,
/*should return*/ false, mode, freeMemory, width,
/*addedType*/ tapeType, type_args, volatile_args, aug);
break;
}
case DerivativeMode::ReverseModeCombined:
assert(freeMemory);
newFunc = Logic.CreatePrimalAndGradient(
Expand All @@ -841,12 +905,12 @@ class Enzyme : public ModulePass {
case DerivativeMode::ReverseModePrimal:
case DerivativeMode::ReverseModeGradient: {
bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1;
bool returnUsed = !cast<Function>(fn)->getReturnType()->isVoidTy() &&
!cast<Function>(fn)->getReturnType()->isEmptyTy();
bool shadowReturnUsed = returnUsed && (retType == DIFFE_TYPE::DUP_ARG ||
retType == DIFFE_TYPE::DUP_NONEED);
aug = &Logic.CreateAugmentedPrimal(
cast<Function>(fn), retType, constants, TA,
/*returnUsed*/ returnUsed, type_args, volatile_args,
forceAnonymousTape, /*atomicAdd*/ AtomicAdd);
cast<Function>(fn), retType, constants, TA, returnUsed,
shadowReturnUsed, type_args, volatile_args, forceAnonymousTape,
/*atomicAdd*/ AtomicAdd);
auto &DL = cast<Function>(fn)->getParent()->getDataLayout();
if (!forceAnonymousTape) {
assert(!aug->tapeType);
Expand Down Expand Up @@ -918,7 +982,9 @@ class Enzyme : public ModulePass {
}
}

if (mode == DerivativeMode::ReverseModeGradient && tape && tapeType) {
if ((mode == DerivativeMode::ReverseModeGradient ||
mode == DerivativeMode::ForwardModeSplit) &&
tape && tapeType) {
auto &DL = cast<Function>(fn)->getParent()->getDataLayout();
if (tapeIsPointer) {
tape = Builder.CreateBitCast(
Expand Down Expand Up @@ -1227,6 +1293,7 @@ class Enzyme : public ModulePass {
Fn->getName().contains("__enzyme_call_inactive") ||
Fn->getName().contains("__enzyme_autodiff") ||
Fn->getName().contains("__enzyme_fwddiff") ||
Fn->getName().contains("__enzyme_fwdsplit") ||
Fn->getName().contains("__enzyme_augmentfwd") ||
Fn->getName().contains("__enzyme_augmentsize") ||
Fn->getName().contains("__enzyme_reverse")))
Expand Down Expand Up @@ -1484,6 +1551,9 @@ class Enzyme : public ModulePass {
} else if (Fn->getName().contains("__enzyme_fwddiff")) {
enableEnzyme = true;
mode = DerivativeMode::ForwardMode;
} else if (Fn->getName().contains("__enzyme_fwdsplit")) {
enableEnzyme = true;
mode = DerivativeMode::ForwardModeSplit;
} else if (Fn->getName().contains("__enzyme_augmentfwd")) {
enableEnzyme = true;
mode = DerivativeMode::ReverseModePrimal;
Expand Down Expand Up @@ -1679,17 +1749,26 @@ class Enzyme : public ModulePass {
"__enzyme_register_gradient";
constexpr static const char derivative_handler_name[] =
"__enzyme_register_derivative";
constexpr static const char splitderivative_handler_name[] =
"__enzyme_register_splitderivative";

Logic.clear();

bool changed = false;
std::vector<GlobalVariable *> globalsToErase;
for (GlobalVariable &g : M.globals()) {
if (g.getName().contains(gradient_handler_name)) {
handleCustomDerivative<gradient_handler_name, 3>(M, g, globalsToErase);
handleCustomDerivative<gradient_handler_name,
DerivativeMode::ReverseModeGradient, 3>(
M, g, globalsToErase);
} else if (g.getName().contains(derivative_handler_name)) {
handleCustomDerivative<derivative_handler_name, 2>(M, g,
globalsToErase);
handleCustomDerivative<derivative_handler_name,
DerivativeMode::ForwardMode, 2>(M, g,
globalsToErase);
} else if (g.getName().contains(splitderivative_handler_name)) {
handleCustomDerivative<splitderivative_handler_name,
DerivativeMode::ForwardModeSplit, 3>(
M, g, globalsToErase);
} else if (g.getName().contains("__enzyme_inactivefn")) {
handleInactiveFunction(M, g, globalsToErase);
}
Expand Down
Loading

0 comments on commit e32d116

Please sign in to comment.