From 27596386b44190520e9ba2b49ce02c3a4d1b9c3c Mon Sep 17 00:00:00 2001 From: Muhammad Asif Manzoor Date: Wed, 6 Nov 2024 01:50:26 +0000 Subject: [PATCH] Handle different value types for constant op explicitly for TTNN backend --- lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 6 ++-- test/ttmlir/Dialect/TTNN/simple_constant.mlir | 36 +++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index a496a82363..e7cf8916c1 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -23,7 +23,6 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/LogicalResult.h" -#include using namespace mlir; using namespace mlir::tt; @@ -506,8 +505,7 @@ class ConstantOpConversionPattern Value device = getOrInsertDevice(rewriter, op); float fillValue = valueAttr.getElementType().isInteger() - ? getIntegerValue( - valueAttr) // static_cast(valueAttr.getSplatValue()) + ? getIntegerValue(valueAttr) : valueAttr.getSplatValue().convertToFloat(); if (fillValue == 0) { rewriter.replaceOpWithNewOp( @@ -554,7 +552,7 @@ class ConstantOpConversionPattern case 64: return static_cast(valueAttr.getSplatValue()); } - return 0.0; + assert(false && "Unsupported integer type."); } }; diff --git a/test/ttmlir/Dialect/TTNN/simple_constant.mlir b/test/ttmlir/Dialect/TTNN/simple_constant.mlir index fc85c13ef3..00eafaa52b 100644 --- a/test/ttmlir/Dialect/TTNN/simple_constant.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_constant.mlir @@ -1,24 +1,60 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { + func.func @test_empty_int8() -> tensor<64x128xi8> { + %0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi8>}> : () -> tensor<64x128xi8> + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + return %0 : tensor<64x128xi8> + } + + func.func @test_empty_int16() -> tensor<64x128xi16> { + %0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi16>}> : () -> tensor<64x128xi16> + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + return %0 : tensor<64x128xi16> + } + func.func @test_empty_int() -> tensor<64x128xi32> { %0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi32>}> : () -> tensor<64x128xi32> // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] return %0 : tensor<64x128xi32> } + func.func @test_empty_bfloat16() -> tensor<64x128xbf16> { + %0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<64x128xbf16>}> : () -> tensor<64x128xbf16> + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + return %0 : tensor<64x128xbf16> + } + func.func @test_empty_float() -> tensor<64x128xf32> { %0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<64x128xf32>}> : () -> tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] return %0 : tensor<64x128xf32> } + func.func @test_full_int8() -> tensor<64x128xi8> { + %0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xi8>}> : () -> tensor<64x128xi8> + // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] + return %0 : tensor<64x128xi8> + } + + func.func @test_full_int16() -> tensor<64x128xi16> { + %0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xi16>}> : () -> tensor<64x128xi16> + // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] + return %0 : tensor<64x128xi16> + } + func.func @test_full_int() -> tensor<64x128xi32> { %0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xi32>}> : () -> tensor<64x128xi32> // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] return %0 : tensor<64x128xi32> } + func.func @test_full_bfloat16() -> tensor<64x128xbf16> { + %0 = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<64x128xbf16>}> : () -> tensor<64x128xbf16> + // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] + return %0 : tensor<64x128xbf16> + } + func.func @test_full_float() -> tensor<64x128xf32> { %0 = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<64x128xf32>}> : () -> tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]]