Skip to content

Commit

Permalink
fixing constant for TTIR->TTNN conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
mmanzoorTT committed Nov 6, 2024
1 parent 6bad515 commit 90c5561
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include <cstdint>

using namespace mlir;
using namespace mlir::tt;
Expand Down Expand Up @@ -503,9 +504,11 @@ class ConstantOpConversionPattern

if (valueAttr.isSplat()) {
Value device = getOrInsertDevice(rewriter, op);
float fillValue = valueAttr.getElementType().isInteger()
? static_cast<float>(valueAttr.getSplatValue<int>())
: valueAttr.getSplatValue<float>();
float fillValue =
valueAttr.getElementType().isInteger()
? getIntegerValue(
valueAttr) // static_cast<float>(valueAttr.getSplatValue<int>())
: valueAttr.getSplatValue<mlir::APFloat>().convertToFloat();
if (fillValue == 0) {
rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
op, this->getTypeConverter()->convertType(op.getType()), device);
Expand Down Expand Up @@ -536,6 +539,23 @@ class ConstantOpConversionPattern

return success();
}

float getIntegerValue(mlir::ElementsAttr valueAttr) const {
size_t bitWidth = valueAttr.getElementType().getIntOrFloatBitWidth();
switch (bitWidth) {
case 1:
return static_cast<float>(valueAttr.getSplatValue<bool>());
case 8:
return static_cast<float>(valueAttr.getSplatValue<int8_t>());
case 16:
return static_cast<float>(valueAttr.getSplatValue<int16_t>());
case 32:
return static_cast<float>(valueAttr.getSplatValue<int>());
case 64:
return static_cast<float>(valueAttr.getSplatValue<int64_t>());
}
return 0.0;
}
};

} // namespace
Expand Down

0 comments on commit 90c5561

Please sign in to comment.