diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp index bd97f1dfe6..5a910dd616 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp @@ -51,6 +51,10 @@ class StablehloTypeConverter : public TypeConverter { if (type.getElementTypeBitWidth() == 1) { elementType = BFloat16Type::get(elementType.getContext()); changed = true; + } else if (type.getElementTypeBitWidth() == 64 && + isa(type.getElementType())) { + elementType = IntegerType::get(elementType.getContext(), 32); + changed = true; } // Create shape of 1-D tensor in case of scalar input. if (shape.size() == 0) { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 15b1f086b4..159f387c53 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -2,16 +2,21 @@ // // SPDX-License-Identifier: Apache-2.0 -#include #include +#include "mlir/Dialect/Traits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" + #include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h" +#include "ttmlir/Dialect/TT/IR/TT.h" +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTIR/IR/TTIR.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" -#include "mlir/Dialect/Traits.h" +#include #include #include #include @@ -19,14 +24,8 @@ #include #include #include - #include -#include "ttmlir/Dialect/TT/IR/TT.h" -#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTIR/IR/TTIR.h" -#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" - using namespace mlir; using namespace mlir::tt; @@ -315,12 +314,7 @@ class StableHLOToTTIRConstantOpConversionPattern auto outputType = mlir::cast( getTypeConverter()->convertType(srcOp.getResult().getType())); - // Scalar tensors are not supported by TTIR so we have to convert them to - // 1-D tensors. - mlir::ElementsAttr valueAttr = - srcOp.getValue().getShapedType().getShape().empty() - ? convertTo1DTensor(srcOp.getValue()) - : srcOp.getValue(); + mlir::ElementsAttr valueAttr = getValueAttr(srcOp.getValue()); rewriter.replaceOpWithNewOp(srcOp, outputType, valueAttr); @@ -338,13 +332,50 @@ class StableHLOToTTIRConstantOpConversionPattern return success(); } - mlir::ElementsAttr convertTo1DTensor(mlir::ElementsAttr valueAttr) const { + // Rebuilding value of constant op for following cases. + // 1. Scalar values: TTNN does not support scalar types. So they are converted + // 1-D tensors. + // 2. Boolean tensor: TTNN does not support boolean data. So they are + // converted to bfloat16 tensors. + // 3. Integer tensor: TTNN does not support 64 bit integer. So they are + // converted to 32 bit tensor. + mlir::ElementsAttr getValueAttr(mlir::ElementsAttr valueAttr) const { + Type elementType = valueAttr.getElementType(); + size_t bitWidth = elementType.getIntOrFloatBitWidth(); + bool isTensor = !valueAttr.getShapedType().getShape().empty(); + bool isIntTensor = isTensor && isa(elementType) && + bitWidth != 1 && bitWidth != 64; + bool isFloatTensor = isTensor && isa(elementType); + + if (isTensor && (isIntTensor || isFloatTensor)) { + return valueAttr; + } + mlir::ShapedType valueType = mlir::cast( getTypeConverter()->convertType(valueAttr.getShapedType())); - if (valueAttr.getElementType().isInteger()) { - return mlir::DenseElementsAttr::get(valueType, - valueAttr.getSplatValue()); - } else { + if (isa(elementType)) { + switch (bitWidth) { + case 1: { + return rebuildValueAttr(valueAttr, 1); + } + case 8: { + return rebuildValueAttr(valueAttr, 8); + } + case 16: { + return rebuildValueAttr(valueAttr, 16); + } + case 32: { + return rebuildValueAttr(valueAttr, 32); + } + case 64: { + return rebuildValueAttr(valueAttr, 32); + } + default: { + assert(false && "Unsupported integer type."); + } + } + } + if (isa(elementType)) { // In case of float values llvm has a bug where not all float types are // supported for iterating in DenseElementsAttr, so we have to use a // different constructor. @@ -353,6 +384,35 @@ class StableHLOToTTIRConstantOpConversionPattern valueAttr.getValues().end()); return mlir::DenseElementsAttr::get(valueType, floatValues); } + assert(false && "Unsupported data type."); + } + + // Extract the values (using the given ElementType) and create new data + // structure. This is used to convert scalars (of type boolean, int8, int16, + // int32, and int64) and tensors (of type boolean and int64). + template + mlir::ElementsAttr rebuildValueAttr(mlir::ElementsAttr valueAttr, + size_t bitWidth) const { + mlir::ShapedType valueType = mlir::cast( + getTypeConverter()->convertType(valueAttr.getShapedType())); + + // Create data structure for boolean type with bfloat16. + if (bitWidth == 1) { + std::vector booleanValue = {}; + for (ElementType value : valueAttr.getValues()) { + mlir::APFloat input(mlir::APFloat::BFloat(), value); + booleanValue.emplace_back(input); + } + return mlir::DenseElementsAttr::get(valueType, booleanValue); + } + + // Create data structure for other types. + std::vector IntegerValue = {}; + for (ElementType value : valueAttr.getValues()) { + mlir::APInt input(bitWidth, value); + IntegerValue.emplace_back(input); + } + return mlir::DenseElementsAttr::get(valueType, IntegerValue); } }; diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/binary/concat_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/binary/concat_op.mlir index 301c6696a9..6286f47e7c 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/binary/concat_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/binary/concat_op.mlir @@ -15,7 +15,7 @@ module @jit_concat attributes {} { dimension = 0 : i64 } : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 0 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x2xi64>, tensor<1x2xi64>, tensor<4x2xi64>) -> tensor<4x2xi64> + // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 0 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x2xi32>, tensor<1x2xi32>, tensor<4x2xi32>) -> tensor<4x2xi32> return %0 : tensor<4x2xi64> } @@ -42,7 +42,7 @@ module @jit_concat attributes {} { dimension = 1 : i64 } : (tensor<256x512xi64>, tensor<256x256xi64>) -> tensor<256x768xi64> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<256x512xi64>, tensor<256x256xi64>, tensor<256x768xi64>) -> tensor<256x768xi64> + // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<256x512xi32>, tensor<256x256xi32>, tensor<256x768xi32>) -> tensor<256x768xi32> return %0 : tensor<256x768xi64> } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir index 6e5aea1e20..b00a1bec6e 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir @@ -1,31 +1,187 @@ // REQUIRES: stablehlo // RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s module @jit_constant attributes {} { - func.func public @test_splat() -> tensor<64xf32> { - %0 = stablehlo.constant dense<0.3> : tensor<64xf32> + func.func public @test_boolean_scalar() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16> + %0 = stablehlo.constant dense : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xbf16> + return %0 : tensor + } + + func.func public @test_boolean_splat() -> tensor<64xi1> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<64xbf16>}> : () -> tensor<64xbf16> + %0 = stablehlo.constant dense : tensor<64xi1> + // CHECK: return %{{[0-9]+}} : tensor<64xbf16> + return %0 : tensor<64xi1> + } + + func.func public @test_boolean_multiple() -> tensor<2x2xi1> { + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[1.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00]]> : tensor<2x2xbf16>}> : () -> tensor<2x2xbf16> + %0 = stablehlo.constant dense<[[true, false], [false, true]]> : tensor<2x2xi1> + // CHECK: return %{{[0-9]+}} : tensor<2x2xbf16> + return %0 : tensor<2x2xi1> + } + + func.func public @test_bfloat16_scalar() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16> + %0 = stablehlo.constant dense<3.0> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xbf16> + return %0 : tensor + } + + func.func public @test_bfloat16_splat() -> tensor<64xbf16> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e+00> : tensor<64xbf16>}> : () -> tensor<64xbf16> + %0 = stablehlo.constant dense<3.0> : tensor<64xbf16> + // CHECK: return %{{[0-9]+}} : tensor<64xbf16> + return %0 : tensor<64xbf16> + } + + func.func public @test_bfloat16_multiple() -> tensor<2x2xbf16> { + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xbf16>}> : () -> tensor<2x2xbf16> + %0 = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xbf16> + // CHECK: return %{{[0-9]+}} : tensor<2x2xbf16> + return %0 : tensor<2x2xbf16> + } + + func.func public @test_float16_scalar() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16> + %0 = stablehlo.constant dense<3.0> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xf16> + return %0 : tensor + } + + func.func public @test_float16_splat() -> tensor<64xf16> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e+00> : tensor<64xf16>}> : () -> tensor<64xf16> + %0 = stablehlo.constant dense<3.0> : tensor<64xf16> + // CHECK: return %{{[0-9]+}} : tensor<64xf16> + return %0 : tensor<64xf16> + } + + func.func public @test_float16_multiple() -> tensor<2x2xf16> { + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf16>}> : () -> tensor<2x2xf16> + %0 = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf16> + // CHECK: return %{{[0-9]+}} : tensor<2x2xf16> + return %0 : tensor<2x2xf16> + } + + func.func public @test_float_scalar() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = stablehlo.constant dense<0.3> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xf32> + return %0 : tensor + } + + func.func public @test_float_splat() -> tensor<64xf32> { // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e-01> : tensor<64xf32>}> : () -> tensor<64xf32> + %0 = stablehlo.constant dense<0.3> : tensor<64xf32> + // CHECK: return %{{[0-9]+}} : tensor<64xf32> return %0 : tensor<64xf32> } - func.func public @test_multiple() -> tensor<2x2xf32> { - %0 = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> + func.func public @test_float_multiple() -> tensor<2x2xf32> { // The ugly regex after `dense` is necessary because double square opening // brackets indicate substitution block in FileCheck syntax. // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf32>}> : () -> tensor<2x2xf32> + %0 = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> + // CHECK: return %{{[0-9]+}} : tensor<2x2xf32> return %0 : tensor<2x2xf32> } - func.func public @test_scalar_int() -> tensor { - %0 = stablehlo.constant dense<3> : tensor + func.func public @test_int8_scalar() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = stablehlo.constant dense<3> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xi8> + return %0 : tensor + } + + func.func public @test_int8_splat() -> tensor<64xi8> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xi8>}> : () -> tensor<64xi8> + %0 = stablehlo.constant dense<3> : tensor<64xi8> + // CHECK: return %{{[0-9]+}} : tensor<64xi8> + return %0 : tensor<64xi8> + } + + func.func public @test_int8_multiple() -> tensor<2x2xi8> { + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xi8>}> : () -> tensor<2x2xi8> + %0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi8> + // CHECK: return %{{[0-9]+}} : tensor<2x2xi8> + return %0 : tensor<2x2xi8> + } + + func.func public @test_int16_scalar() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xi16>}> : () -> tensor<1xi16> + %0 = stablehlo.constant dense<3> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xi16> + return %0 : tensor + } + + func.func public @test_int16_splat() -> tensor<64xi16> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xi16>}> : () -> tensor<64xi16> + %0 = stablehlo.constant dense<3> : tensor<64xi16> + // CHECK: return %{{[0-9]+}} : tensor<64xi16> + return %0 : tensor<64xi16> + } + + func.func public @test_int16_multiple() -> tensor<2x2xi16> { + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xi16>}> : () -> tensor<2x2xi16> + %0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi16> + // CHECK: return %{{[0-9]+}} : tensor<2x2xi16> + return %0 : tensor<2x2xi16> + } + + func.func public @test_int32_scalar() -> tensor { // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xi32>}> : () -> tensor<1xi32> + %0 = stablehlo.constant dense<3> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xi32> return %0 : tensor + } + + func.func public @test_int32_splat() -> tensor<64xi32> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xi32>}> : () -> tensor<64xi32> + %0 = stablehlo.constant dense<3> : tensor<64xi32> + // CHECK: return %{{[0-9]+}} : tensor<64xi32> + return %0 : tensor<64xi32> + } + + func.func public @test_int32_multiple() -> tensor<2x2xi32> { + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xi32>}> : () -> tensor<2x2xi32> + %0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> + // CHECK: return %{{[0-9]+}} : tensor<2x2xi32> + return %0 : tensor<2x2xi32> + } + + func.func public @test_int64_scalar() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xi32>}> : () -> tensor<1xi32> + %0 = stablehlo.constant dense<3> : tensor // CHECK: return %{{[0-9]+}} : tensor<1xi32> + return %0 : tensor } - func.func public @test_scalar_float() -> tensor { - %0 = stablehlo.constant dense<0.3> : tensor - // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32> - return %0 : tensor - // CHECK: return %{{[0-9]+}} : tensor<1xf32> + func.func public @test_int64_splat() -> tensor<64xi64> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xi32>}> : () -> tensor<64xi32> + %0 = stablehlo.constant dense<3> : tensor<64xi64> + // CHECK: return %{{[0-9]+}} : tensor<64xi32> + return %0 : tensor<64xi64> + } + + func.func public @test_int64_multiple() -> tensor<2x2xi64> { + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xi32>}> : () -> tensor<2x2xi32> + %0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi64> + // CHECK: return %{{[0-9]+}} : tensor<2x2xi32> + return %0 : tensor<2x2xi64> } }