Skip to content

Commit

Permalink
move away from using deprecated StringRef API
Browse files Browse the repository at this point in the history
llvm/llvm-project@5ac1295
deprecated the use of `StringRef::startswith` and `StringRef::endswith`
in favor of `starts_with` and `ends_with`. Introduce a helper and update
all uses in Enzyme codebase.
  • Loading branch information
ftynse committed Dec 18, 2023
1 parent fb6a1c6 commit 6ca6421
Show file tree
Hide file tree
Showing 16 changed files with 125 additions and 82 deletions.
13 changes: 10 additions & 3 deletions enzyme/BCLoad/BCLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ using namespace llvm;
#include "blas_headers.h"
#undef DATA

static inline bool endsWith(llvm::StringRef string, llvm::StringRef suffix) {
#if LLVM_VERSION_MAJOR >= 18
return string.ends_with(suffix);
#else
return string.endswith(suffix);
#endif // LLVM_VERSION_MAJOR
}

bool provideDefinitions(Module &M, std::set<std::string> ignoreFunctions = {}) {
std::vector<StringRef> todo;
bool seen32 = false;
Expand All @@ -30,7 +38,7 @@ bool provideDefinitions(Module &M, std::set<std::string> ignoreFunctions = {}) {
if (strlen(postfix) == 0) {
str = F.getName().str();
if (ignoreFunctions.count(str)) continue;
} else if (F.getName().endswith(postfix)) {
} else if (endsWith(F.getName(), postfix)) {
auto blasName =
F.getName().substr(0, F.getName().size() - strlen(postfix)).str();
if (ignoreFunctions.count(blasName)) continue;
Expand All @@ -44,8 +52,7 @@ bool provideDefinitions(Module &M, std::set<std::string> ignoreFunctions = {}) {
seen32 = true;
if (index == 2)
seen64 = true;
if (StringRef(str).endswith("gemm"))
seenGemm = true;
if (endsWith(str, "gemm")) seenGemm = true;
break;
}
index++;
Expand Down
18 changes: 9 additions & 9 deletions enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,20 +465,20 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) {
std::string demangledName = llvm::demangle(Name.str());
auto dName = StringRef(demangledName);
for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) {
if (dName.startswith(FuncName)) {
if (startsWith(dName, FuncName)) {
return true;
}
}
if (demangledName == Name.str()) {
// Either demangeling failed
// or they are equal but matching failed
// if (!Name.startswith("llvm."))
// if (!startsWith(Name, "llvm."))
// llvm::errs() << "matching failed: " << Name.str() << " "
// << demangledName << "\n";
}

for (auto FuncName : KnownInactiveFunctionsStartingWith) {
if (Name.startswith(FuncName)) {
if (startsWith(Name, FuncName)) {
return true;
}
}
Expand Down Expand Up @@ -1560,15 +1560,15 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {

auto dName = demangle(funcName.str());
for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) {
if (StringRef(dName).startswith(FuncName)) {
if (startsWith(dName, FuncName)) {
InsertConstantValue(TR, Val);
insertConstantsFrom(TR, *UpHypothesis);
return true;
}
}

for (auto FuncName : KnownInactiveFunctionsStartingWith) {
if (funcName.startswith(FuncName)) {
if (startsWith(funcName, FuncName)) {
InsertConstantValue(TR, Val);
insertConstantsFrom(TR, *UpHypothesis);
return true;
Expand Down Expand Up @@ -1886,13 +1886,13 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {

auto dName = demangle(funcName.str());
for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) {
if (StringRef(dName).startswith(FuncName)) {
if (startsWith(dName, FuncName)) {
return false;
}
}

for (auto FuncName : KnownInactiveFunctionsStartingWith) {
if (funcName.startswith(FuncName)) {
if (startsWith(funcName, FuncName)) {
return false;
}
}
Expand Down Expand Up @@ -2553,13 +2553,13 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR,

auto dName = demangle(funcName.str());
for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) {
if (StringRef(dName).startswith(FuncName)) {
if (startsWith(dName, FuncName)) {
return true;
}
}

for (auto FuncName : KnownInactiveFunctionsStartingWith) {
if (funcName.startswith(FuncName)) {
if (startsWith(funcName, FuncName)) {
return true;
}
}
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -5932,7 +5932,7 @@ class AdjointGenerator
// not fully understood by LLVM. One of the results of this is that the
// visitor dispatches to visitCallInst, rather than visitIntrinsicInst, when
// presented with the intrinsic - hence why we are handling it here.
if (getFuncNameFromCall(&call).startswith("llvm.intel.subscript")) {
if (startsWith(getFuncNameFromCall(&call), ("llvm.intel.subscript"))) {
assert(isa<IntrinsicInst>(call));
visitIntrinsicInst(cast<IntrinsicInst>(call));
return;
Expand Down
8 changes: 4 additions & 4 deletions enzyme/Enzyme/CallDerivatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2244,7 +2244,7 @@ bool AdjointGenerator<T>::handleKnownCallDerivatives(
}
}

if ((funcName.startswith("MPI_") || funcName.startswith("PMPI_")) &&
if ((startsWith(funcName, "MPI_") || startsWith(funcName, "PMPI_")) &&
(!gutils->isConstantInstruction(&call) || funcName == "MPI_Barrier" ||
funcName == "MPI_Comm_free" || funcName == "MPI_Comm_disconnect" ||
MPIInactiveCommAllocators.find(funcName) !=
Expand All @@ -2263,8 +2263,8 @@ bool AdjointGenerator<T>::handleKnownCallDerivatives(
}

if (funcName == "printf" || funcName == "puts" ||
funcName.startswith("_ZN3std2io5stdio6_print") ||
funcName.startswith("_ZN4core3fmt")) {
startsWith(funcName, "_ZN3std2io5stdio6_print") ||
startsWith(funcName, "_ZN4core3fmt")) {
if (Mode == DerivativeMode::ReverseModeGradient) {
eraseIfUnused(call, /*erase*/ true, /*check*/ false);
}
Expand Down Expand Up @@ -2353,7 +2353,7 @@ bool AdjointGenerator<T>::handleKnownCallDerivatives(
return true;
}

if (funcName.startswith("__kmpc") &&
if (startsWith(funcName, "__kmpc") &&
funcName != "__kmpc_global_thread_num") {
llvm::errs() << *gutils->oldFunc << "\n";
llvm::errs() << call << "\n";
Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/Clang/EnzymeClang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ class EnzymePlugin final : public clang::ASTConsumer {
std::string pluginPath;
#endif
for (auto P : Opts.Plugins)
if (llvm::sys::path::stem(P).endswith(PluginName)) {
if (endsWith(llvm::sys::path::stem(P), PluginName)) {
#if LLVM_VERSION_MAJOR < 18
pluginPath = P;
#endif
for (auto passPlugin : CGOpts.PassPlugins) {
if (llvm::sys::path::stem(passPlugin).endswith(PluginName)) {
if (endsWith(llvm::sys::path::stem(passPlugin), PluginName)) {
contains = true;
break;
}
Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ class EnzymeBase {
Value *res = CI->getArgOperand(i);
auto metaString = getMetadataName(res);
// handle metadata
if (metaString && metaString->startswith("enzyme_")) {
if (metaString && startsWith(*metaString, "enzyme_")) {
if (*metaString == "enzyme_const_return") {
retType = DIFFE_TYPE::CONSTANT;
continue;
Expand Down Expand Up @@ -862,7 +862,7 @@ class EnzymeBase {
bool skipArg = false;

// handle metadata
while (metaString && metaString->startswith("enzyme_")) {
while (metaString && startsWith(*metaString, "enzyme_")) {
if (*metaString == "enzyme_not_overwritten") {
overwritten = false;
} else if (*metaString == "enzyme_byref") {
Expand Down Expand Up @@ -1368,7 +1368,7 @@ class EnzymeBase {
auto metaString = getMetadataName(res);

// handle metadata
if (metaString && metaString->startswith("enzyme_")) {
if (metaString && startsWith(*metaString, "enzyme_")) {
if (*metaString == "enzyme_scalar") {
ty = BATCH_TYPE::SCALAR;
} else if (*metaString == "enzyme_vector") {
Expand Down
7 changes: 4 additions & 3 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ struct CacheAnalysis {
return {};
}

if (funcName.startswith("MPI_") || funcName.startswith("enzyme_wrapmpi$$"))
if (startsWith(funcName, "MPI_") ||
startsWith(funcName, "enzyme_wrapmpi$$"))
return {};

if (funcName == "__kmpc_for_static_init_4" ||
Expand Down Expand Up @@ -615,7 +616,7 @@ struct CacheAnalysis {
// We do not need uncacheable args for intrinsic functions. So skip
// such callsites.
if (auto II = dyn_cast<IntrinsicInst>(&inst)) {
if (!II->getCalledFunction()->getName().startswith("llvm.julia"))
if (!startsWith(II->getCalledFunction()->getName(), "llvm.julia"))
continue;
}

Expand Down Expand Up @@ -5314,7 +5315,7 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) {
"MPI_Allreduce",
};

if (F->getName().startswith("_ZNSolsE") || NoFrees.count(F->getName()))
if (startsWith(F->getName(), "_ZNSolsE") || NoFrees.count(F->getName()))
return F;

switch (F->getIntrinsicID()) {
Expand Down
13 changes: 7 additions & 6 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1131,12 +1131,13 @@ static void ForceRecursiveInlining(Function *NewF, size_t Limit) {
continue;
if (CI->getCalledFunction()->empty())
continue;
if (CI->getCalledFunction()->getName().startswith(
"_ZN3std2io5stdio6_print"))
if (startsWith(CI->getCalledFunction()->getName(),
"_ZN3std2io5stdio6_print"))
continue;
if (CI->getCalledFunction()->getName().startswith("_ZN4core3fmt"))
if (startsWith(CI->getCalledFunction()->getName(), "_ZN4core3fmt"))
continue;
if (CI->getCalledFunction()->getName().startswith("enzyme_wrapmpi$$"))
if (startsWith(CI->getCalledFunction()->getName(),
"enzyme_wrapmpi$$"))
continue;
if (CI->getCalledFunction()->hasFnAttribute(
Attribute::ReturnsTwice) ||
Expand Down Expand Up @@ -1539,7 +1540,7 @@ Function *PreProcessCache::preprocessForClone(Function *F,
if (F && F->getName().contains("__enzyme_double")) {
continue;
}
if (F && (F->getName().startswith("f90io") ||
if (F && (startsWith(F->getName(), "f90io") ||
F->getName() == "ftnio_fmt_write64" ||
F->getName() == "__mth_i_ipowi" ||
F->getName() == "f90_pausea")) {
Expand Down Expand Up @@ -1599,7 +1600,7 @@ Function *PreProcessCache::preprocessForClone(Function *F,
if (F && F->getName().contains("__enzyme_double")) {
continue;
}
if (F && (F->getName().startswith("f90io") ||
if (F && (startsWith(F->getName(), "f90io") ||
F->getName() == "ftnio_fmt_write64" ||
F->getName() == "__mth_i_ipowi" ||
F->getName() == "f90_pausea")) {
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/FunctionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ static inline void calculateUnusedValues(
}
}

if (false && oldFunc.getName().endswith("subfn")) {
if (false && endsWith(oldFunc.getName(), "subfn")) {
llvm::errs() << "Prepping values for: " << oldFunc.getName()
<< " returnValue: " << returnValue << "\n";
for (auto v : unnecessaryInstructions) {
Expand Down
8 changes: 4 additions & 4 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
}

std::pair<Value *, BasicBlock *> idx = std::make_pair(val, scope);
// assert(!val->getName().startswith("$tapeload"));
// assert(!startsWith(val->getName(), "$tapeload"));
if (permitCache) {
auto found0 = unwrap_cache.find(BuilderM.GetInsertBlock());
if (found0 != unwrap_cache.end()) {
Expand Down Expand Up @@ -4010,7 +4010,7 @@ bool GradientUtils::legalRecompute(const Value *val,
n == "lgammal_r" || n == "__lgamma_r_finite" ||
n == "__lgammaf_r_finite" || n == "__lgammal_r_finite" || n == "tanh" ||
n == "tanhf" || n == "__pow_finite" ||
n == "julia.pointer_from_objref" || n.startswith("enzyme_wrapmpi$$") ||
n == "julia.pointer_from_objref" || startsWith(n, "enzyme_wrapmpi$$") ||
n == "omp_get_thread_num" || n == "omp_get_max_threads") {
return true;
}
Expand Down Expand Up @@ -4160,7 +4160,7 @@ bool GradientUtils::shouldRecompute(const Value *val,
n == "lgammal_r" || n == "__lgamma_r_finite" ||
n == "__lgammaf_r_finite" || n == "__lgammal_r_finite" || n == "tanh" ||
n == "tanhf" || n == "__pow_finite" ||
n == "julia.pointer_from_objref" || n.startswith("enzyme_wrapmpi$$") ||
n == "julia.pointer_from_objref" || startsWith(n, "enzyme_wrapmpi$$") ||
n == "omp_get_thread_num" || n == "omp_get_max_threads") {
return true;
}
Expand Down Expand Up @@ -4451,7 +4451,7 @@ Constant *GradientUtils::GetOrCreateShadowConstant(
if (arg->getName() == "_ZTVN10__cxxabiv120__si_class_type_infoE" ||
arg->getName() == "_ZTVN10__cxxabiv117__class_type_infoE" ||
arg->getName() == "_ZTVN10__cxxabiv121__vmi_class_type_infoE" ||
arg->getName().startswith("??_R")) // any of the MS RTTI manglings
startsWith(arg->getName(), "??_R")) // any of the MS RTTI manglings
return arg;

if (hasMetadata(arg, "enzyme_shadow")) {
Expand Down
26 changes: 13 additions & 13 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,20 +282,20 @@ bool mlir::enzyme::ActivityAnalyzer::isFunctionArgumentConstant(
std::string demangledName = llvm::demangle(Name.str());
auto dName = StringRef(demangledName);
for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) {
if (dName.startswith(FuncName)) {
if (startsWith(dName, FuncName)) {
return true;
}
}
if (demangledName == Name.str()) {
// Either demangeling failed
// or they are equal but matching failed
// if (!Name.startswith("llvm."))
// if (!startsWith(Name, "llvm."))
// llvm::errs() << "matching failed: " << Name.str() << " "
// << demangledName << "\n";
}

for (auto FuncName : KnownInactiveFunctionsStartingWith) {
if (Name.startswith(FuncName)) {
if (startsWith(Name, FuncName)) {
return true;
}
}
Expand All @@ -315,7 +315,7 @@ bool mlir::enzyme::ActivityAnalyzer::isFunctionArgumentConstant(
}

for (unsigned intrinsicID : constantIntrinsics) {
if (Name.startswith(llvm::Intrinsic::getBaseName(intrinsicID)))
if (startsWith(Name, llvm::Intrinsic::getBaseName(intrinsicID)))
return true;
}

Expand Down Expand Up @@ -410,13 +410,13 @@ bool mlir::enzyme::ActivityAnalyzer::isFunctionArgumentConstant(
// // Certain intrinsics are inactive by definition
// // and have nothing to propagate.
// for (unsigned intrinsicID : constantIntrinsics) {
// if (Name.startswith(llvm::Intrinsic::getBaseName(intrinsicID)))
// if (startsWith(Name, llvm::Intrinsic::getBaseName(intrinsicID)))
// return;
// }

// if (Name.startswith(
// if (startsWith(Name,
// llvm::Intrinsic::getBaseName(llvm::Intrinsic::memcpy)) ||
// Name.startswith(
// startsWith(Name,
// llvm::Intrinsic::getBaseName(llvm::Intrinsic::memmove))) {
// propagateFromOperand(CI.getArgOperands()[0]);
// propagateFromOperand(CI.getArgOperands()[1]);
Expand Down Expand Up @@ -1586,15 +1586,15 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR,

auto dName = llvm::demangle(funcName.str());
for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) {
if (StringRef(dName).startswith(FuncName)) {
if (startsWith(dName, FuncName)) {
InsertConstantValue(TR, Val);
insertConstantsFrom(TR, *UpHypothesis);
return true;
}
}

for (auto FuncName : KnownInactiveFunctionsStartingWith) {
if (funcName.startswith(FuncName)) {
if (startsWith(funcName, FuncName)) {
InsertConstantValue(TR, Val);
insertConstantsFrom(TR, *UpHypothesis);
return true;
Expand Down Expand Up @@ -1868,13 +1868,13 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR,

auto dName = llvm::demangle(funcName.str());
for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) {
if (StringRef(dName).startswith(FuncName)) {
if (startsWith(dName, FuncName)) {
return false;
}
}

for (auto FuncName : KnownInactiveFunctionsStartingWith) {
if (funcName.startswith(FuncName)) {
if (startsWith(funcName, FuncName)) {
return false;
}
}
Expand Down Expand Up @@ -2594,13 +2594,13 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin(

auto dName = llvm::demangle(funcName.str());
for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) {
if (StringRef(dName).startswith(FuncName)) {
if (startsWith(dName, FuncName)) {
return true;
}
}

for (auto FuncName : KnownInactiveFunctionsStartingWith) {
if (funcName.startswith(FuncName)) {
if (startsWith(funcName, FuncName)) {
return true;
}
}
Expand Down
Loading

0 comments on commit 6ca6421

Please sign in to comment.