From 3e8d83ca09b8ad52fa0cc5f372c319192486653d Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 11 Oct 2023 15:43:08 -0400 Subject: [PATCH] Add Typetree to/from md --- enzyme/Enzyme/CApi.cpp | 25 ++++++++++- enzyme/Enzyme/FunctionUtils.cpp | 4 ++ enzyme/Enzyme/TypeAnalysis/BaseType.h | 5 ++- enzyme/Enzyme/TypeAnalysis/ConcreteType.h | 4 +- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 3 ++ enzyme/Enzyme/TypeAnalysis/TypeTree.h | 46 +++++++++++++++++++++ 6 files changed, 82 insertions(+), 5 deletions(-) diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 7a0ab6e344ee..02fdc43d7bad 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -711,6 +711,29 @@ void EnzymeExtractReturnInfo(EnzymeAugmentedReturnPtr ret, int64_t *data, } } +static MDNode *extractMDNode(MetadataAsValue *MAV) { + Metadata *MD = MAV->getMetadata(); + assert((isa(MD) || isa(MD)) && + "Expected a metadata node or a canonicalized constant"); + + if (MDNode *N = dyn_cast(MD)) + return N; + + return MDNode::get(MAV->getContext(), MD); +} + +CTypeTreeRef EnzymeTypeTreeFromMD(LLVMValueRef Val) { + TypeTree *Ret = new TypeTree(); + MDNode *N = Val ? extractMDNode(unwrap(Val)) : nullptr; + Ret->insertFromMD(N); + return (CTypeTreeRef)N; +} + +LLVMValueRef EnzymeTypeTreeToMD(CTypeTreeRef CTR, LLVMContextRef ctx) { + auto MD = ((TypeTree *)CTR)->toMD(*unwrap(ctx)); + return wrap(MetadataAsValue::get(MD->getContext(), MD)); +} + CTypeTreeRef EnzymeNewTypeTree() { return (CTypeTreeRef)(new TypeTree()); } CTypeTreeRef EnzymeNewTypeTreeCT(CConcreteType CT, LLVMContextRef ctx) { return (CTypeTreeRef)(new TypeTree(eunwrap(CT, *unwrap(ctx)))); @@ -850,7 +873,7 @@ void EnzymeSetStringMD(LLVMValueRef Inst, const char *Kind, LLVMValueRef Val) { LLVMValueRef EnzymeGetStringMD(LLVMValueRef Inst, const char *Kind) { auto *I = unwrap(Inst); assert(I && "Expected instruction"); - if (auto *MD = I->getMetadata(KindID)) + if (auto *MD = I->getMetadata(Kind)) return wrap(MetadataAsValue::get(I->getContext(), MD)); return nullptr; } diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 58623d5d0942..43c1d69483d3 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -494,6 +494,10 @@ UpgradeAllocasToMallocs(Function *NewF, DerivativeMode mode, {ConstantAsMetadata::get(ConstantInt::get( IntegerType::get(AI->getContext(), 64), align))})); + for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type"}) + if (auto M = AI->getMetadata(MD)) + CI->setMetadata(MD, M); + if (rep != CI) { cast(rep)->setMetadata("enzyme_caststack", MDNode::get(CI->getContext(), {})); diff --git a/enzyme/Enzyme/TypeAnalysis/BaseType.h b/enzyme/Enzyme/TypeAnalysis/BaseType.h index 88da154eab52..9ba5bd2bec5c 100644 --- a/enzyme/Enzyme/TypeAnalysis/BaseType.h +++ b/enzyme/Enzyme/TypeAnalysis/BaseType.h @@ -25,6 +25,7 @@ #ifndef ENZYME_TYPE_ANALYSIS_BASE_TYPE_H #define ENZYME_TYPE_ANALYSIS_BASE_TYPE_H 1 +#include "llvm/ADT/StringRef.h" #include "llvm/Support/ErrorHandling.h" #include @@ -60,7 +61,7 @@ static inline std::string to_string(BaseType t) { } /// Convert string to BaseType -static inline BaseType parseBaseType(std::string str) { +static inline BaseType parseBaseType(llvm::StringRef str) { if (str == "Integer") return BaseType::Integer; if (str == "Float") @@ -73,4 +74,4 @@ static inline BaseType parseBaseType(std::string str) { return BaseType::Unknown; llvm_unreachable("Unknown BaseType string"); } -#endif \ No newline at end of file +#endif diff --git a/enzyme/Enzyme/TypeAnalysis/ConcreteType.h b/enzyme/Enzyme/TypeAnalysis/ConcreteType.h index 5caf2e4fa94e..c0fd754c17b4 100644 --- a/enzyme/Enzyme/TypeAnalysis/ConcreteType.h +++ b/enzyme/Enzyme/TypeAnalysis/ConcreteType.h @@ -65,9 +65,9 @@ class ConcreteType { /// Construct a ConcreteType from a string /// A Concrete Type's string representation is given by the string of the /// enum If it is a floating point it is given by Float@ - ConcreteType(std::string Str, llvm::LLVMContext &C) { + ConcreteType(llvm::StringRef Str, llvm::LLVMContext &C) { auto Sep = Str.find('@'); - if (Sep != std::string::npos) { + if (Sep != llvm::StringRef::npos) { SubTypeEnum = BaseType::Float; assert(Str.substr(0, Sep) == "Float"); auto SubName = Str.substr(Sep + 1); diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 866480f0e80a..e34bf91f7653 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -788,6 +788,9 @@ void TypeAnalyzer::considerTBAA() { for (BasicBlock &BB : *fntypeinfo.Function) { for (Instruction &I : BB) { + if (auto MD = I.getMetadata("enzyme_type")) { + updateAnalysis(&I, TypeTree::fromMD(MD), &I); + } if (CallBase *call = dyn_cast(&I)) { Function *F = call->getCalledFunction(); diff --git a/enzyme/Enzyme/TypeAnalysis/TypeTree.h b/enzyme/Enzyme/TypeAnalysis/TypeTree.h index bf31de9f8b9f..858f07cc5873 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeTree.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeTree.h @@ -1207,6 +1207,52 @@ class TypeTree : public std::enable_shared_from_this { out += "}"; return out; } + + llvm::MDNode *toMD(llvm::LLVMContext &ctx) { + llvm::SmallVector subMD; + std::map todo; + ConcreteType base(BaseType::Unknown); + for (auto &pair : mapping) { + if (pair.first.size() == 0) { + base = pair.second; + continue; + } + auto next(pair.first); + next.erase(next.begin()); + todo[pair.first[0]].mapping.insert(std::make_pair(next, pair.second)); + } + subMD.push_back(llvm::MDString::get(ctx, base.str())); + for (auto pair : todo) { + subMD.push_back(llvm::ConstantAsMetadata::get( + llvm::ConstantInt::get(llvm::IntegerType::get(ctx, 32), pair.first))); + subMD.push_back(pair.second.toMD(ctx)); + } + return llvm::MDNode::get(ctx, subMD); + }; + + void insertFromMD(llvm::MDNode *md, const std::vector &prev = {}) { + ConcreteType base( + llvm::cast(md->getOperand(0))->getString(), + md->getContext()); + if (base != BaseType::Unknown) + mapping.insert(std::make_pair(prev, base)); + for (size_t i = 1; i < md->getNumOperands(); i += 2) { + auto off = llvm::cast( + llvm::cast(md->getOperand(i)) + ->getValue()) + ->getSExtValue(); + auto next(prev); + next.push_back((int)off); + insertFromMD(llvm::cast(md->getOperand(i + 1)), next); + } + } + + static TypeTree fromMD(llvm::MDNode *md) { + TypeTree ret; + std::vector off; + ret.insertFromMD(md, off); + return ret; + } }; #endif