Skip to content

Commit

Permalink
Nice error message for undifferentiable functions
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 21, 2023
1 parent 21dcb51 commit e0737ec
Show file tree
Hide file tree
Showing 17 changed files with 500 additions and 212 deletions.
23 changes: 14 additions & 9 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3896,8 +3896,9 @@ class AdjointGenerator
Mode == DerivativeMode::ReverseModeCombined) {
if (called) {
subdata = &gutils->Logic.CreateAugmentedPrimal(
cast<Function>(called), subretType, argsInverted,
TR.analyzer.interprocedural, /*return is used*/ false,
RequestContext(&call, &BuilderZ), cast<Function>(called),
subretType, argsInverted, TR.analyzer.interprocedural,
/*return is used*/ false,
/*shadowReturnUsed*/ false, nextTypeInfo, overwritten_args, false,
gutils->getWidth(),
/*AtomicAdd*/ true,
Expand Down Expand Up @@ -4096,6 +4097,7 @@ class AdjointGenerator
}

newcalled = gutils->Logic.CreatePrimalAndGradient(
RequestContext(&call, &Builder2),
(ReverseCacheKey){.todiff = cast<Function>(called),
.retType = subretType,
.constant_args = argsInverted,
Expand Down Expand Up @@ -6851,8 +6853,9 @@ class AdjointGenerator

if (called) {
newcalled = gutils->Logic.CreateForwardDiff(
cast<Function>(called), subretType, argsInverted,
TR.analyzer.interprocedural, /*returnValue*/ subretused, Mode,
RequestContext(&call, &BuilderZ), cast<Function>(called),
subretType, argsInverted, TR.analyzer.interprocedural,
/*returnValue*/ subretused, Mode,
((DiffeGradientUtils *)gutils)->FreeMemory, gutils->getWidth(),
tape ? tape->getType() : nullptr, nextTypeInfo, overwritten_args,
/*augmented*/ subdata);
Expand Down Expand Up @@ -7254,10 +7257,10 @@ class AdjointGenerator
if (Mode == DerivativeMode::ReverseModePrimal ||
Mode == DerivativeMode::ReverseModeCombined) {
subdata = &gutils->Logic.CreateAugmentedPrimal(
cast<Function>(called), subretType, argsInverted,
TR.analyzer.interprocedural, /*return is used*/ subretused,
shadowReturnUsed, nextTypeInfo, overwritten_args, false,
gutils->getWidth(), gutils->AtomicAdd);
RequestContext(&call, &BuilderZ), cast<Function>(called),
subretType, argsInverted, TR.analyzer.interprocedural,
/*return is used*/ subretused, shadowReturnUsed, nextTypeInfo,
overwritten_args, false, gutils->getWidth(), gutils->AtomicAdd);
if (Mode == DerivativeMode::ReverseModePrimal) {
assert(augmentedReturn);
auto subaugmentations =
Expand Down Expand Up @@ -7639,6 +7642,7 @@ class AdjointGenerator
}

newcalled = gutils->Logic.CreatePrimalAndGradient(
RequestContext(&call, &Builder2),
(ReverseCacheKey){.todiff = cast<Function>(called),
.retType = subretType,
.constant_args = argsInverted,
Expand Down Expand Up @@ -10066,7 +10070,8 @@ class AdjointGenerator
auto callval = call.getCalledOperand();
if (!isa<Constant>(callval))
callval = gutils->getNewFromOriginal(callval);
newCall->setCalledOperand(gutils->Logic.CreateNoFree(callval));
newCall->setCalledOperand(gutils->Logic.CreateNoFree(
RequestContext(&call, &BuilderZ), callval));
}
if (gutils->knownRecomputeHeuristic.find(&call) !=
gutils->knownRecomputeHeuristic.end()) {
Expand Down
51 changes: 37 additions & 14 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,11 +541,11 @@ void EnzymeGradientUtilsSubTransferHelper(
}

LLVMValueRef EnzymeCreateForwardDiff(
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode,
uint8_t freeMemory, unsigned width, LLVMTypeRef additionalArg,
CFnTypeInfo typeInfo, uint8_t *_overwritten_args,
EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue,
CDerivativeMode mode, uint8_t freeMemory, unsigned width,
LLVMTypeRef additionalArg, CFnTypeInfo typeInfo, uint8_t *_overwritten_args,
size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented) {
SmallVector<DIFFE_TYPE, 4> nconstant_args((DIFFE_TYPE *)constant_args,
(DIFFE_TYPE *)constant_args +
Expand All @@ -556,16 +556,18 @@ LLVMValueRef EnzymeCreateForwardDiff(
overwritten_args.push_back(_overwritten_args[i]);
}
return wrap(eunwrap(Logic).CreateForwardDiff(
RequestContext(cast_or_null<Instruction>(unwrap(request_req)),
unwrap(request_ip)),
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
eunwrap(TA), returnValue, (DerivativeMode)mode, freeMemory, width,
unwrap(additionalArg), eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
overwritten_args, eunwrap(augmented)));
}
LLVMValueRef EnzymeCreatePrimalAndGradient(
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed,
CDerivativeMode mode, unsigned width, uint8_t freeMemory,
EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue,
uint8_t dretUsed, CDerivativeMode mode, unsigned width, uint8_t freeMemory,
LLVMTypeRef additionalArg, uint8_t forceAnonymousTape, CFnTypeInfo typeInfo,
uint8_t *_overwritten_args, size_t overwritten_args_size,
EnzymeAugmentedReturnPtr augmented, uint8_t AtomicAdd) {
Expand All @@ -578,6 +580,8 @@ LLVMValueRef EnzymeCreatePrimalAndGradient(
overwritten_args.push_back(_overwritten_args[i]);
}
return wrap(eunwrap(Logic).CreatePrimalAndGradient(
RequestContext(cast<Instruction>(unwrap(request_req)),
unwrap(request_ip)),
(ReverseCacheKey){
.todiff = cast<Function>(unwrap(todiff)),
.retType = (DIFFE_TYPE)retType,
Expand All @@ -596,10 +600,10 @@ LLVMValueRef EnzymeCreatePrimalAndGradient(
eunwrap(TA), eunwrap(augmented)));
}
EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnUsed, uint8_t shadowReturnUsed,
CFnTypeInfo typeInfo, uint8_t *_overwritten_args,
EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnUsed,
uint8_t shadowReturnUsed, CFnTypeInfo typeInfo, uint8_t *_overwritten_args,
size_t overwritten_args_size, uint8_t forceAnonymousTape, unsigned width,
uint8_t AtomicAdd) {

Expand All @@ -612,14 +616,31 @@ EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
overwritten_args.push_back(_overwritten_args[i]);
}
return ewrap(eunwrap(Logic).CreateAugmentedPrimal(
RequestContext(cast_or_null<Instruction>(unwrap(request_req)),
unwrap(request_ip)),
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
eunwrap(TA), returnUsed, shadowReturnUsed,
eunwrap(typeInfo, cast<Function>(unwrap(todiff))), overwritten_args,
forceAnonymousTape, width, AtomicAdd));
}

LLVMValueRef EnzymeCreateBatch(EnzymeLogicRef Logic, LLVMValueRef request_req,
LLVMBuilderRef request_ip, LLVMValueRef tobatch,
unsigned width, CBATCH_TYPE *arg_types,
size_t arg_types_size, CBATCH_TYPE retType) {

return wrap(eunwrap(Logic).CreateBatch(
RequestContext(cast_or_null<Instruction>(unwrap(request_req)),
unwrap(request_ip)),
cast<Function>(unwrap(tobatch)), width,
ArrayRef<BATCH_TYPE>((BATCH_TYPE *)arg_types,
(BATCH_TYPE *)arg_types + arg_types_size),
(BATCH_TYPE)retType));
}

LLVMValueRef EnzymeCreateTrace(
EnzymeLogicRef Logic, LLVMValueRef totrace, LLVMValueRef *sample_functions,
EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
LLVMValueRef totrace, LLVMValueRef *sample_functions,
size_t sample_functions_size, LLVMValueRef *observe_functions,
size_t observe_functions_size, const char *active_random_variables[],
size_t active_random_variables_size, CProbProgMode mode, uint8_t autodiff,
Expand All @@ -641,6 +662,8 @@ LLVMValueRef EnzymeCreateTrace(
}

return wrap(eunwrap(Logic).CreateTrace(
RequestContext(cast_or_null<Instruction>(unwrap(request_req)),
unwrap(request_ip)),
cast<Function>(unwrap(totrace)), SampleFunctions, ObserveFunctions,
ActiveRandomVariables, (ProbProgMode)mode, (bool)autodiff,
eunwrap(interface)));
Expand Down
36 changes: 2 additions & 34 deletions enzyme/Enzyme/CApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ typedef enum {
// but don't need the forward
} CDIFFE_TYPE;

typedef enum { BT_SCALAR = 0, BT_VECTOR = 1 } CBATCH_TYPE;

typedef enum {
DEM_ForwardMode = 0,
DEM_ReverseModePrimal = 1,
Expand All @@ -132,40 +134,6 @@ typedef enum {
DEM_Condition = 1,
} CProbProgMode;

LLVMValueRef EnzymeCreateForwardDiff(
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode,
uint8_t freeMemory, unsigned width, LLVMTypeRef additionalArg,
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
size_t uncacheable_args_size, EnzymeAugmentedReturnPtr augmented);

LLVMValueRef EnzymeCreatePrimalAndGradient(
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed,
CDerivativeMode mode, unsigned width, uint8_t freeMemory,
LLVMTypeRef additionalArg, uint8_t forceAnonymousTape,
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
size_t uncacheable_args_size, EnzymeAugmentedReturnPtr augmented,
uint8_t AtomicAdd);

EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnUsed, uint8_t shadowReturnUsed,
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
size_t uncacheable_args_size, uint8_t forceAnonymousTape, unsigned width,
uint8_t AtomicAdd);

LLVMValueRef CreateTrace(
EnzymeLogicRef Logic, LLVMValueRef totrace, LLVMValueRef *sample_functions,
size_t sample_functions_size, LLVMValueRef *observe_functions,
size_t observe_functions_size, LLVMValueRef *generative_functions,
size_t generative_functions_size, const char *active_random_variables[],
size_t active_random_variables_size, CProbProgMode mode, uint8_t autodiff,
EnzymeTraceInterfaceRef interface);

typedef uint8_t (*CustomRuleType)(int /*direction*/, CTypeTreeRef /*return*/,
CTypeTreeRef * /*args*/,
struct IntList * /*knownValues*/,
Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/DiffeGradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
TargetLibraryInfo &TLI, TypeAnalysis &TA, FnTypeInfo &oldTypeInfo,
DIFFE_TYPE retType, bool diffeReturnArg, ArrayRef<DIFFE_TYPE> constant_args,
ReturnType returnValue, Type *additionalArg, bool omp) {
assert(!todiff->empty());
Function *oldFunc = todiff;
assert(mode == DerivativeMode::ReverseModeGradient ||
mode == DerivativeMode::ReverseModeCombined ||
Expand Down Expand Up @@ -149,7 +148,8 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
}

TypeResults TR = TA.analyzeFunction(typeInfo);
assert(TR.getFunction() == oldFunc);
if (!oldFunc->empty())
assert(TR.getFunction() == oldFunc);

auto res = new DiffeGradientUtils(Logic, newFunc, oldFunc, TLI, TA, TR,
invertedPointers, constant_values,
Expand Down
26 changes: 16 additions & 10 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,8 @@ class EnzymeBase {
? BATCH_TYPE::SCALAR
: BATCH_TYPE::VECTOR;

auto newFunc = Logic.CreateBatch(F, width, arg_types, ret_type);
auto newFunc = Logic.CreateBatch(RequestContext(CI, &Builder), F, width,
arg_types, ret_type);

if (!newFunc)
return false;
Expand Down Expand Up @@ -1432,6 +1433,7 @@ class EnzymeBase {
populate_overwritten_args(TA, fn, mode, overwritten_args);

IRBuilder Builder(CI);
RequestContext context(CI, &Builder);

// differentiate fn
Function *newFunc = nullptr;
Expand All @@ -1440,15 +1442,15 @@ class EnzymeBase {
switch (mode) {
case DerivativeMode::ForwardMode:
newFunc = Logic.CreateForwardDiff(
fn, retType, constants, TA,
context, fn, retType, constants, TA,
/*should return*/ primalReturn, mode, freeMemory, width,
/*addedType*/ nullptr, type_args, overwritten_args,
/*augmented*/ nullptr);
break;
case DerivativeMode::ForwardModeSplit: {
bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1;
aug = &Logic.CreateAugmentedPrimal(
fn, retType, constants, TA,
context, fn, retType, constants, TA,
/*returnUsed*/ false, /*shadowReturnUsed*/ false, type_args,
overwritten_args, forceAnonymousTape, width, /*atomicAdd*/ AtomicAdd);
auto &DL = fn->getParent()->getDataLayout();
Expand Down Expand Up @@ -1484,14 +1486,15 @@ class EnzymeBase {
tapeType = PointerType::getInt8PtrTy(fn->getContext());
}
newFunc = Logic.CreateForwardDiff(
fn, retType, constants, TA,
context, fn, retType, constants, TA,
/*should return*/ primalReturn, mode, freeMemory, width,
/*addedType*/ tapeType, type_args, overwritten_args, aug);
break;
}
case DerivativeMode::ReverseModeCombined:
assert(freeMemory);
newFunc = Logic.CreatePrimalAndGradient(
context,
(ReverseCacheKey){.todiff = fn,
.retType = retType,
.constant_args = constants,
Expand All @@ -1518,8 +1521,8 @@ class EnzymeBase {
bool shadowReturnUsed = returnUsed && (retType == DIFFE_TYPE::DUP_ARG ||
retType == DIFFE_TYPE::DUP_NONEED);
aug = &Logic.CreateAugmentedPrimal(
fn, retType, constants, TA, returnUsed, shadowReturnUsed, type_args,
overwritten_args, forceAnonymousTape, width,
context, fn, retType, constants, TA, returnUsed, shadowReturnUsed,
type_args, overwritten_args, forceAnonymousTape, width,
/*atomicAdd*/ AtomicAdd);
auto &DL = fn->getParent()->getDataLayout();
if (!forceAnonymousTape) {
Expand Down Expand Up @@ -1557,6 +1560,7 @@ class EnzymeBase {
newFunc = aug->fn;
else
newFunc = Logic.CreatePrimalAndGradient(
context,
(ReverseCacheKey){.todiff = fn,
.retType = retType,
.constant_args = constants,
Expand Down Expand Up @@ -1856,9 +1860,9 @@ class EnzymeBase {
constants.push_back(DIFFE_TYPE::CONSTANT);
}

auto newFunc = Logic.CreateTrace(F, sampleFunctions, observeFunctions,
opt->ActiveRandomVariables, mode, autodiff,
interface);
auto newFunc = Logic.CreateTrace(
RequestContext(CI, &Builder), F, sampleFunctions, observeFunctions,
opt->ActiveRandomVariables, mode, autodiff, interface);

if (!autodiff) {
auto call = CallInst::Create(newFunc->getFunctionType(), newFunc, args);
Expand Down Expand Up @@ -2438,8 +2442,10 @@ class EnzymeBase {
bool AtomicAdd = Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
Arch == Triple::amdgcn;

IRBuilder<> Builder(CI);
auto val = GradientUtils::GetOrCreateShadowConstant(
Logic, Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(F), TA, fn,
RequestContext(CI, &Builder), Logic,
Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(F), TA, fn,
pair.second, /*width*/ 1, AtomicAdd);
CI->replaceAllUsesWith(ConstantExpr::getPointerCast(val, CI->getType()));
CI->eraseFromParent();
Expand Down
Loading

0 comments on commit e0737ec

Please sign in to comment.