Skip to content

Commit

Permalink
adressing feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Dec 6, 2024
1 parent 3f3ee94 commit 2dd80bf
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 21 deletions.
27 changes: 6 additions & 21 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,20 +341,11 @@ void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle,
}


/// This is the main entry point to register a custom derivative for language frontends.
/// It should be prefered over trying to register custom-derivatives in the llvm-ir module.
/// The main reason is that these rules can handle non-default activity cases, e.g.
/// a function call where a pointer or a float scalar is marked as const.
/// To get a better low-level understanding, the code in AdjointGenerator.h can be read.
///
/// This Function will only handle using ReverseMode AD (either split or combined).
/// As a high-level example, assume we want to register a custom derivative for a vector resize function.
/// We pass the mangled name as first argument.
/// resizing is a simple enough to be handled in the fwd pass, so we just pass a nullptr as RevHandle.
/// For the forward pass, we will need to resize the shadow of the input (if duplicated), so the CI
/// is a CallInst of resize on the shadow argument. The IRBuilder B, the shadow argument, and gutils
/// should all be provided available in the frontend. The last three arguments ...
///
/// This Function will only handle ReverseMode AD (either split or combined).
/// As a high-level example, assume we want to register a custom derivative for `pow(x, y)`.
/// We pass the mangled name of pow as first argument.
/// The IRBuilder B, the shadow argument, and gutils should all be available in the frontend.
/// The last three arguments ...
void EnzymeRegisterCallHandler(char *Name,
CustomAugmentedFunctionForward FwdHandle,
CustomFunctionReverse RevHandle) {
Expand All @@ -378,13 +369,7 @@ void EnzymeRegisterCallHandler(char *Name,
};
}

/// This is the main entry point to register a custom derivative for language frontends.
/// It should be prefered over trying to register custom-derivatives in the llvm-ir module.
/// The main reason is that these rules can handle non-default activity cases, e.g.
/// a function call where a pointer or a float scalar is marked as const.
/// To get a better low-level understanding, the code in AdjointGenerator.h can be read.
///
/// This Function will only handle using ForwardMode AD.
/// This Function will only handle ForwardMode AD.
void EnzymeRegisterFwdCallHandler(char *Name, CustomFunctionForward FwdHandle) {
auto &pair = customFwdCallHandlers[Name];
pair = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils,
Expand Down
7 changes: 7 additions & 0 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ extern llvm::StringMap<std::function<llvm::Value *(
shadowHandlers;

class DiffeGradientUtils;

/// This is the main entry point to register a custom derivative for language frontends.
/// It is more general, and therefore recommended over trying to register
/// custom-derivatives in the llvm-ir module. Examples on why it is more general include custom-derivatives
/// for non-default activity cases (e.g. a function call where a pointer or a float scalar is marked as const).
/// It also allows using Enzyme and LLVM analysis, e.g. activity analysis, differential use analysis, alias analysis.
/// To get a better low-level understanding, the documentation in CApi.cpp can be read.
extern llvm::StringMap<std::pair<
std::function<bool(llvm::IRBuilder<> &, llvm::CallInst *, GradientUtils &,
llvm::Value *&, llvm::Value *&, llvm::Value *&)>,
Expand Down

0 comments on commit 2dd80bf

Please sign in to comment.