From 29c265661b3e30d634d49ad68ef65927609f019a Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 14 Dec 2023 16:20:03 -0500 Subject: [PATCH] Make nicer sparse arg error --- enzyme/Enzyme/FunctionUtils.cpp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 2e0ad89b03ae..ccc74f2e4f8a 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -6072,9 +6072,21 @@ bool LowerSparsification(llvm::Function *F, bool replaceAll) { auto diff = toInt(B, replacements[SI->getPointerOperand()]); SmallVector args; args.push_back(SI->getValueOperand()); - if (args[0]->getType() != store_fn->getFunctionType()->getParamType(0)) + auto sty = store_fn->getFunctionType()->getParamType(0); + if (args[0]->getType() != + store_fn->getFunctionType()->getParamType(0)) { + if (CastInst::castIsValid (Instruction::BitCast, args[0], sty)) args[0] = B.CreateBitCast( - args[0], store_fn->getFunctionType()->getParamType(0)); + args[0], sty); + else { + auto args0ty = args[0]->getType(); + EmitFailure("IllegalSparse", CI->getDebugLoc(), CI, + " first argument of store function must be the type of " + "the store found fn arg type ", + sty, + " expected ", args0ty); + } + } args.push_back(diff); for (size_t i = argstart; i < num_args; i++) args.push_back(CI->getArgOperand(i));