From 6a707777e364fd060d6f21da78f0ed371322c7f5 Mon Sep 17 00:00:00 2001 From: Max Aehle Date: Thu, 4 Jan 2024 19:08:01 +0100 Subject: [PATCH] Let IsAllFloat take an additional DataLayout argument and use it to simplify the implementation of IsAllFloat. --- enzyme/Enzyme/ActivityAnalysis.cpp | 2 +- enzyme/Enzyme/AdjointGenerator.h | 12 ++++++------ enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 8 ++++---- enzyme/Enzyme/TypeAnalysis/TypeTree.h | 16 ++-------------- 4 files changed, 13 insertions(+), 25 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 8ada308d4977..361541b954dd 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -1352,7 +1352,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { for (int i = 0; i < 2; ++i) { auto FT = TR.query(BO->getOperand(1 - i)) - .IsAllFloat((DL.getTypeSizeInBits(BO->getType()) + 7) / 8); + .IsAllFloat((DL.getTypeSizeInBits(BO->getType()) + 7) / 8, DL); // If ^ against 0b10000000000 and a float the result is a float if (FT) if (containsOnlyAtMostTopBit(BO->getOperand(i), FT, DL)) { diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 86ba20134f97..7dc91bec08d5 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -2358,7 +2358,7 @@ class AdjointGenerator auto &dl = gutils->oldFunc->getParent()->getDataLayout(); auto size = dl.getTypeSizeInBits(BO.getType()) / 8; - auto FT = TR.query(&BO).IsAllFloat(size); + auto FT = TR.query(&BO).IsAllFloat(size, dl); auto eFT = FT; if (FT) for (int i = 0; i < 2; ++i) { @@ -2388,7 +2388,7 @@ class AdjointGenerator auto &dl = gutils->oldFunc->getParent()->getDataLayout(); auto size = dl.getTypeSizeInBits(BO.getType()) / 8; - auto FT = TR.query(&BO).IsAllFloat(size); + auto FT = TR.query(&BO).IsAllFloat(size, dl); auto eFT = FT; // If ^ against 0b10000000000 and a float the result is a float if (FT) @@ -2426,7 +2426,7 @@ class AdjointGenerator auto &dl = gutils->oldFunc->getParent()->getDataLayout(); auto size = dl.getTypeSizeInBits(BO.getType()) / 8; - auto FT = TR.query(&BO).IsAllFloat(size); + auto FT = TR.query(&BO).IsAllFloat(size, dl); auto eFT = FT; // If & against 0b10000000000 and a float the result is a float if (FT) @@ -2603,7 +2603,7 @@ class AdjointGenerator auto size = dl.getTypeSizeInBits(BO.getType()) / 8; Type *diffTy = gutils->getShadowType(BO.getType()); - auto FT = TR.query(&BO).IsAllFloat(size); + auto FT = TR.query(&BO).IsAllFloat(size, dl); auto eFT = FT; if (FT) for (int i = 0; i < 2; ++i) { @@ -2631,7 +2631,7 @@ class AdjointGenerator auto &dl = gutils->oldFunc->getParent()->getDataLayout(); auto size = dl.getTypeSizeInBits(BO.getType()) / 8; - auto FT = TR.query(&BO).IsAllFloat(size); + auto FT = TR.query(&BO).IsAllFloat(size, dl); auto eFT = FT; Value *dif[2] = {constantval0 ? nullptr : diffe(orig_op0, Builder2), @@ -2672,7 +2672,7 @@ class AdjointGenerator Value *dif[2] = {constantval0 ? nullptr : diffe(orig_op0, Builder2), constantval1 ? nullptr : diffe(orig_op1, Builder2)}; - auto FT = TR.query(&BO).IsAllFloat(size); + auto FT = TR.query(&BO).IsAllFloat(size, dl); auto eFT = FT; // If & against 0b10000000000 and a float the result is a float if (FT) diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index bd8d5d85c4ad..381cd0729114 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -2926,7 +2926,7 @@ void TypeAnalyzer::visitBinaryOperation(const DataLayout &dl, llvm::Type *T, if (direction & UP) for (int i = 0; i < 2; ++i) { Type *FT = nullptr; - if (!(FT = Ret.IsAllFloat(size))) + if (!(FT = Ret.IsAllFloat(size, dl))) continue; // If ^ against 0b10000000000, the result is a float bool validXor = containsOnlyAtMostTopBit(Args[i], FT, dl); @@ -2938,7 +2938,7 @@ void TypeAnalyzer::visitBinaryOperation(const DataLayout &dl, llvm::Type *T, case BinaryOperator::Or: for (int i = 0; i < 2; ++i) { Type *FT = nullptr; - if (!(FT = Ret.IsAllFloat(size))) + if (!(FT = Ret.IsAllFloat(size, dl))) continue; // If | against a number only or'ing the exponent, the result is a float bool validXor = false; @@ -3143,7 +3143,7 @@ void TypeAnalyzer::visitBinaryOperation(const DataLayout &dl, llvm::Type *T, } else if (Opcode == BinaryOperator::Xor) { for (int i = 0; i < 2; ++i) { Type *FT; - if (!(FT = (i == 0 ? RHS : LHS).IsAllFloat(size))) + if (!(FT = (i == 0 ? RHS : LHS).IsAllFloat(size, dl))) continue; // If ^ against 0b10000000000, the result is a float bool validXor = containsOnlyAtMostTopBit(Args[i], FT, dl); @@ -3154,7 +3154,7 @@ void TypeAnalyzer::visitBinaryOperation(const DataLayout &dl, llvm::Type *T, } else if (Opcode == BinaryOperator::Or) { for (int i = 0; i < 2; ++i) { Type *FT; - if (!(FT = (i == 0 ? RHS : LHS).IsAllFloat(size))) + if (!(FT = (i == 0 ? RHS : LHS).IsAllFloat(size, dl))) continue; // If & against 0b10000000000, the result is a float bool validXor = false; diff --git a/enzyme/Enzyme/TypeAnalysis/TypeTree.h b/enzyme/Enzyme/TypeAnalysis/TypeTree.h index 4bcd6fa09cb4..b69b83e5b355 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeTree.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeTree.h @@ -698,7 +698,7 @@ class TypeTree : public std::enable_shared_from_this { return dat; } - llvm::Type *IsAllFloat(const size_t size) const { + llvm::Type *IsAllFloat(const size_t size, const llvm::DataLayout &dl) const { auto m1 = TypeTree::operator[]({-1}); if (auto FT = m1.isFloat()) return FT; @@ -706,19 +706,7 @@ class TypeTree : public std::enable_shared_from_this { auto m0 = TypeTree::operator[]({0}); if (auto flt = m0.isFloat()) { - size_t chunk; - if (flt->isFloatTy()) { - chunk = 4; - } else if (flt->isDoubleTy()) { - chunk = 8; - } else if (flt->isHalfTy()) { - chunk = 2; - } else if (flt->isX86_FP80Ty()) { - chunk = 10; - } else { - llvm::errs() << *flt << "\n"; - assert(0 && "unhandled float type"); - } + size_t chunk = dl.getTypeSizeInBits(flt) / 8; for (size_t i = chunk; i < size; i += chunk) { auto mx = TypeTree::operator[]({(int)i}); if (auto f2 = mx.isFloat()) {