diff --git a/include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt b/include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt index 7af4852500..e04bf3f3e3 100644 --- a/include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt +++ b/include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt @@ -2,6 +2,12 @@ add_mlir_dialect(TTIROps ttir) add_mlir_doc(TTIRBase TTIRDialect src/autogen/md/Dialect/ -gen-dialect-doc) add_mlir_doc(TTIROps TTIROp src/autogen/md/Dialect/ -gen-op-doc) +set(LLVM_TARGET_DEFINITIONS TTIROpsAttrs.td) +mlir_tablegen(TTIROpsAttrs.h.inc -gen-attrdef-decls) +mlir_tablegen(TTIROpsAttrs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(TTIROpsAttrsIncGen) +add_dependencies(mlir-headers TTIROpsAttrsIncGen) + set(LLVM_TARGET_DEFINITIONS TTIROpsInterfaces.td) mlir_tablegen(TTIROpsInterfaces.h.inc -gen-op-interface-decls) mlir_tablegen(TTIROpsInterfaces.cpp.inc -gen-op-interface-defs) diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIRBase.td b/include/ttmlir/Dialect/TTIR/IR/TTIRBase.td index f8a3198973..57f3dc37d3 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIRBase.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIRBase.td @@ -21,7 +21,7 @@ def TTIR_Dialect : Dialect { or dialects that are actually supported by a consuming backend. }]; let cppNamespace = "::mlir::tt::ttir"; - + let useDefaultAttributePrinterParser = 1; let dependentDialects = [ "::mlir::arith::ArithDialect", "::mlir::func::FuncDialect", diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.h b/include/ttmlir/Dialect/TTIR/IR/TTIROps.h index 26bcf1a192..cd102dbb96 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.h +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.h @@ -16,6 +16,9 @@ #include "TTIROpsInterfaces.h" +#define GET_ATTRDEF_CLASSES +#include "ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.h.inc" + #define GET_OP_CLASSES #include "ttmlir/Dialect/TTIR/IR/TTIROps.h.inc" diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 9eae608b74..144008309d 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -7,6 +7,7 @@ include "ttmlir/Dialect/TT/IR/TTOpsTypes.td" include "ttmlir/Dialect/TTIR/IR/TTIRBase.td" +include "ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td" include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td" include "mlir/Dialect/Linalg/IR/LinalgBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -557,6 +558,45 @@ def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> { let hasVerifier = 1; } +def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> { + let summary = "Generalized convolution op."; + let description = [{ + Applies a convolution of the rhs with the lhs. + + This operation captures convolutions of all dimensionality as well + as deconvolution/conv transpose. + }]; + + let arguments = (ins + AnyRankedTensor:$input, + AnyRankedTensor:$weight, + Optional:$bias, + AnyRankedTensor:$output, + // Default value: one for each of the spatial dimension. + DefaultValuedOptionalAttr(getConvolutionLayout().getInputSpatialDimensions().size(), 1)">:$window_strides, + // Default value: two zeros for each of the spatial dimension. + DefaultValuedOptionalAttr(getConvolutionLayout().getInputSpatialDimensions().size()*2, 0)">:$padding, + // Default value: one for each of the spatial dimension. + DefaultValuedOptionalAttr(getConvolutionLayout().getInputSpatialDimensions().size(), 1)">:$input_dilation, + // Default value: one for each of the spatial dimension. + DefaultValuedOptionalAttr(getConvolutionLayout().getInputSpatialDimensions().size(), 1)">:$weight_dilation, + // Default value: false for each of the spatial dimension. + DefaultValuedOptionalAttr(getConvolutionLayout().getInputSpatialDimensions().size(), false)">:$window_reversal, + TTIR_ConvolutionLayoutAttr:$convolution_layout, + ConfinedAttr:$feature_group_count, + ConfinedAttr:$batch_group_count, + TT_OperandConstraintArrayAttr:$operand_constraints + ); + + let results = (outs AnyRankedTensor); + let hasVerifier = 1; + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; +} + + def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> { let summary = "Applies a 2D max pooling over an input signal composed of several input planes."; let description = [{ diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td b/include/ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td new file mode 100644 index 0000000000..60943af269 --- /dev/null +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_TTIR_ATTRS_TD +#define TTMLIR_TTIR_ATTRS_TD + +include "mlir/IR/AttrTypeBase.td" +include "ttmlir/Dialect/TTIR/IR/TTIRBase.td" + +def TTIR_ConvolutionLayoutAttr : AttrDef { + let mnemonic = "convolution_layout"; + let summary = "Structure of dimension information for convolution op"; + let description = [{ + Holds the layout information for the input activation, weights, and output. + }]; + let parameters = (ins + "int64_t":$inputBatchDimension, + "int64_t":$inputFeatureDimension, + ArrayRefParameter<"int64_t">:$inputSpatialDimensions, + + "int64_t":$kernelOutputFeatureDimension, + "int64_t":$kernelInputFeatureDimension, + ArrayRefParameter<"int64_t">:$kernelSpatialDimensions, + + "int64_t":$outputBatchDimension, + "int64_t":$outputFeatureDimension, + ArrayRefParameter<"int64_t">:$outputSpatialDimensions + ); + + let assemblyFormat = [{ + `input_batch` `=` $inputBatchDimension `,` + `input_feature` `=` $inputFeatureDimension`,` + `input_spatial_dimensions` `=` custom($inputSpatialDimensions) `,` + `kernel_output_feature` `=` $kernelOutputFeatureDimension `,` + `kernel_input_feature` `=` $kernelInputFeatureDimension `,` + `kernel_spatial_dimensions` `=` custom($kernelSpatialDimensions) `,` + `output_batch` `=` $outputBatchDimension `,` + `output_feature` `=` $outputFeatureDimension `,` + `output_spatial_dimensions` `=` custom($outputSpatialDimensions) + }]; +} + +#endif // TTMLIR_TTIR_ATTRS_TD diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index 1cee4cbb5c..709d3b0af3 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -119,4 +119,11 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> { ]; } +def TTIRConvolutionToConv2d: Pass<"ttir-convolution-to-conv2d", "::mlir::ModuleOp"> { + let summary = "Convert eligible convolution ops to conv2d ops."; + let description = [{ + This pass converts eligible convolution ops to conv2d ops. + }]; +} + #endif diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 4890e0ad20..b9f8807025 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include +#include #include #include "mlir/IR/Builders.h" @@ -373,46 +374,28 @@ class StableHLOToTTIRConvolutionOpConversionPattern tensor::EmptyOp outputTensor = rewriter.create( srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); - std::vector strides = - adaptor.getWindowStrides().value_or(ArrayRef({1, 1})).vec(); - IntegerAttr stride_height_attr = - rewriter.getSI32IntegerAttr(static_cast(strides[0])); - IntegerAttr stride_width_attr = - rewriter.getSI32IntegerAttr(static_cast(strides[1])); - - std::vector dilation = - adaptor.getLhsDilation().value_or(ArrayRef({1, 1})).vec(); - - IntegerAttr dilation_height_attr = - rewriter.getSI32IntegerAttr(static_cast(dilation[0])); - IntegerAttr dilation_width_attr = - rewriter.getSI32IntegerAttr(static_cast(dilation[1])); - - IntegerAttr groups_attr = rewriter.getSI32IntegerAttr( - static_cast(adaptor.getFeatureGroupCount())); - - std::vector padding; - if (!adaptor.getPadding().has_value()) { - padding = {0, 0, 0, 0}; - } else { - for (auto iter = adaptor.getPadding()->value_begin(); - iter < adaptor.getPadding()->value_end(); iter++) { - padding.push_back(static_cast(*iter)); - } - } - - rewriter.replaceOpWithNewOp( + auto dimNums = adaptor.getDimensionNumbers(); + rewriter.replaceOpWithNewOp( srcOp, outputType, adaptor.getLhs(), adaptor.getRhs(), - mlir::Value(nullptr), outputTensor, stride_height_attr, - stride_width_attr, dilation_height_attr, dilation_width_attr, - groups_attr, rewriter.getSI32IntegerAttr(padding[0]), - rewriter.getSI32IntegerAttr(padding[1]), - rewriter.getSI32IntegerAttr(padding[2]), - rewriter.getSI32IntegerAttr(padding[3]), + mlir::Value(nullptr), outputTensor, adaptor.getWindowStridesAttr(), + adaptor.getPaddingAttr(), adaptor.getLhsDilationAttr(), + adaptor.getRhsDilationAttr(), adaptor.getWindowReversalAttr(), + mlir::tt::ttir::ConvolutionLayoutAttr::get( + getContext(), dimNums.getInputBatchDimension(), + dimNums.getInputFeatureDimension(), + dimNums.getInputSpatialDimensions(), + dimNums.getKernelOutputFeatureDimension(), + dimNums.getKernelInputFeatureDimension(), + dimNums.getKernelSpatialDimensions(), + dimNums.getOutputBatchDimension(), + dimNums.getOutputFeatureDimension(), + dimNums.getOutputSpatialDimensions()), + adaptor.getFeatureGroupCountAttr(), adaptor.getBatchGroupCountAttr(), rewriter.getArrayAttr( SmallVector(adaptor.getOperands().size() + 1, rewriter.getAttr( OperandConstraint::AnyDeviceTile)))); + return success(); } }; diff --git a/lib/Dialect/TTIR/IR/TTIRDialect.cpp b/lib/Dialect/TTIR/IR/TTIRDialect.cpp index 46ec9bac2a..73d259ea3c 100644 --- a/lib/Dialect/TTIR/IR/TTIRDialect.cpp +++ b/lib/Dialect/TTIR/IR/TTIRDialect.cpp @@ -8,6 +8,11 @@ #include "mlir/InitAllDialects.h" #include "mlir/Transforms/InliningUtils.h" #include "ttmlir/Dialect/TTIR/IR/TTIROps.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/TypeSwitch.h" + +#define GET_ATTRDEF_CLASSES +#include "ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.cpp.inc" using namespace mlir; using namespace mlir::tt::ttir; @@ -59,4 +64,8 @@ void TTIRDialect::initialize() { #include "ttmlir/Dialect/TTIR/IR/TTIROps.cpp.inc" >(); addInterfaces(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.cpp.inc" + >(); } diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 5e88470cfb..54945e28ae 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -62,6 +62,54 @@ ::mlir::LogicalResult mlir::tt::ttir::Conv2dOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ConvolutionOp +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult mlir::tt::ttir::ConvolutionOp::verify() { + if (getConvolutionLayout().getInputSpatialDimensions().size() != + getConvolutionLayout().getOutputSpatialDimensions().size()) { + return emitOpError("Convolution input, output, and kernel must have the " + "same number of spatial dimensions"); + } + if (getConvolutionLayout().getInputSpatialDimensions().size() != + getConvolutionLayout().getKernelSpatialDimensions().size()) { + return emitOpError("Convolution input, output, and kernel must have the " + "same number of spatial dimensions"); + } + + // Subtract 2 from the rank as to not count batch and feature dimension + if (getInput().getType().getRank() - 2 != + (int64_t)getConvolutionLayout().getInputSpatialDimensions().size()) { + return emitOpError("Input tensor must have the same number of spatial " + "dimensions as specified in the ConvolutionLayout"); + } + + if (getWeight().getType().getRank() - 2 != + (int64_t)getConvolutionLayout().getKernelSpatialDimensions().size()) { + return emitOpError("Weight tensor must have the same number of spatial " + "dimensions as specified in the ConvolutionLayout"); + } + + std::optional<::mlir::RankedTensorType> biasType = + getBias().getImpl() ? std::make_optional(getBias().getType()) + : std::nullopt; + + if (biasType.has_value()) { + if (biasType->getRank() != 4) { + return emitOpError("Bias must be a 4D tensor"); + } + } + + if (getWindowStrides().size() != + getConvolutionLayout().getInputSpatialDimensions().size()) { + return emitOpError("Window strides must have the same number of elements " + "as the spatial dimensions of the input tensor"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // MaxPool2dOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTIR/Transforms/Transforms.cpp b/lib/Dialect/TTIR/Transforms/Transforms.cpp index 084f1a90d4..7775c02baa 100644 --- a/lib/Dialect/TTIR/Transforms/Transforms.cpp +++ b/lib/Dialect/TTIR/Transforms/Transforms.cpp @@ -3,15 +3,27 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttmlir/Dialect/TT/IR/TT.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h" #include "ttmlir/Dialect/TTIR/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace mlir::tt::ttir { #define GEN_PASS_DEF_TTIRSLIDINGWINDOW2DFIXSHAPES +#define GEN_PASS_DEF_TTIRCONVOLUTIONTOCONV2D #include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" //===----------------------------------------------------------------------===// @@ -151,4 +163,338 @@ class TTIRSlidingWindow2dFixShapes } }; +//===----------------------------------------------------------------------===// +// Convolution passes +//===----------------------------------------------------------------------===// + +using TransposeDims = std::tuple; + +template +using PaddingMatrix = std::array, NDims>; + +template +static PaddingMatrix getPaddingMatrix(DenseIntElementsAttr paddingAttr) { + PaddingMatrix paddingMatrix; + std::vector paddingFlattened(paddingAttr.value_begin(), + paddingAttr.value_end()); + + for (uint32_t i = 0; i < 2*NDims; i+=2) { + paddingMatrix[i/2] = {paddingFlattened[i], paddingFlattened[i+1]}; + } + return paddingMatrix; +} +/* + * The following functions are used to generate the transpose operations needed + * to convert a convolution operation to the specific op definitions for a + * ConvNdOp for any N spatial dimensions. + * + * All convolutions will have a batch and feature dimension, and the kernel will + * have an input and output feature dimension. The spatial dimensions can be + * represented by non-negative integers. + */ +enum ConvolutionDimension { BATCH = -1, FEATURE = -2, INVALID_DIM = -3 }; + +enum ConvolutionKernelDimension { + INPUT_FEATURES = -1, + OUTPUT_FEATURES = -2, + INVALID_KERNEL_DIM = -3 +}; + +static tensor::EmptyOp generateTransposeDPSOutput(Value input, int64_t dim0, + int64_t dim1, + PatternRewriter &rewriter) { + auto input_type = mlir::cast(input.getType()); + auto output_shape = input_type.getShape().vec(); + std::swap(output_shape[dim0], output_shape[dim1]); + + auto output_type = RankedTensorType::get( + output_shape, input_type.getElementType(), input_type.getEncoding()); + + return rewriter.create(input.getLoc(), output_shape, + output_type.getElementType()); +} + +static TransposeOp generateTranspose(Value input, int64_t dim0, int64_t dim1, + PatternRewriter &rewriter, + ::mlir::ArrayAttr operandConstraints) { + auto input_type = mlir::cast(input.getType()); + auto output_shape = input_type.getShape().vec(); + std::swap(output_shape[dim0], output_shape[dim1]); + + auto dim0_attr = rewriter.getSI32IntegerAttr(dim0); + auto dim1_attr = rewriter.getSI32IntegerAttr(dim1); + + auto dps_output = generateTransposeDPSOutput(input, dim0, dim1, rewriter); + return rewriter.create(input.getLoc(), dps_output.getType(), + input, dps_output, dim0_attr, dim1_attr, + operandConstraints); +} + +static std::vector generateKernelTransposeIndices( + ConvolutionOp op, + const std::vector ttnn_convolution_kernel_layout) { + std::vector transpose_indices; + + std::vector kernel_layout( + ttnn_convolution_kernel_layout.size(), + ConvolutionKernelDimension::INVALID_KERNEL_DIM); + kernel_layout[op.getConvolutionLayout().getKernelOutputFeatureDimension()] = + ConvolutionKernelDimension::OUTPUT_FEATURES; + kernel_layout[op.getConvolutionLayout().getKernelInputFeatureDimension()] = + ConvolutionKernelDimension::INPUT_FEATURES; + + int64_t spatial_count = 0; + for (int64_t spatial_dim : + op.getConvolutionLayout().getKernelSpatialDimensions()) { + kernel_layout[spatial_dim] = spatial_count; + spatial_count++; + } + + const std::vector desired_kernel_layout = + ttnn_convolution_kernel_layout; + for (int64_t i = 0; i < static_cast(kernel_layout.size()); i++) { + if (kernel_layout[i] != desired_kernel_layout[i]) { + int64_t dim0 = i; + int64_t dim1 = std::find(kernel_layout.begin(), kernel_layout.end(), + desired_kernel_layout[i]) - + kernel_layout.begin(); + transpose_indices.push_back(std::make_tuple(dim0, dim1)); + std::swap(kernel_layout[dim0], kernel_layout[dim1]); + } + } + + return transpose_indices; +} + +static std::vector generateInputTransposeIndices( + ConvolutionOp op, const std::vector ttnn_convolution_layout) { + std::vector transpose_indices; + + std::vector input_layout(ttnn_convolution_layout.size(), + ConvolutionDimension::INVALID_DIM); + input_layout[op.getConvolutionLayout().getInputBatchDimension()] = + ConvolutionDimension::BATCH; + input_layout[op.getConvolutionLayout().getInputFeatureDimension()] = + ConvolutionDimension::FEATURE; + + int64_t spatial_count = 0; + for (int64_t spatial_dim : + op.getConvolutionLayout().getInputSpatialDimensions()) { + input_layout[spatial_dim] = spatial_count; + spatial_count++; + } + + const std::vector desired_input_layout = ttnn_convolution_layout; + for (int64_t i = 0; i < static_cast(input_layout.size()); i++) { + if (input_layout[i] != desired_input_layout[i]) { + int64_t dim0 = i; + int64_t dim1 = std::find(input_layout.begin(), input_layout.end(), + desired_input_layout[i]) - + input_layout.begin(); + transpose_indices.push_back(std::make_tuple(dim0, dim1)); + std::swap(input_layout[dim0], input_layout[dim1]); + } + } + + return transpose_indices; +} + +/** + * Although this function is mostly a clone of generateInputTransposeIndices, + * its slightly different in that if the original Convolution op had the same + * input and output layout, this function will generate the same transposes, + * that were applied to the input but in reverse order. This makes optimizing + * away the inserted transposes easier. + */ +static std::vector generateOutputTransposeIndices( + ConvolutionOp op, const std::vector ttnn_convolution_layout) { + std::vector transpose_indices; + + std::vector desired_output_layout(ttnn_convolution_layout.size(), + ConvolutionDimension::INVALID_DIM); + desired_output_layout[op.getConvolutionLayout().getOutputBatchDimension()] = + ConvolutionDimension::BATCH; + desired_output_layout[op.getConvolutionLayout().getOutputFeatureDimension()] = + ConvolutionDimension::FEATURE; + + int64_t spatial_count = 0; + for (int64_t spatial_dim : + op.getConvolutionLayout().getOutputSpatialDimensions()) { + desired_output_layout[spatial_dim] = spatial_count; + spatial_count++; + } + + std::vector output_layout = ttnn_convolution_layout; + + for (int64_t i = static_cast(desired_output_layout.size()) - 1; + i >= 0; i--) { + if (desired_output_layout[i] != output_layout[i]) { + int64_t dim0 = i; + int64_t dim1 = std::find(output_layout.begin(), output_layout.end(), + desired_output_layout[i]) - + output_layout.begin(); + transpose_indices.push_back(std::make_tuple(dim0, dim1)); + std::swap(output_layout[dim0], output_layout[dim1]); + } + } + + return transpose_indices; +} + +static Value +generateTransposeSequence(Value input, PatternRewriter &rewriter, + std::vector transpose_indices, + ::mlir::ArrayAttr operandConstraints) { + for (auto [dim0, dim1] : transpose_indices) { + input = generateTranspose(input, dim0, dim1, rewriter, operandConstraints) + .getResult(); + } + + return input; +} + +class ConvolutionToConv2dPatternRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + constexpr static uint32_t numSpatialDims = 2; + constexpr static uint32_t SPATIAL_DIM_HEIGHT = 0; + constexpr static uint32_t SPATIAL_DIM_WIDTH = 1; + + // NHWC + const std::vector conv2d_layout = {ConvolutionDimension::BATCH, SPATIAL_DIM_HEIGHT, SPATIAL_DIM_WIDTH, + ConvolutionDimension::FEATURE}; + // OIHW + const std::vector conv2d_kernel_layout = { + ConvolutionKernelDimension::OUTPUT_FEATURES, + ConvolutionKernelDimension::INPUT_FEATURES, SPATIAL_DIM_HEIGHT, SPATIAL_DIM_WIDTH}; + LogicalResult isConv2d(ConvolutionOp op) const { + + // Conv2d will have 2 spatial dimensions + + assert(op.getConvolutionLayout().getInputSpatialDimensions().size() == + op.getConvolutionLayout().getOutputSpatialDimensions().size() && + "Convolution input, output, and kernel must have the same number of " + "spatial dimensions"); + assert(op.getConvolutionLayout().getInputSpatialDimensions().size() == + op.getConvolutionLayout().getKernelSpatialDimensions().size() && + "Convolution input, output, and kernel must have the same number of " + "spatial dimensions"); + + if (op.getConvolutionLayout().getInputSpatialDimensions().size() != numSpatialDims) { + return failure(); + } + + // Not currently supporting window reversal + std::vector window_reversal(op.getWindowReversal().begin(), + op.getWindowReversal().end()); + for (bool reversed : window_reversal) { + if (reversed) { + return failure(); + } + } + + // Not currently support batch groups + if (op.getBatchGroupCount() != 1) { + return failure(); + } + + return success(); + } + + LogicalResult matchAndRewrite(ConvolutionOp op, + PatternRewriter &rewriter) const final { + + if (failed(isConv2d(op))) { + return failure(); + } + + auto stride_height_attr = + rewriter.getSI32IntegerAttr(op.getWindowStrides()[SPATIAL_DIM_HEIGHT]); + auto stride_width_attr = + rewriter.getSI32IntegerAttr(op.getWindowStrides()[SPATIAL_DIM_WIDTH]); + auto dilation_height_attr = + rewriter.getSI32IntegerAttr(op.getWeightDilation()[SPATIAL_DIM_HEIGHT]); + auto dilation_width_attr = + rewriter.getSI32IntegerAttr(op.getWeightDilation()[SPATIAL_DIM_WIDTH]); + + // Padding is a list of 2-tuples, the order of the 2-tuples is in most-significant spatial dimension first order + // For Conv2d the most significant spatial dimension is the height, followed by the width. + auto padding_matrix = getPaddingMatrix(op.getPadding()); + auto padding_top_attr = rewriter.getSI32IntegerAttr(padding_matrix[SPATIAL_DIM_HEIGHT][0]); + auto padding_bottom_attr = rewriter.getSI32IntegerAttr(padding_matrix[SPATIAL_DIM_HEIGHT][1]); + auto padding_left_attr = rewriter.getSI32IntegerAttr(padding_matrix[SPATIAL_DIM_WIDTH][0]); + auto padding_right_attr = rewriter.getSI32IntegerAttr(padding_matrix[SPATIAL_DIM_WIDTH][1]); + + auto groups_attr = rewriter.getSI32IntegerAttr(op.getFeatureGroupCount()); + + auto output_shape = op.getResult().getType().getShape().vec(); + std::vector new_output_shape = { + output_shape[op.getConvolutionLayout().getOutputBatchDimension()], + output_shape[op.getConvolutionLayout().getOutputSpatialDimensions()[SPATIAL_DIM_HEIGHT]], + output_shape[op.getConvolutionLayout().getOutputSpatialDimensions()[SPATIAL_DIM_WIDTH]], + output_shape[op.getConvolutionLayout().getOutputFeatureDimension()]}; + + auto inputType = mlir::cast(op.getInput().getType()); + auto outputType = + inputType.cloneWith(new_output_shape, inputType.getElementType()); + + auto convDPSOutput = rewriter.create( + op.getInput().getLoc(), new_output_shape, outputType.getElementType()); + + auto input_transpose_indices = + generateInputTransposeIndices(op, conv2d_layout); + Value input = generateTransposeSequence(op.getInput(), rewriter, + input_transpose_indices, + op.getOperandConstraints()); + + auto kernel_transpose_indices = + generateKernelTransposeIndices(op, conv2d_kernel_layout); + Value weight = generateTransposeSequence(op.getWeight(), rewriter, + kernel_transpose_indices, + op.getOperandConstraints()); + Conv2dOp new_conv = rewriter.create( + op.getLoc(), outputType, input, weight, op.getBias(), convDPSOutput, + stride_height_attr, stride_width_attr, dilation_height_attr, + dilation_width_attr, groups_attr, padding_left_attr, padding_right_attr, + padding_top_attr, padding_bottom_attr, op.getOperandConstraints()); + + auto output_transpose_indices = + generateOutputTransposeIndices(op, conv2d_layout); + Value output = generateTransposeSequence(new_conv.getResult(), rewriter, + output_transpose_indices, + op.getOperandConstraints()); + + rewriter.replaceOp(op, output); + + return success(); + } +}; + +class TTIRConvolutionToConv2d + : public impl::TTIRConvolutionToConv2dBase { +public: + using impl::TTIRConvolutionToConv2dBase< + TTIRConvolutionToConv2d>::TTIRConvolutionToConv2dBase; + + void runOnOperation() final { + { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { + signalPassFailure(); + return; + } + } + } + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + } +}; + } // namespace mlir::tt::ttir diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index 0acddef69c..e446b663ae 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -25,6 +25,7 @@ void createTTNNPipelineTTIRPasses( // function. Removes all private functions. pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::tt::ttir::createTTIRConvolutionToConv2d()); pm.addPass(mlir::tt::ttir::createTTIRSlidingWindow2dFixShapes()); pm.addPass(mlir::tt::ttir::createTTIRLoadSystemDesc(systemDescOptions)); diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/conv2d_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/conv2d_op.mlir index f2d708e827..ce4a6f6565 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/conv2d_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/conv2d_op.mlir @@ -7,16 +7,13 @@ module @jit_convolution attributes {} { window = { stride = [1, 1], pad = [[1, 1], [1, 1]], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [0, 0] } { feature_group_count = 1 : i64, batch_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo] } : (tensor<1x128x128x32xf32>, tensor<64x32x3x3xf32>) -> tensor<1x128x128x64xf32> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.conv2d"[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.convolution"[[C:.*]] return %0 : tensor<1x128x128x64xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/complex_conv_channel_first.mlir b/test/ttmlir/Dialect/TTNN/complex_conv_channel_first.mlir new file mode 100644 index 0000000000..6b4fa5c5b1 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/complex_conv_channel_first.mlir @@ -0,0 +1,37 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint +module @jit_convolution { + func.func public @test_NCHW_HWIO_to_NHWC_OIHW_conv2d(%arg0: tensor<1x32x128x128xf32>, %arg1: tensor<3x3x32x64xf32>) -> tensor<1x64x128x128xf32> { + %0 = tensor.empty() : tensor<1x64x128x128xf32> + // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.conv2d"[[C:.*]] + %1 = "ttir.convolution"(%arg0, %arg1, %0) <{ + batch_group_count = 1 : i64, + convolution_layout = #ttir, + feature_group_count = 1 : i64, + input_dilation = array, + operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile], + padding = dense<1> : tensor<2x2xi64>, + weight_dilation = array, + window_reversal = array, + window_strides = array + }> : (tensor<1x32x128x128xf32>, tensor<3x3x32x64xf32>, tensor<1x64x128x128xf32>) -> tensor<1x64x128x128xf32> + // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] + return %1 : tensor<1x64x128x128xf32> + } +}