diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index ca71867462f..d6d7b1c6628 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -340,6 +340,21 @@ void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, }; } + +/// This is the main entry point to register a custom derivative for language frontends. +/// It should be prefered over trying to register custom-derivatives in the llvm-ir module. +/// The main reason is that these rules can handle non-default activity cases, e.g. +/// a function call where a pointer or a float scalar is marked as const. +/// To get a better low-level understanding, the code in AdjointGenerator.h can be read. +/// +/// This Function will only handle using ReverseMode AD (either split or combined). +/// As a high-level example, assume we want to register a custom derivative for a vector resize function. +/// We pass the mangled name as first argument. +/// resizing is a simple enough to be handled in the fwd pass, so we just pass a nullptr as RevHandle. +/// For the forward pass, we will need to resize the shadow of the input (if duplicated), so the CI +/// is a CallInst of resize on the shadow argument. The IRBuilder B, the shadow argument, and gutils +/// should all be provided available in the frontend. The last three arguments ... +/// void EnzymeRegisterCallHandler(char *Name, CustomAugmentedFunctionForward FwdHandle, CustomFunctionReverse RevHandle) { @@ -363,6 +378,13 @@ void EnzymeRegisterCallHandler(char *Name, }; } +/// This is the main entry point to register a custom derivative for language frontends. +/// It should be prefered over trying to register custom-derivatives in the llvm-ir module. +/// The main reason is that these rules can handle non-default activity cases, e.g. +/// a function call where a pointer or a float scalar is marked as const. +/// To get a better low-level understanding, the code in AdjointGenerator.h can be read. +/// +/// This Function will only handle using ForwardMode AD. void EnzymeRegisterFwdCallHandler(char *Name, CustomFunctionForward FwdHandle) { auto &pair = customFwdCallHandlers[Name]; pair = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils,