diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 7efd06a8bd37..e2d896893537 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2271,6 +2271,66 @@ Function *GetFunctionFromValue(Value *fn) { } } } + if (auto LI = dyn_cast(fn)) { + auto obj = getBaseObject(LI->getPointerOperand()); + if (isa(obj)) { + std::set> done; + SmallVector, 1> todo; + Value *stored = nullptr; + bool legal = true; + for (auto U : obj->users()) { + if (auto I = dyn_cast(U)) + todo.push_back(std::make_pair(I, obj)); + else { + legal = false; + break; + } + } + while (legal && todo.size()) { + auto tup = todo.pop_back_val(); + if (done.count(tup)) + continue; + done.insert(tup); + auto cur = tup.first; + auto prev = tup.second; + if (auto SI = dyn_cast(cur)) + if (SI->getPointerOperand() == prev) { + if (stored == SI->getValueOperand()) + continue; + else if (stored == nullptr) { + stored = SI->getValueOperand(); + continue; + } else { + legal = false; + break; + } + } + + if (isPointerArithmeticInst(cur, /*includephi*/ true)) { + for (auto U : cur->users()) { + if (auto I = dyn_cast(U)) + todo.push_back(std::make_pair(I, cur)); + else { + legal = false; + break; + } + } + continue; + } + + if (!cur->mayWriteToMemory() && cur->getType()->isVoidTy()) + continue; + + legal = false; + break; + } + + if (legal && stored) { + fn = stored; + continue; + } + } + } break; }