Skip to content

Commit

Permalink
Create new general pooling op and decomposition pattern that converts to
Browse files Browse the repository at this point in the history
maxpool2d

Use output memory config attribute in runtime, add silicon tests
  • Loading branch information
LPanosTT committed Nov 4, 2024
1 parent 3dbf089 commit 29670ae
Show file tree
Hide file tree
Showing 20 changed files with 651 additions and 361 deletions.
6 changes: 6 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ mlir_tablegen(TTIROpsAttrs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(TTIROpsAttrsIncGen)
add_dependencies(mlir-headers TTIROpsAttrsIncGen)

set(LLVM_TARGET_DEFINITIONS TTIROpsEnums.td)
mlir_tablegen(TTIROpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(TTIROpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRTTIROpsEnumsIncGen)
add_dependencies(mlir-headers MLIRTTIROpsEnumsIncGen)

set(LLVM_TARGET_DEFINITIONS TTIROpsInterfaces.td)
mlir_tablegen(TTIROpsInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(TTIROpsInterfaces.cpp.inc -gen-op-interface-defs)
Expand Down
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "TTIROpsInterfaces.h"

#include "ttmlir/Dialect/TTIR/IR/TTIROpsEnums.h.inc"

#define GET_ATTRDEF_CLASSES
#include "ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.h.inc"

Expand Down
31 changes: 28 additions & 3 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,33 @@ def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> {
}];
}

def TTIR_PoolingOp : TTIR_DPSOp<"pooling", [AttrSizedOperandSegments]> {
let summary = "General pooling op";
let description = [{
General pooling op
}];

let arguments = (ins
Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs,
TTIR_PoolingMethodAttr:$pooling_method,
DenseI64ArrayAttr:$window_dimensions,

// Default stride of 1 over every dimension
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "SmallVector<int64_t>(getWindowDimensions().size(), 1)">:$window_strides,
// Default dilation of 1 over every dimension
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "SmallVector<int64_t>(getWindowDimensions().size(), 1)">:$base_dilations,
// Default dilation of 1 over every dimension
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "SmallVector<int64_t>(getWindowDimensions().size(), 1)">:$window_dilations,
// Default padding of 0 over every dimension
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "SmallVector<int64_t>(getWindowDimensions().size() * 2, 0)">:$padding,
TT_OperandConstraintArrayAttr:$operand_constraints
);

let results = (outs Variadic<AnyRankedTensor>);

let hasVerifier = 1;
}

def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> {
let summary = "Applies a 2D max pooling over an input signal composed of several input planes.";
Expand All @@ -728,9 +755,7 @@ def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> {
SI32Attr:$padding_right,
SI32Attr:$padding_top,
SI32Attr:$padding_bottom,
TT_OperandConstraintArrayAttr:$operand_constraints,
OptionalAttr<SI32Attr>:$original_height,
OptionalAttr<SI32Attr>:$original_width);
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

Expand Down
4 changes: 4 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@

include "mlir/IR/AttrTypeBase.td"
include "ttmlir/Dialect/TTIR/IR/TTIRBase.td"
include "mlir/IR/EnumAttr.td"
include "ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td"

def TTIR_PoolingMethodAttr : EnumAttr<TTIR_Dialect, TTIR_PoolingMethod, "pooling_method">;

def TTIR_ConvolutionLayoutAttr : AttrDef<TTIR_Dialect, "ConvolutionLayout", [], "::mlir::Attribute"> {
let mnemonic = "convolution_layout";
Expand Down
21 changes: 21 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_TTIR_ENUMS_TD
#define TTMLIR_TTIR_ENUMS_TD

include "mlir/IR/EnumAttr.td"

def TTIR_PoolingMethodAverage : I32EnumAttrCase<"Average", 0>;
def TTIR_PoolingMethodMax : I32EnumAttrCase<"Max", 1>;

def TTIR_PoolingMethod : I32EnumAttr<"PoolingMethod", "TTIR PoolingMethod", [
TTIR_PoolingMethodAverage,
TTIR_PoolingMethodMax
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt::ttir";
}

#endif
7 changes: 0 additions & 7 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,6 @@ def TTIRLayout: Pass<"ttir-layout", "::mlir::ModuleOp"> {
];
}

def TTIRSlidingWindow2dFixShapes: Pass<"ttir-sliding-window-2d-fix-shapes", "::mlir::ModuleOp"> {
let summary = "Insert reshapes on the input and output of 2-dimensional sliding window ops that collapse N,H,W on the input: i.e (N, H, W, C) --> (1, 1, N*H*W, C), and unflatten the output: i.e (1, 1, N*H*W, C) --> (N, H, W, C)";
let description = [{
Insert reshapes on the input and output of 2-dimensional sliding window ops that collapse N,H,W on the input: i.e (N, H, W, C) --> (1, 1, N*H*W, C), and unflatten the output: i.e (1, 1, N*H*W, C) --> (N, H, W, C)
}];
}

def TTIRSplitCompoundLayout: Pass<"ttir-split-compound-layout", "::mlir::ModuleOp"> {
let summary = "Split compound layouts.";
let description = [{
Expand Down
210 changes: 113 additions & 97 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <limits>
#include <vector>

#include "mlir/Dialect/Traits.h"
Expand Down Expand Up @@ -473,11 +474,74 @@ class StableHLOToTTIRReduceWindowOpConversionPattern
rewriter.eraseOp(op);
}

bool isMaxPool2d(mlir::stablehlo::ReduceWindowOp &srcOp) const {
bool isMaxPool(mlir::stablehlo::ReduceWindowOp &srcOp) const {
if (srcOp.getBody().getBlocks().size() != 1) {
return false;
}

// Find constant input(s)
Operation *init_value;
for (uint64_t i = 0; i < srcOp.getInitValues().size(); i++) {
init_value = srcOp.getInitValues()[i].getDefiningOp();
auto name = init_value->getName().getStringRef().str();
(void)name;
while (init_value->getOpOperands().size() == 1) {
init_value = init_value->getOpOperand(0).get().getDefiningOp();
}
if (!isa<stablehlo::ConstantOp>(init_value)) {
return false;
}

stablehlo::ConstantOp init_value_op =
mlir::cast<stablehlo::ConstantOp>(init_value);

if (init_value_op.getValueAttr().size() != 1) {
return false;
}

// Constant operand must be -inf if this is to be a max pool
// since bfloat16 is not a type we acually have I must compare the raw
// bits
if (init_value_op.getResult().getType().getElementType().isBF16()) {
// Collect the values into a vector
std::vector<mlir::Attribute> values;
for (int64_t i = 0; i < init_value_op.getValueAttr().size(); ++i) {
values.push_back(
init_value_op.getValueAttr().getValues<mlir::Attribute>()[i]);
}

auto denseValues = ::mlir::DenseElementsAttr::get(
init_value_op.getValueAttr().getShapedType(), values);
uint16_t bfloat_bits =
static_cast<uint16_t>(*denseValues.getRawData().data());
if (bfloat_bits != 0xff80) { // This is -inf in bfloat16
return false;
}
} else if (init_value_op.getValue().getType().isF32()) {
if (*init_value_op.getValue().value_begin<float>() !=
-std::numeric_limits<float>::infinity()) {
return false;
}
} else if (init_value_op.getValue().getType().isF64()) {
if (*init_value_op.getValue().value_begin<double>() !=
-std::numeric_limits<double>::infinity()) {
return false;
}
} else if (init_value_op.getValue().getType().isInteger(32)) {
if (*init_value_op.getValue().value_begin<int32_t>() !=
std::numeric_limits<int32_t>::min()) {
return false;
}
} else if (init_value_op.getValue().getType().isInteger(64)) {
if (*init_value_op.getValue().value_begin<int64_t>() !=
std::numeric_limits<int64_t>::min()) {
return false;
}
} else {
return false;
}
}

Block &block = *srcOp.getBody().getBlocks().begin();
uint32_t op_idx = 0;
for (Operation &op : block) {
Expand All @@ -501,105 +565,57 @@ class StableHLOToTTIRReduceWindowOpConversionPattern
mlir::stablehlo::ReduceWindowOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (isMaxPool2d(srcOp)) {
RankedTensorType outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult(0).getType()));
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

RankedTensorType outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult(0).getType()));
ValueRange inputs = adaptor.getInputs()[0];
ValueRange outputs = {outputTensor};

auto window_dimensions = adaptor.getWindowDimensionsAttr();
auto window_strides = adaptor.getWindowStridesAttr();
auto base_dilations = adaptor.getBaseDilationsAttr();
auto window_dilations = adaptor.getWindowDilationsAttr();
auto padding_ = adaptor.getPaddingAttr();

// Generate defaults if they dont exist
window_strides = window_strides
? window_strides
: rewriter.getDenseI64ArrayAttr(SmallVector<int64_t>(
window_dimensions.size(), 1));
base_dilations = base_dilations
? base_dilations
: rewriter.getDenseI64ArrayAttr(SmallVector<int64_t>(
window_dimensions.size(), 1));
window_dilations =
window_dilations ? window_dilations
: rewriter.getDenseI64ArrayAttr(SmallVector<int64_t>(
window_dimensions.size(), 1));
auto padding =
padding_ ? rewriter.getDenseI64ArrayAttr(
SmallVector<int64_t>(padding_.getValues<int64_t>()))
: rewriter.getDenseI64ArrayAttr(
SmallVector<int64_t>(window_dimensions.size() * 2, 1));

auto operand_constraints = rewriter.getArrayAttr(SmallVector<Attribute>(
adaptor.getOperands().size(), rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile)));

mlir::tt::ttir::PoolingMethod pooling_method;
if (isMaxPool(srcOp)) {
pooling_method = mlir::tt::ttir::PoolingMethod::Max;
} else {
return failure();
}

tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
rewriter.replaceOpWithNewOp<ttir::PoolingOp>(
srcOp, outputType, inputs, outputs,
pooling_method, window_dimensions, window_strides,
base_dilations, window_dilations, padding, operand_constraints);

// The generalized ReduceWindow allows for kernel_size, strides, dilation,
// and padding to act on all 4 input dimensions. Since we only support
// channel-last pooling, we select the middle two values for H and W.
// And fail if the others are not 1 (or 0 in the case of padding).
std::vector<int64_t> window_dimensions = adaptor.getWindowDimensions();
if (window_dimensions[0] != 1 || window_dimensions[3] != 1) {
return failure();
}
IntegerAttr kernel_height_attr = rewriter.getSI32IntegerAttr(
static_cast<int32_t>(window_dimensions[1]));
IntegerAttr kernel_width_attr = rewriter.getSI32IntegerAttr(
static_cast<int32_t>(window_dimensions[2]));

std::vector<int64_t> strides =
adaptor.getWindowStrides()
.value_or(ArrayRef<int64_t>({1, 1, 1, 1}))
.vec();

if (strides[0] != 1 || strides[3] != 1) {
return failure();
}
IntegerAttr stride_height_attr =
rewriter.getSI32IntegerAttr(static_cast<int32_t>(strides[1]));
IntegerAttr stride_width_attr =
rewriter.getSI32IntegerAttr(static_cast<int32_t>(strides[2]));

std::vector<int64_t> dilation =
adaptor.getBaseDilations()
.value_or(ArrayRef<int64_t>({1, 1, 1, 1}))
.vec();

if (dilation[0] != 1 || dilation[3] != 1) {
return failure();
}
IntegerAttr dilation_height_attr =
rewriter.getSI32IntegerAttr(static_cast<int32_t>(dilation[1]));
IntegerAttr dilation_width_attr =
rewriter.getSI32IntegerAttr(static_cast<int32_t>(dilation[2]));

// Padding here is in the form ((., .), (top, bottom), (left, right), (.,
// .)) one for each of (N, H, W, C). Since we only support maxpool2d, the
// first and last padding tuples must be zero to be valid. This list is
// flattened so we can use a single iterator to get the values.
std::vector<int32_t> padding = {0, 0, 0, 0};
if (adaptor.getPadding().has_value()) {
uint32_t pad_idx = 0;
for (auto iter = adaptor.getPadding()->value_begin<int64_t>();
iter < adaptor.getPadding()->value_end<int64_t>(); iter++) {

// TTIR requires left, right, top, bottom
if (pad_idx == 2) {
padding[2] = *iter;
} else if (pad_idx == 3) {
padding[3] = *iter;
} else if (pad_idx == 4) {
padding[0] = *iter;
} else if (pad_idx == 5) {
padding[1] = *iter;
} else if (*iter != 0) {
// Padding on the channel or batch is > 1. TTIR/TTNN does not
// support this.
return failure();
}
pad_idx++;
}
}
::llvm::ArrayRef<int64_t> input_shape =
mlir::cast<mlir::RankedTensorType>(adaptor.getInputs()[0].getType())
.getShape();

// Dead ttir.constant sticks around and fails verification. Removing it
// like so since its behind another op
recursiveErase(rewriter, adaptor.getInitValues()[0].getDefiningOp());
rewriter.replaceOpWithNewOp<mlir::tt::ttir::MaxPool2dOp>(
srcOp, outputType, srcOp.getInputs()[0], outputTensor,
kernel_height_attr, kernel_width_attr, stride_height_attr,
stride_width_attr, dilation_height_attr, dilation_width_attr,
rewriter.getBoolAttr(false), rewriter.getSI32IntegerAttr(padding[0]),
rewriter.getSI32IntegerAttr(padding[1]),
rewriter.getSI32IntegerAttr(padding[2]),
rewriter.getSI32IntegerAttr(padding[3]),
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))),
rewriter.getSI32IntegerAttr(input_shape[1]),
rewriter.getSI32IntegerAttr(input_shape[2]));

return success();
}
return failure();
return success();

}
};

Expand Down
Loading

0 comments on commit 29670ae

Please sign in to comment.