Skip to content

Commit

Permalink
add fwd mode example
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Dec 7, 2024
1 parent 2dd80bf commit c7545e8
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 11 deletions.
10 changes: 4 additions & 6 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,11 +341,8 @@ void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle,
}


/// This Function will only handle ReverseMode AD (either split or combined).
/// As a high-level example, assume we want to register a custom derivative for `pow(x, y)`.
/// We pass the mangled name of pow as first argument.
/// The IRBuilder B, the shadow argument, and gutils should all be available in the frontend.
/// The last three arguments ...
/// This is the entry point to register reverse-mode custom derivatives programmatically.
/// A more detailed documentation is available in GradientUtils.h
void EnzymeRegisterCallHandler(char *Name,
CustomAugmentedFunctionForward FwdHandle,
CustomFunctionReverse RevHandle) {
Expand All @@ -369,7 +366,8 @@ void EnzymeRegisterCallHandler(char *Name,
};
}

/// This Function will only handle ForwardMode AD.
/// This is the entry point to register forward-mode custom derivatives programmatically.
/// A more detailed documentation is available in GradientUtils.h
void EnzymeRegisterFwdCallHandler(char *Name, CustomFunctionForward FwdHandle) {
auto &pair = customFwdCallHandlers[Name];
pair = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils,
Expand Down
71 changes: 66 additions & 5 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,80 @@ extern llvm::StringMap<std::function<llvm::Value *(

class DiffeGradientUtils;

/// This is the main entry point to register a custom derivative for language frontends.
/// It is more general, and therefore recommended over trying to register
/// custom-derivatives in the llvm-ir module. Examples on why it is more general include custom-derivatives
/// for non-default activity cases (e.g. a function call where a pointer or a float scalar is marked as const).
/// This is the entry point to register custom derivatives programmatically.
/// It is more general than registering custom-derivatives in the llvm-ir module, at the cost of higher complexity.
/// Examples on why it is more general include custom-derivatives for non-default activity cases
/// (e.g. a function call where a pointer or a float scalar is marked as const).
/// It also allows using Enzyme and LLVM analysis, e.g. activity analysis, differential use analysis, alias analysis.
/// To get a better low-level understanding, the documentation in CApi.cpp can be read.
///
///
extern llvm::StringMap<std::pair<
std::function<bool(llvm::IRBuilder<> &, llvm::CallInst *, GradientUtils &,
llvm::Value *&, llvm::Value *&, llvm::Value *&)>,
std::function<void(llvm::IRBuilder<> &, llvm::CallInst *,
DiffeGradientUtils &, llvm::Value *)>>>
customCallHandlers;

/// The StringMap allows looking up a (forward-mode) custom rule based on the mangled name of the function.
/// The first argument is the IRBuilder, the third argument are gradientutils, both of which should be already
/// available in the frontend. The second argument is the CallInst, this should be for a function call which will
/// compute the forward-mode derivative, while taking into consideration which input arguments are active or const.
/// The function returns true, if the custom rule was applied, and false otherwise (e.g. because the combination of
/// activities is not yet supported). The last two arguments are ...
///
/// Example:
/// define double @my_pow(double %x, double %y) {
/// %call = call double @llvm.pow(double %x, double %y)
/// ret double %call
/// }
///
/// The custom rule for this function could be:
/// customCallHandlers["my_pow"] = [](llvm::IRBuilder<> &Builder, llvm::CallInst *CI, GradientUtils &gutils, llvm::Value *&dcall, llvm::Value *&normalReturn, llvm::Value *&shadowReturn) {
/// auto x = CI->getArgOperand(0);
/// auto y = CI->getArgOperand(1);
/// auto xprime = gutils.getNewFromOriginal(x);
/// auto yprime = gutils.getNewFromOriginal(y);
/// bool is_x_active = !gutils.isConstantValue(x);
/// bool is_y_active = !gutils.isConstantValue(y);
/// normalreturn = Builder.CreateCall(Intrinsic::pow, {x, y});
/// if (is_x_active) {
/// auto ym1 = Builder.CreateFSub(y, ConstantFP::get(Type::getDoubleTy(CI->getContext()), 1.0));
/// auto pow = Builder.CreateCall(Intrinsic::pow, {x, ym1});
/// auto ypow = Builder.CreateFMul(y, pow);
/// shadowReturn = Builder.CreateFMul(xprime, ypow);
///
/// // if y were inactive, this would be conceptually equivalent to generating
/// // define internal double @fwddiffetester(double %x, double %"x'", double %y) #1 {
/// // %0 = fsub fast double %y, 1.000000e+00
/// // %1 = call fast double @llvm.pow.f64(double %x, double %0)
/// // %2 = fmul fast double %y, %1
/// // %3 = fmul fast double %"x'", %2
/// // ret double %3
/// // }
/// }
/// if (is_y_active) {
/// auto pow = Builder.CreateCall(Intrinsic::pow, {x, y});
/// auto log = Builder.CreateCall(Intrinsic::log, {x});
/// auto logpow = Builder.CreateFMul(pow, log);
/// auto ylogpow = Builder.CreateFMul(yprime, logpow);
/// if (is_x_active) {
/// shadowReturn = Builder.CreateFAdd(ylogpow, shadowReturn);
/// } else {
/// shadowReturn = ylogpow;
/// }
///
/// // if x was inactive, this would be conceptually equivalent to generating
/// // define internal double @fwddiffetester.1(double %x, double %y, double %"y'") #1 {
/// // %0 = call fast double @llvm.pow.f64(double %x, double %y)
/// // %1 = call fast double @llvm.log.f64(double %x)
/// // %2 = fmul fast double %0, %1
/// // %3 = fmul fast double %"y'", %2
/// // ret double %3
/// // }
/// }
/// // We covered all 2x2 combinations, so always return true
/// return true;
/// }
extern llvm::StringMap<
std::function<bool(llvm::IRBuilder<> &, llvm::CallInst *, GradientUtils &,
llvm::Value *&, llvm::Value *&)>>
Expand Down

0 comments on commit c7545e8

Please sign in to comment.