From 2dd80bf807dccc3c752be304c326f02684949316 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 6 Dec 2024 00:52:39 -0500 Subject: [PATCH] adressing feedback --- enzyme/Enzyme/CApi.cpp | 27 ++++++--------------------- enzyme/Enzyme/GradientUtils.h | 7 +++++++ 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index d6d7b1c6628..7192fb07862 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -341,20 +341,11 @@ void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, } -/// This is the main entry point to register a custom derivative for language frontends. -/// It should be prefered over trying to register custom-derivatives in the llvm-ir module. -/// The main reason is that these rules can handle non-default activity cases, e.g. -/// a function call where a pointer or a float scalar is marked as const. -/// To get a better low-level understanding, the code in AdjointGenerator.h can be read. -/// -/// This Function will only handle using ReverseMode AD (either split or combined). -/// As a high-level example, assume we want to register a custom derivative for a vector resize function. -/// We pass the mangled name as first argument. -/// resizing is a simple enough to be handled in the fwd pass, so we just pass a nullptr as RevHandle. -/// For the forward pass, we will need to resize the shadow of the input (if duplicated), so the CI -/// is a CallInst of resize on the shadow argument. The IRBuilder B, the shadow argument, and gutils -/// should all be provided available in the frontend. The last three arguments ... -/// +/// This Function will only handle ReverseMode AD (either split or combined). +/// As a high-level example, assume we want to register a custom derivative for `pow(x, y)`. +/// We pass the mangled name of pow as first argument. +/// The IRBuilder B, the shadow argument, and gutils should all be available in the frontend. +/// The last three arguments ... void EnzymeRegisterCallHandler(char *Name, CustomAugmentedFunctionForward FwdHandle, CustomFunctionReverse RevHandle) { @@ -378,13 +369,7 @@ void EnzymeRegisterCallHandler(char *Name, }; } -/// This is the main entry point to register a custom derivative for language frontends. -/// It should be prefered over trying to register custom-derivatives in the llvm-ir module. -/// The main reason is that these rules can handle non-default activity cases, e.g. -/// a function call where a pointer or a float scalar is marked as const. -/// To get a better low-level understanding, the code in AdjointGenerator.h can be read. -/// -/// This Function will only handle using ForwardMode AD. +/// This Function will only handle ForwardMode AD. void EnzymeRegisterFwdCallHandler(char *Name, CustomFunctionForward FwdHandle) { auto &pair = customFwdCallHandlers[Name]; pair = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils, diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index d739babf42e..2f9f1390fb4 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -79,6 +79,13 @@ extern llvm::StringMap &, llvm::CallInst *, GradientUtils &, llvm::Value *&, llvm::Value *&, llvm::Value *&)>,