Skip to content

Commit

Permalink
Let IsAllFloat take an additional DataLayout argument
Browse files Browse the repository at this point in the history
and use it to simplify the implementation of IsAllFloat.
  • Loading branch information
maxaehle committed Jan 4, 2024
1 parent 24c8098 commit 6a70777
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 25 deletions.
2 changes: 1 addition & 1 deletion enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
for (int i = 0; i < 2; ++i) {
auto FT =
TR.query(BO->getOperand(1 - i))
.IsAllFloat((DL.getTypeSizeInBits(BO->getType()) + 7) / 8);
.IsAllFloat((DL.getTypeSizeInBits(BO->getType()) + 7) / 8, DL);
// If ^ against 0b10000000000 and a float the result is a float
if (FT)
if (containsOnlyAtMostTopBit(BO->getOperand(i), FT, DL)) {
Expand Down
12 changes: 6 additions & 6 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2358,7 +2358,7 @@ class AdjointGenerator
auto &dl = gutils->oldFunc->getParent()->getDataLayout();
auto size = dl.getTypeSizeInBits(BO.getType()) / 8;

auto FT = TR.query(&BO).IsAllFloat(size);
auto FT = TR.query(&BO).IsAllFloat(size, dl);
auto eFT = FT;
if (FT)
for (int i = 0; i < 2; ++i) {
Expand Down Expand Up @@ -2388,7 +2388,7 @@ class AdjointGenerator
auto &dl = gutils->oldFunc->getParent()->getDataLayout();
auto size = dl.getTypeSizeInBits(BO.getType()) / 8;

auto FT = TR.query(&BO).IsAllFloat(size);
auto FT = TR.query(&BO).IsAllFloat(size, dl);
auto eFT = FT;
// If ^ against 0b10000000000 and a float the result is a float
if (FT)
Expand Down Expand Up @@ -2426,7 +2426,7 @@ class AdjointGenerator
auto &dl = gutils->oldFunc->getParent()->getDataLayout();
auto size = dl.getTypeSizeInBits(BO.getType()) / 8;

auto FT = TR.query(&BO).IsAllFloat(size);
auto FT = TR.query(&BO).IsAllFloat(size, dl);
auto eFT = FT;
// If & against 0b10000000000 and a float the result is a float
if (FT)
Expand Down Expand Up @@ -2603,7 +2603,7 @@ class AdjointGenerator
auto size = dl.getTypeSizeInBits(BO.getType()) / 8;
Type *diffTy = gutils->getShadowType(BO.getType());

auto FT = TR.query(&BO).IsAllFloat(size);
auto FT = TR.query(&BO).IsAllFloat(size, dl);
auto eFT = FT;
if (FT)
for (int i = 0; i < 2; ++i) {
Expand Down Expand Up @@ -2631,7 +2631,7 @@ class AdjointGenerator
auto &dl = gutils->oldFunc->getParent()->getDataLayout();
auto size = dl.getTypeSizeInBits(BO.getType()) / 8;

auto FT = TR.query(&BO).IsAllFloat(size);
auto FT = TR.query(&BO).IsAllFloat(size, dl);
auto eFT = FT;

Value *dif[2] = {constantval0 ? nullptr : diffe(orig_op0, Builder2),
Expand Down Expand Up @@ -2672,7 +2672,7 @@ class AdjointGenerator
Value *dif[2] = {constantval0 ? nullptr : diffe(orig_op0, Builder2),
constantval1 ? nullptr : diffe(orig_op1, Builder2)};

auto FT = TR.query(&BO).IsAllFloat(size);
auto FT = TR.query(&BO).IsAllFloat(size, dl);
auto eFT = FT;
// If & against 0b10000000000 and a float the result is a float
if (FT)
Expand Down
8 changes: 4 additions & 4 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2926,7 +2926,7 @@ void TypeAnalyzer::visitBinaryOperation(const DataLayout &dl, llvm::Type *T,
if (direction & UP)
for (int i = 0; i < 2; ++i) {
Type *FT = nullptr;
if (!(FT = Ret.IsAllFloat(size)))
if (!(FT = Ret.IsAllFloat(size, dl)))
continue;
// If ^ against 0b10000000000, the result is a float
bool validXor = containsOnlyAtMostTopBit(Args[i], FT, dl);
Expand All @@ -2938,7 +2938,7 @@ void TypeAnalyzer::visitBinaryOperation(const DataLayout &dl, llvm::Type *T,
case BinaryOperator::Or:
for (int i = 0; i < 2; ++i) {
Type *FT = nullptr;
if (!(FT = Ret.IsAllFloat(size)))
if (!(FT = Ret.IsAllFloat(size, dl)))
continue;
// If | against a number only or'ing the exponent, the result is a float
bool validXor = false;
Expand Down Expand Up @@ -3143,7 +3143,7 @@ void TypeAnalyzer::visitBinaryOperation(const DataLayout &dl, llvm::Type *T,
} else if (Opcode == BinaryOperator::Xor) {
for (int i = 0; i < 2; ++i) {
Type *FT;
if (!(FT = (i == 0 ? RHS : LHS).IsAllFloat(size)))
if (!(FT = (i == 0 ? RHS : LHS).IsAllFloat(size, dl)))
continue;
// If ^ against 0b10000000000, the result is a float
bool validXor = containsOnlyAtMostTopBit(Args[i], FT, dl);
Expand All @@ -3154,7 +3154,7 @@ void TypeAnalyzer::visitBinaryOperation(const DataLayout &dl, llvm::Type *T,
} else if (Opcode == BinaryOperator::Or) {
for (int i = 0; i < 2; ++i) {
Type *FT;
if (!(FT = (i == 0 ? RHS : LHS).IsAllFloat(size)))
if (!(FT = (i == 0 ? RHS : LHS).IsAllFloat(size, dl)))
continue;
// If & against 0b10000000000, the result is a float
bool validXor = false;
Expand Down
16 changes: 2 additions & 14 deletions enzyme/Enzyme/TypeAnalysis/TypeTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -698,27 +698,15 @@ class TypeTree : public std::enable_shared_from_this<TypeTree> {
return dat;
}

llvm::Type *IsAllFloat(const size_t size) const {
llvm::Type *IsAllFloat(const size_t size, const llvm::DataLayout &dl) const {
auto m1 = TypeTree::operator[]({-1});
if (auto FT = m1.isFloat())
return FT;

auto m0 = TypeTree::operator[]({0});

if (auto flt = m0.isFloat()) {
size_t chunk;
if (flt->isFloatTy()) {
chunk = 4;
} else if (flt->isDoubleTy()) {
chunk = 8;
} else if (flt->isHalfTy()) {
chunk = 2;
} else if (flt->isX86_FP80Ty()) {
chunk = 10;
} else {
llvm::errs() << *flt << "\n";
assert(0 && "unhandled float type");
}
size_t chunk = dl.getTypeSizeInBits(flt) / 8;
for (size_t i = chunk; i < size; i += chunk) {
auto mx = TypeTree::operator[]({(int)i});
if (auto f2 = mx.isFloat()) {
Expand Down

0 comments on commit 6a70777

Please sign in to comment.