Skip to content

Commit

Permalink
Add Typetree to/from md
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 11, 2023
1 parent 051c838 commit 3e8d83c
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 5 deletions.
25 changes: 24 additions & 1 deletion enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,29 @@ void EnzymeExtractReturnInfo(EnzymeAugmentedReturnPtr ret, int64_t *data,
}
}

static MDNode *extractMDNode(MetadataAsValue *MAV) {
Metadata *MD = MAV->getMetadata();
assert((isa<MDNode>(MD) || isa<ConstantAsMetadata>(MD)) &&
"Expected a metadata node or a canonicalized constant");

if (MDNode *N = dyn_cast<MDNode>(MD))
return N;

return MDNode::get(MAV->getContext(), MD);
}

CTypeTreeRef EnzymeTypeTreeFromMD(LLVMValueRef Val) {
TypeTree *Ret = new TypeTree();
MDNode *N = Val ? extractMDNode(unwrap<MetadataAsValue>(Val)) : nullptr;
Ret->insertFromMD(N);
return (CTypeTreeRef)N;
}

LLVMValueRef EnzymeTypeTreeToMD(CTypeTreeRef CTR, LLVMContextRef ctx) {
auto MD = ((TypeTree *)CTR)->toMD(*unwrap(ctx));
return wrap(MetadataAsValue::get(MD->getContext(), MD));
}

CTypeTreeRef EnzymeNewTypeTree() { return (CTypeTreeRef)(new TypeTree()); }
CTypeTreeRef EnzymeNewTypeTreeCT(CConcreteType CT, LLVMContextRef ctx) {
return (CTypeTreeRef)(new TypeTree(eunwrap(CT, *unwrap(ctx))));
Expand Down Expand Up @@ -850,7 +873,7 @@ void EnzymeSetStringMD(LLVMValueRef Inst, const char *Kind, LLVMValueRef Val) {
LLVMValueRef EnzymeGetStringMD(LLVMValueRef Inst, const char *Kind) {
auto *I = unwrap<Instruction>(Inst);
assert(I && "Expected instruction");
if (auto *MD = I->getMetadata(KindID))
if (auto *MD = I->getMetadata(Kind))
return wrap(MetadataAsValue::get(I->getContext(), MD));
return nullptr;
}
Expand Down
4 changes: 4 additions & 0 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,10 @@ UpgradeAllocasToMallocs(Function *NewF, DerivativeMode mode,
{ConstantAsMetadata::get(ConstantInt::get(
IntegerType::get(AI->getContext(), 64), align))}));

for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type"})
if (auto M = AI->getMetadata(MD))
CI->setMetadata(MD, M);

if (rep != CI) {
cast<Instruction>(rep)->setMetadata("enzyme_caststack",
MDNode::get(CI->getContext(), {}));
Expand Down
5 changes: 3 additions & 2 deletions enzyme/Enzyme/TypeAnalysis/BaseType.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#ifndef ENZYME_TYPE_ANALYSIS_BASE_TYPE_H
#define ENZYME_TYPE_ANALYSIS_BASE_TYPE_H 1

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/ErrorHandling.h"
#include <string>

Expand Down Expand Up @@ -60,7 +61,7 @@ static inline std::string to_string(BaseType t) {
}

/// Convert string to BaseType
static inline BaseType parseBaseType(std::string str) {
static inline BaseType parseBaseType(llvm::StringRef str) {
if (str == "Integer")
return BaseType::Integer;
if (str == "Float")
Expand All @@ -73,4 +74,4 @@ static inline BaseType parseBaseType(std::string str) {
return BaseType::Unknown;
llvm_unreachable("Unknown BaseType string");
}
#endif
#endif
4 changes: 2 additions & 2 deletions enzyme/Enzyme/TypeAnalysis/ConcreteType.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ class ConcreteType {
/// Construct a ConcreteType from a string
/// A Concrete Type's string representation is given by the string of the
/// enum If it is a floating point it is given by Float@<specific_type>
ConcreteType(std::string Str, llvm::LLVMContext &C) {
ConcreteType(llvm::StringRef Str, llvm::LLVMContext &C) {
auto Sep = Str.find('@');
if (Sep != std::string::npos) {
if (Sep != llvm::StringRef::npos) {
SubTypeEnum = BaseType::Float;
assert(Str.substr(0, Sep) == "Float");
auto SubName = Str.substr(Sep + 1);
Expand Down
3 changes: 3 additions & 0 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,9 @@ void TypeAnalyzer::considerTBAA() {

for (BasicBlock &BB : *fntypeinfo.Function) {
for (Instruction &I : BB) {
if (auto MD = I.getMetadata("enzyme_type")) {
updateAnalysis(&I, TypeTree::fromMD(MD), &I);
}

if (CallBase *call = dyn_cast<CallBase>(&I)) {
Function *F = call->getCalledFunction();
Expand Down
46 changes: 46 additions & 0 deletions enzyme/Enzyme/TypeAnalysis/TypeTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,52 @@ class TypeTree : public std::enable_shared_from_this<TypeTree> {
out += "}";
return out;
}

llvm::MDNode *toMD(llvm::LLVMContext &ctx) {
llvm::SmallVector<llvm::Metadata *, 1> subMD;
std::map<int, TypeTree> todo;
ConcreteType base(BaseType::Unknown);
for (auto &pair : mapping) {
if (pair.first.size() == 0) {
base = pair.second;
continue;
}
auto next(pair.first);
next.erase(next.begin());
todo[pair.first[0]].mapping.insert(std::make_pair(next, pair.second));
}
subMD.push_back(llvm::MDString::get(ctx, base.str()));
for (auto pair : todo) {
subMD.push_back(llvm::ConstantAsMetadata::get(
llvm::ConstantInt::get(llvm::IntegerType::get(ctx, 32), pair.first)));
subMD.push_back(pair.second.toMD(ctx));
}
return llvm::MDNode::get(ctx, subMD);
};

void insertFromMD(llvm::MDNode *md, const std::vector<int> &prev = {}) {
ConcreteType base(
llvm::cast<llvm::MDString>(md->getOperand(0))->getString(),
md->getContext());
if (base != BaseType::Unknown)
mapping.insert(std::make_pair(prev, base));
for (size_t i = 1; i < md->getNumOperands(); i += 2) {
auto off = llvm::cast<llvm::ConstantInt>(
llvm::cast<llvm::ConstantAsMetadata>(md->getOperand(i))
->getValue())
->getSExtValue();
auto next(prev);
next.push_back((int)off);
insertFromMD(llvm::cast<llvm::MDNode>(md->getOperand(i + 1)), next);
}
}

static TypeTree fromMD(llvm::MDNode *md) {
TypeTree ret;
std::vector<int> off;
ret.insertFromMD(md, off);
return ret;
}
};

#endif

0 comments on commit 3e8d83c

Please sign in to comment.