Skip to content

Commit

Permalink
prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Dec 6, 2024
1 parent 57b718b commit 3f3ee94
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,21 @@ void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle,
};
}


/// This is the main entry point to register a custom derivative for language frontends.
/// It should be prefered over trying to register custom-derivatives in the llvm-ir module.
/// The main reason is that these rules can handle non-default activity cases, e.g.
/// a function call where a pointer or a float scalar is marked as const.
/// To get a better low-level understanding, the code in AdjointGenerator.h can be read.
///
/// This Function will only handle using ReverseMode AD (either split or combined).
/// As a high-level example, assume we want to register a custom derivative for a vector resize function.
/// We pass the mangled name as first argument.
/// resizing is a simple enough to be handled in the fwd pass, so we just pass a nullptr as RevHandle.
/// For the forward pass, we will need to resize the shadow of the input (if duplicated), so the CI
/// is a CallInst of resize on the shadow argument. The IRBuilder B, the shadow argument, and gutils
/// should all be provided available in the frontend. The last three arguments ...
///
void EnzymeRegisterCallHandler(char *Name,
CustomAugmentedFunctionForward FwdHandle,
CustomFunctionReverse RevHandle) {
Expand All @@ -363,6 +378,13 @@ void EnzymeRegisterCallHandler(char *Name,
};
}

/// This is the main entry point to register a custom derivative for language frontends.
/// It should be prefered over trying to register custom-derivatives in the llvm-ir module.
/// The main reason is that these rules can handle non-default activity cases, e.g.
/// a function call where a pointer or a float scalar is marked as const.
/// To get a better low-level understanding, the code in AdjointGenerator.h can be read.
///
/// This Function will only handle using ForwardMode AD.
void EnzymeRegisterFwdCallHandler(char *Name, CustomFunctionForward FwdHandle) {
auto &pair = customFwdCallHandlers[Name];
pair = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils,
Expand Down

0 comments on commit 3f3ee94

Please sign in to comment.