Skip to content

Commit

Permalink
Merge branch 'main' into staylor/last_updated_workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
staylorTT authored Nov 4, 2024
2 parents 6932620 + e6c60fd commit 7e18c15
Show file tree
Hide file tree
Showing 25 changed files with 465 additions and 37 deletions.
12 changes: 12 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,18 @@ def TTIR_CosOp: TTIR_ElementwiseUnaryOp<"cos"> {
}];
}

def TTIR_SignOp: TTIR_ElementwiseUnaryOp<"sign"> {
let summary = "Eltwise sign operation.";
let description = [{
Returns the sign of the `operand` element-wise and produces a `result`
tensor.

Example:
%a: [[3, -2, 0], [1, -4, 4]]
"ttir.sign"(%a, %out) -> %out: [[1, -1, 0], [1, -1, 1]]
}];
}

def TTIR_LogicalNotOp: TTIR_ElementwiseUnaryOp<"logical_not"> {
let summary = "Eltwise logical not op.";
let description = [{
Expand Down
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ add_mlir_dialect(TTNNOps ttnn)
add_mlir_doc(TTNNBase TTNNDialect src/autogen/md/Dialect/ -gen-dialect-doc)
add_mlir_doc(TTNNOps TTNNOp src/autogen/md/Dialect/ -gen-op-doc)

add_mlir_interface(TTNNOpsBackendInterfaces)

set(LLVM_TARGET_DEFINITIONS TTNNOpsEnums.td)
mlir_tablegen(TTNNOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(TTNNOpsEnums.cpp.inc -gen-enum-defs)
Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define TTMLIR_TTMLIR_DIALECT_TTNN_TTNNDIALECT_TD

include "mlir/IR/OpBase.td"
include "ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.td"

//===----------------------------------------------------------------------===//
// TTNN dialect definition.
Expand Down Expand Up @@ -43,6 +44,6 @@ def TTNN_Dialect : Dialect {
//===----------------------------------------------------------------------===//

class TTNN_Op<string mnemonic, list<Trait> traits = []> :
Op<TTNN_Dialect, mnemonic, traits>;
Op<TTNN_Dialect, mnemonic, !listconcat(traits, [TTNNOpBackendInterface])>;

#endif
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.h.inc"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"

#define GET_OP_CLASSES
Expand Down
16 changes: 15 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,18 @@ def TTNN_CeilOp : TTNN_ElementwiseUnaryOp<"ceil"> {
}];
}

def TTNN_SignOp: TTNN_ElementwiseUnaryOp<"sign"> {
let summary = "Eltwise sign operation.";
let description = [{
Returns the sign of the `operand` element-wise and produces a `result`
tensor.

Example:
%a: [[3, -2, 0], [1, -4, 4]]
"ttnn.sign"(%a, %out) -> %out: [[1, -1, 0], [1, -1, 1]]
}];
}

def TTNN_CosOp : TTNN_ElementwiseUnaryOp<"cos"> {
let summary = "Eltwise cosine.";
let description = [{
Expand Down Expand Up @@ -202,7 +214,9 @@ def TTNN_ReciprocalOp : TTNN_ElementwiseUnaryOp<"reciprocal"> {
}];
}

def TTNN_ReluOp : TTNN_ElementwiseUnaryOp<"relu"> {
def TTNN_ReluOp : TTNN_ElementwiseUnaryOp<"relu",
[DeclareOpInterfaceMethods<TTNNOpBackendInterface, ["getOpPerfCycles", "getOpL1Usage", "isOpLegal"]>]
> {
let summary = "Eltwise ReLU.";
let description = [{
Eltwise ReLU operation.
Expand Down
50 changes: 50 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSPERFINTERFACES_TD
#define TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSPERFINTERFACES_TD

include "mlir/IR/OpBase.td"

def TTNNOpBackendInterface : OpInterface<"TTNNOpBackend"> {
let description = [{
Interface to access a registered method to infer the return types for an
operation that can be used during type inference.
}];
let cppNamespace = "::mlir::tt::ttnn";
let methods = [
InterfaceMethod<
/*desc=*/[{
Return the op kernel estimate in clock cycles.
}],
/*retTy=*/"size_t",
/*methodName=*/"getOpPerfCycles",
/*args=*/(ins "const std::vector<tt::LayoutAttr>&":$input_layouts, "const tt::LayoutAttr&":$output_layout),
/*methodBody=*/"",
/*defaultImplementation=*/"return std::numeric_limits<size_t>::max();"
>,
InterfaceMethod<
/*desc=*/[{
Return the op kernel estimate in clock cycles.
}],
/*retTy=*/"size_t",
/*methodName=*/"getOpL1Usage",
/*args=*/(ins "const std::vector<tt::LayoutAttr>&":$input_layouts, "const tt::LayoutAttr&":$output_layout),
/*methodBody=*/"",
/*defaultImplementation=*/"return 0;"
>,
InterfaceMethod<
/*desc=*/[{
Return the op kernel estimate in clock cycles.
}],
/*retTy=*/"bool",
/*methodName=*/"isOpLegal",
/*args=*/(ins "const std::vector<tt::LayoutAttr>&":$input_layouts, "const tt::LayoutAttr&":$output_layout),
/*methodBody=*/"",
/*defaultImplementation=*/"return true;"
>,
];
}

#endif // TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSPERFINTERFACES_TD
3 changes: 2 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ enum EltwiseOpType: uint32 {
Cos = 27,
Log = 28,
Log1p = 29,
Expm1 = 30
Expm1 = 30,
Sign = 31
}

union EltwiseOpParams {
Expand Down
4 changes: 4 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class StablehloTypeConverter : public TypeConverter {
if (type.getElementTypeBitWidth() == 1) {
elementType = BFloat16Type::get(elementType.getContext());
changed = true;
} else if (type.getElementTypeBitWidth() == 64 &&
isa<IntegerType>(type.getElementType())) {
elementType = IntegerType::get(elementType.getContext(), 32);
changed = true;
}
// Create shape of 1-D tensor in case of scalar input.
if (shape.size() == 0) {
Expand Down
100 changes: 81 additions & 19 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,30 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <algorithm>
#include <vector>

#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"

#include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h"
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

#include "mlir/Dialect/Traits.h"
#include <llvm/ADT/APFloat.h>
#include <mlir/Dialect/Func/Transforms/FuncConversions.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/IR/ValueRange.h>
#include <mlir/Support/LogicalResult.h>

#include <stablehlo/dialect/StablehloOps.h>

#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

using namespace mlir;
using namespace mlir::tt;

Expand Down Expand Up @@ -315,12 +314,7 @@ class StableHLOToTTIRConstantOpConversionPattern
auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));

// Scalar tensors are not supported by TTIR so we have to convert them to
// 1-D tensors.
mlir::ElementsAttr valueAttr =
srcOp.getValue().getShapedType().getShape().empty()
? convertTo1DTensor(srcOp.getValue())
: srcOp.getValue();
mlir::ElementsAttr valueAttr = getValueAttr(srcOp.getValue());

rewriter.replaceOpWithNewOp<mlir::tt::ttir::ConstantOp>(srcOp, outputType,
valueAttr);
Expand All @@ -338,13 +332,50 @@ class StableHLOToTTIRConstantOpConversionPattern
return success();
}

mlir::ElementsAttr convertTo1DTensor(mlir::ElementsAttr valueAttr) const {
// Rebuilding value of constant op for following cases.
// 1. Scalar values: TTNN does not support scalar types. So they are converted
// 1-D tensors.
// 2. Boolean tensor: TTNN does not support boolean data. So they are
// converted to bfloat16 tensors.
// 3. Integer tensor: TTNN does not support 64 bit integer. So they are
// converted to 32 bit tensor.
mlir::ElementsAttr getValueAttr(mlir::ElementsAttr valueAttr) const {
Type elementType = valueAttr.getElementType();
size_t bitWidth = elementType.getIntOrFloatBitWidth();
bool isTensor = !valueAttr.getShapedType().getShape().empty();
bool isIntTensor = isTensor && isa<IntegerType>(elementType) &&
bitWidth != 1 && bitWidth != 64;
bool isFloatTensor = isTensor && isa<FloatType>(elementType);

if (isTensor && (isIntTensor || isFloatTensor)) {
return valueAttr;
}

mlir::ShapedType valueType = mlir::cast<mlir::ShapedType>(
getTypeConverter()->convertType(valueAttr.getShapedType()));
if (valueAttr.getElementType().isInteger()) {
return mlir::DenseElementsAttr::get<int>(valueType,
valueAttr.getSplatValue<int>());
} else {
if (isa<IntegerType>(elementType)) {
switch (bitWidth) {
case 1: {
return rebuildValueAttr<bool>(valueAttr, 1);
}
case 8: {
return rebuildValueAttr<int8_t>(valueAttr, 8);
}
case 16: {
return rebuildValueAttr<int16_t>(valueAttr, 16);
}
case 32: {
return rebuildValueAttr<int32_t>(valueAttr, 32);
}
case 64: {
return rebuildValueAttr<int64_t>(valueAttr, 32);
}
default: {
assert(false && "Unsupported integer type.");
}
}
}
if (isa<FloatType>(elementType)) {
// In case of float values llvm has a bug where not all float types are
// supported for iterating in DenseElementsAttr, so we have to use a
// different constructor.
Expand All @@ -353,6 +384,35 @@ class StableHLOToTTIRConstantOpConversionPattern
valueAttr.getValues<mlir::APFloat>().end());
return mlir::DenseElementsAttr::get(valueType, floatValues);
}
assert(false && "Unsupported data type.");
}

// Extract the values (using the given ElementType) and create new data
// structure. This is used to convert scalars (of type boolean, int8, int16,
// int32, and int64) and tensors (of type boolean and int64).
template <typename ElementType>
mlir::ElementsAttr rebuildValueAttr(mlir::ElementsAttr valueAttr,
size_t bitWidth) const {
mlir::ShapedType valueType = mlir::cast<mlir::ShapedType>(
getTypeConverter()->convertType(valueAttr.getShapedType()));

// Create data structure for boolean type with bfloat16.
if (bitWidth == 1) {
std::vector<mlir::APFloat> booleanValue = {};
for (ElementType value : valueAttr.getValues<ElementType>()) {
mlir::APFloat input(mlir::APFloat::BFloat(), value);
booleanValue.emplace_back(input);
}
return mlir::DenseElementsAttr::get(valueType, booleanValue);
}

// Create data structure for other types.
std::vector<mlir::APInt> IntegerValue = {};
for (ElementType value : valueAttr.getValues<ElementType>()) {
mlir::APInt input(bitWidth, value);
IntegerValue.emplace_back(input);
}
return mlir::DenseElementsAttr::get(valueType, IntegerValue);
}
};

Expand Down Expand Up @@ -863,6 +923,8 @@ void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
mlir::stablehlo::Log1pOp, mlir::tt::ttir::Log1pOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::Expm1Op, mlir::tt::ttir::Expm1Op>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::SignOp, mlir::tt::ttir::SignOp>>(typeConverter, ctx);
}

void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseOpConversionPattern<ttir::ReluOp, ttnn::ReluOp>,
ElementwiseOpConversionPattern<ttir::SqrtOp, ttnn::SqrtOp>,
ElementwiseOpConversionPattern<ttir::RsqrtOp, ttnn::RsqrtOp>,
ElementwiseOpConversionPattern<ttir::SignOp, ttnn::SignOp>,
ElementwiseOpConversionPattern<ttir::SigmoidOp, ttnn::SigmoidOp>,
ElementwiseOpConversionPattern<ttir::Log1pOp, ttnn::Log1pOp>,
ElementwiseOpConversionPattern<ttir::ReciprocalOp, ttnn::ReciprocalOp>,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
DefaultOpConversionPattern<ttnn::ReluOp>,
DefaultOpConversionPattern<ttnn::SqrtOp>,
DefaultOpConversionPattern<ttnn::RsqrtOp>,
DefaultOpConversionPattern<ttnn::SignOp>,
DefaultOpConversionPattern<ttnn::SigmoidOp>,
DefaultOpConversionPattern<ttnn::Log1pOp>,
DefaultOpConversionPattern<ttnn::ReciprocalOp>,
Expand Down
8 changes: 8 additions & 0 deletions lib/Dialect/TTNN/Analysis/ShardSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "ttmlir/Dialect/TTNN/Analysis/ShardSolver.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include <mlir/Interfaces/DestinationStyleOpInterface.h>
#include <mlir/Support/LLVM.h>
#include <unordered_set>
Expand Down Expand Up @@ -503,6 +504,13 @@ bool ShardSolver::checkShardCompatible(
// TEMP : Dummy mock implementation, will be replaced.
//

if (TTNNOpBackend backend = dyn_cast<TTNNOpBackend>(consumerOp)) {
if (false ==
backend.isOpLegal(std::vector{producerLayout}, consumerLayout)) {
return false;
}
}

// May need to fetch other inputs for consumerOp(weights/join node).
//

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRTTNNDialect
TTNNDialect.cpp
TTNNOps.cpp
TTNNOpsBackendInterfaces.cpp
TTNNOpsTypes.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
34 changes: 34 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"

#include "ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.cpp.inc"

namespace mlir::tt::ttnn {

//===----------------------------------------------------------------------===//
// ReluOp
//===----------------------------------------------------------------------===//

// // Relu backend interface
size_t ReluOp::getOpPerfCycles(const std::vector<tt::LayoutAttr> &input_layouts,
const tt::LayoutAttr &output_layout) {
// Implement a custom estimate for relu op cycles.
return 5;
}

size_t ReluOp::getOpL1Usage(const std::vector<tt::LayoutAttr> &input_layouts,
const tt::LayoutAttr &output_layout) {
// Implement a custom estimate for relu op L1 usage.
return 10;
}

bool ReluOp::isOpLegal(const std::vector<tt::LayoutAttr> &input_layouts,
const tt::LayoutAttr &output_layout) {
// Implement a custom check for relu op legality.
return true;
}

} // namespace mlir::tt::ttnn
Loading

0 comments on commit 7e18c15

Please sign in to comment.