diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index e1f9d21dc37..933a22304e2 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -929,8 +929,25 @@ void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, FunctionType *FT = FunctionType::get(copy_retty, tys, false); auto fn = M.getOrInsertFunction(copy_name, FT); - Function *F = cast(fn.getCallee()); - attributeKnownFunctions(*F); + Value *callVal = fn.getCallee(); + Function *called = nullptr; + while (!called) { + if (auto castinst = dyn_cast(callVal)) + if (castinst->isCast()) { + callVal = castinst->getOperand(0); + continue; + } + if (auto fn = dyn_cast(callVal)) { + called = fn; + break; + } + if (auto alias = dyn_cast(callVal)) { + callVal = alias->getAliasee(); + continue; + } + break; + } + attributeKnownFunctions(*called); B.CreateCall(fn, args, bundles); }