Skip to content

Commit

Permalink
Handle different value types for constant op explicitly for TTNN backend
Browse files Browse the repository at this point in the history
  • Loading branch information
mmanzoorTT committed Nov 6, 2024
1 parent 90c5561 commit 2759638
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
6 changes: 2 additions & 4 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include <cstdint>

using namespace mlir;
using namespace mlir::tt;
Expand Down Expand Up @@ -506,8 +505,7 @@ class ConstantOpConversionPattern
Value device = getOrInsertDevice(rewriter, op);
float fillValue =
valueAttr.getElementType().isInteger()
? getIntegerValue(
valueAttr) // static_cast<float>(valueAttr.getSplatValue<int>())
? getIntegerValue(valueAttr)
: valueAttr.getSplatValue<mlir::APFloat>().convertToFloat();
if (fillValue == 0) {
rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
Expand Down Expand Up @@ -554,7 +552,7 @@ class ConstantOpConversionPattern
case 64:
return static_cast<float>(valueAttr.getSplatValue<int64_t>());
}
return 0.0;
assert(false && "Unsupported integer type.");
}
};

Expand Down
36 changes: 36 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_constant.mlir
Original file line number Diff line number Diff line change
@@ -1,24 +1,60 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @test_empty_int8() -> tensor<64x128xi8> {
%0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi8>}> : () -> tensor<64x128xi8>
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
return %0 : tensor<64x128xi8>
}

func.func @test_empty_int16() -> tensor<64x128xi16> {
%0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi16>}> : () -> tensor<64x128xi16>
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
return %0 : tensor<64x128xi16>
}

func.func @test_empty_int() -> tensor<64x128xi32> {
%0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi32>}> : () -> tensor<64x128xi32>
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
return %0 : tensor<64x128xi32>
}

func.func @test_empty_bfloat16() -> tensor<64x128xbf16> {
%0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<64x128xbf16>}> : () -> tensor<64x128xbf16>
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
return %0 : tensor<64x128xbf16>
}

func.func @test_empty_float() -> tensor<64x128xf32> {
%0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<64x128xf32>}> : () -> tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
return %0 : tensor<64x128xf32>
}

func.func @test_full_int8() -> tensor<64x128xi8> {
%0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xi8>}> : () -> tensor<64x128xi8>
// CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]]
return %0 : tensor<64x128xi8>
}

func.func @test_full_int16() -> tensor<64x128xi16> {
%0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xi16>}> : () -> tensor<64x128xi16>
// CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]]
return %0 : tensor<64x128xi16>
}

func.func @test_full_int() -> tensor<64x128xi32> {
%0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xi32>}> : () -> tensor<64x128xi32>
// CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]]
return %0 : tensor<64x128xi32>
}

func.func @test_full_bfloat16() -> tensor<64x128xbf16> {
%0 = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<64x128xbf16>}> : () -> tensor<64x128xbf16>
// CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]]
return %0 : tensor<64x128xbf16>
}

func.func @test_full_float() -> tensor<64x128xf32> {
%0 = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<64x128xf32>}> : () -> tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]]
Expand Down

0 comments on commit 2759638

Please sign in to comment.