diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index e7d015afb1..c09de0d8d4 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -1050,22 +1050,55 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> { def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> { let summary = "Conv2d operation."; let description = [{ - Applies a 2D convolution over an input image composed of several input planes. + Applies a 2D convolution over an input image composed of several input planes. + + Inputs: + - `input` AnyRankedTensor: expected in the following format (N, H_in, W_in, C) where: + - N is the batch size + - H_in is the height of the input planes + - W_in is the width of the input planes + - C is the number of channels + + - `weight` AnyRankedTensor: expected in the following format (O, C/G, K_H, K_W). + - `bias` Optional: expected in the following format (1, 1, 1, O) where: + - C is the number of input channels + - O is the number of output channels + - G is the number of groups + - K_H is the height of the kernel + - K_W is the width of the kernel + + - `output` AnyRankedTensor: expected in the following format (N, H_out, W_out, O) where: + - `H_out = (H_in + padding_top + padding_bottom - dilation[0] * (K_H - 1) - 1) / stride[0] + 1` + - `W_out = (W_in + padding_left + padding_right - dilation[1] * (K_W - 1) - 1) / stride[1] + 1` + + Attributes: + - `stride` (i32 | array<2xi32>): Stride of the convolution. + - `padding` (i32 | array<2xi32> | array<4xi32>): Padding on height/width. Can be symmetric (`[pH, pW]`) or asymmetric (`[pT, pL, pB, pR]`). + - `dilation` (i32 | array<2xi32>): Spacing between kernel elements. + - `groups` i32: Number of blocked connections from input channels to output channels. input and output channels must both be divisible by groups. + + Example: + %input = tensor.empty() : () -> tensor<1x32x32x64xbf16> + %weight = tensor.empty() : () -> tensor<64x64x3x3xbf16> + %bias = tensor.empty() : () -> tensor<1x1x1x64xbf16> + %output = tensor.empty() : () -> tensor<1x30x30x64xbf16> + %0 = "ttir.conv2d"(%input, %weight, %bias, %output) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = 1: i32, + groups = 1: i32 + > : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16> }]; let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$weight, Optional:$bias, AnyRankedTensor:$output, - SI32Attr:$stride_height, - SI32Attr:$stride_width, - SI32Attr:$dilation_height, - SI32Attr:$dilation_width, - SI32Attr:$groups, - SI32Attr:$padding_left, - SI32Attr:$padding_right, - SI32Attr:$padding_top, - SI32Attr:$padding_bottom); + AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$stride, + AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$padding, + AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$dilation, + I32Attr:$groups); let results = (outs AnyRankedTensor:$result); diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index b731d1f761..f18ff83b85 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -993,15 +993,12 @@ def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> { I32Attr:$batch_size, I32Attr:$input_height, I32Attr:$input_width, - I32Attr:$kernel_height, - I32Attr:$kernel_width, - I32Attr:$stride_height, - I32Attr:$stride_width, - I32Attr:$padding_height, - I32Attr:$padding_width, - I32Attr:$dilation_height, - I32Attr:$dilation_width, - I32Attr:$groups); + DenseI32ArrayAttr:$kernel_size, + DenseI32ArrayAttr:$stride, + DenseI32ArrayAttr:$padding, + DenseI32ArrayAttr:$dilation, + I32Attr:$groups, + OptionalAttr:$conv2d_config); let results = (outs AnyRankedTensor:$result); @@ -1018,10 +1015,24 @@ def TTNN_ConvTranspose2dOp : TTNN_NamedDPSOp<"conv_transpose2d"> { Applies a 2D transposed convolution operator over an input image composed of several input planes. Inputs: - - `input` AnyRankedTensor: NHWC format (batch_size x height x width x channels) - - `weight` AnyRankedTensor: OIHW format (output_channels x input_channels x height x width) - - `bias` Optional: (1 x 1 x 1 x output_channels) - - `output` AnyRankedTensor: (1 x 1 x (batch_size * height * width) x channels) + - `input (AnyRankedTensor): Expected format `(N x H_in x W_in x C)` + - `N`: Batch size + - `H_in`: Height of the input + - `W_in`: Width of the input + - `C`: Number of input channels + + - `weight` (AnyRankedTensor): Expected format `(C x O / G x K_H x K_W)` + - `C`: Number of input channels + - `O`: Number of output channels + - `G`: Number of groups + - `K_H`: Kernel height + - `K_W`: Kernel width + + - `bias` (Optional): Expected format `(1 x 1 x 1 x O / G)` + + - `output` (AnyRankedTensor): Expected format `(1 x 1 x (N * H_out * W_out) x O)` + - `H_out = (H_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1` + - `W_out = (W_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1` Attributes: - `in_channels` i32: The number of input channels. diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td index 8a5140c01b..07b7e496d6 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td @@ -94,6 +94,53 @@ def TTNN_MemoryConfigAttr : TTNN_Attr<"MemoryConfig", "memory_config"> { }]; } +def TTNN_Conv2dConfigAttr : TTNN_Attr<"Conv2dConfig", "conv2d_config"> { + let summary = "TTNN Conv2dConfig attribute"; + let description = [{ + TTNN conv2d config attribute + }]; + + let parameters = (ins + "DataType":$dtype, + "DataType":$weightsDtype, + "StringAttr":$activation, + "IntegerAttr":$inputChannelsAlignment, + "BoolAttr":$deallocateActivation, + "BoolAttr":$reallocateHaloOutput, + "IntegerAttr":$actBlockHOverride, + "IntegerAttr":$actBlockWDiv, + "BoolAttr":$reshardIfNotOptimal, + "BoolAttr":$overrideShardingConfig, + OptionalParameter<"TensorMemoryLayoutAttr", "TTNN tensor memory layout">:$shardLayout, + // TODO: Finish adding this attribute + OptionalParameter<"Attribute", "TTNN core grid">:$coreGrid, + "BoolAttr":$transposeShards, + "Layout":$outputLayout, + "BoolAttr":$enableActDoubleBuffer, + "BoolAttr":$enableWeightsDoubleBuffer, + "BoolAttr":$enableSplitReader, + "BoolAttr":$enableSubblockPadding); + +let assemblyFormat = "`<` `dtype` `=` $dtype `,` " + "`weightsDtype` `=` $weightsDtype `,` " + "`activation` `=` $activation `,` " + "`inputChannelsAlignment` `=` $inputChannelsAlignment `,` " + "`deallocateActivation` `=` $deallocateActivation `,` " + "`reallocateHaloOutput` `=` $reallocateHaloOutput `,` " + "`actBlockHOverride` `=` $actBlockHOverride `,` " + "`actBlockWDiv` `=` $actBlockWDiv `,` " + "`reshardIfNotOptimal` `=` $reshardIfNotOptimal `,` " + "`overrideShardingConfig` `=` $overrideShardingConfig `,` " + "(`shardLayout` `=` $shardLayout^ `,`)? " + "(`coreGrid` `=` $coreGrid^ `,`)? " + "`transposeShards` `=` $transposeShards `,` " + "`outputLayout` `=` $outputLayout `,` " + "`enableActDoubleBuffer` `=` $enableActDoubleBuffer `,` " + "`enableWeightsDoubleBuffer` `=` $enableWeightsDoubleBuffer `,` " + "`enableSplitReader` `=` $enableSplitReader `,` " + "`enableSubblockPadding` `=` $enableSubblockPadding `>`"; +} + def TTNN_MeshShapeAttr : TTNN_Attr<"MeshShape", "mesh_shape"> { let summary = "TTNN Mesh Shape"; let description = [{ diff --git a/include/ttmlir/Target/Common/types.fbs b/include/ttmlir/Target/Common/types.fbs index c5814ce0af..153c952a6c 100644 --- a/include/ttmlir/Target/Common/types.fbs +++ b/include/ttmlir/Target/Common/types.fbs @@ -97,6 +97,27 @@ table MemoryConfigDesc { shard_spec: ShardSpec; } +table Conv2dConfigDesc { + dtype: DataType; + weights_dtype: DataType; + activation: string; + input_channels_alignment: uint32; + deallocate_activation: bool; + reallocate_halo_output: bool; + act_block_h_override: uint32; + act_block_w_div: uint32; + reshard_if_not_optimal: bool; + override_sharding_config: bool; + shard_layout: TensorMemoryLayout = null; + core_grid: bool = null; + transpose_shards: bool; + output_layout: TensorLayout; + enable_act_double_buffer: bool; + enable_weights_double_buffer: bool; + enable_split_reader: bool; + enable_subblock_padding: bool; +} + table ReplicateTensor { replication_factor: uint32; } diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 87dce4e44c..9c8b60dd66 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -299,15 +299,12 @@ table Conv2dOp { batch_size: uint32; input_height: uint32; input_width: uint32; - kernel_height: uint32; - kernel_width: uint32; - stride_height: uint32; - stride_width: uint32; - padding_height: uint32; - padding_width: uint32; - dilation_height: uint32; - dilation_width: uint32; + kernel_size: [int32]; + stride: [int32]; + padding: [int32]; + dilation: [int32]; groups: uint32; + conv2d_config: Conv2dConfigDesc; } table ConvTranspose2dOp { diff --git a/include/ttmlir/Utils.h b/include/ttmlir/Utils.h index ebc3ba015b..d30b52d734 100644 --- a/include/ttmlir/Utils.h +++ b/include/ttmlir/Utils.h @@ -248,6 +248,51 @@ getPairOfInteger(mlir::Attribute attr) { return std::make_pair(x, y); } +/// For a given mlir::Attribute attr, returns a tuple of four integers of type +/// ReturnTy. If attr is an IntegerAttr, it's interpreted as a +/// (value(attr), value(attr), value(attr), value(attr)) tuple, where +/// value(attr) is of type ScalarTy. If attr is a +/// DenseArrayAttr of size 2, it's interpreted as a (attr[0], +/// attr[1], attr[0], attr[1]) tuple. If attr is a +/// DenseArrayAttr of size 4, it is returned directly as +/// (attr[0], attr[1], attr[2], attr[3]). Otherwise, returns an error message. +template +inline llvm::Expected> +getQuadrupleOfInteger(mlir::Attribute attr) { + ReturnTy x{}, y{}, z{}, w{}; + + // If attr is IntegerAttr, interpret it as (attr, attr, attr, attr) + if (auto value = mlir::dyn_cast(attr)) { + x = y = z = w = integerAs(value.getValue()); + } + // If attr is DenseArrayAttr, handle based on its size + else if (auto tuple = mlir::dyn_cast< + ::mlir::detail::DenseArrayAttrImpl>(attr)) { + if (tuple.size() == 2) { + x = tuple[0]; + y = tuple[1]; + z = tuple[0]; + w = tuple[1]; + } else if (tuple.size() == 4) { + x = tuple[0]; + y = tuple[1]; + z = tuple[2]; + w = tuple[3]; + } else { + return llvm::createStringError("Expected integer, pair, or tuple of size " + "4, but got tuple of size %lu", + tuple.size()); + } + } + // If attr is of an unsupported type + else { + return llvm::createStringError("Unexpected attribute type"); + } + + return std::make_tuple(x, y, z, w); +} + } // namespace ttmlir::utils #endif diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp index 2f01299a69..c88e9dcb40 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp @@ -379,31 +379,29 @@ struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern { return failure(); } - auto strideHeightAttr = rewriter.getSI32IntegerAttr( - adaptor.getWindowStrides()[SPATIAL_DIM_HEIGHT]); - auto strideWidthAttr = rewriter.getSI32IntegerAttr( - adaptor.getWindowStrides()[SPATIAL_DIM_WIDTH]); - auto dilationHeightAttr = rewriter.getSI32IntegerAttr( - adaptor.getWeightDilation()[SPATIAL_DIM_HEIGHT]); - auto dilationWidthAttr = rewriter.getSI32IntegerAttr( - adaptor.getWeightDilation()[SPATIAL_DIM_WIDTH]); + auto strideAttr = rewriter.getDenseI32ArrayAttr({ + static_cast(adaptor.getWindowStrides()[SPATIAL_DIM_HEIGHT]), + static_cast(adaptor.getWindowStrides()[SPATIAL_DIM_WIDTH]), + }); + auto dilationAttr = rewriter.getDenseI32ArrayAttr({ + static_cast(adaptor.getWeightDilation()[SPATIAL_DIM_HEIGHT]), + static_cast(adaptor.getWeightDilation()[SPATIAL_DIM_WIDTH]), + }); // Padding is a list of 2-tuples, the order of the 2-tuples is in // most-significant spatial dimension first order For Conv2d the most // significant spatial dimension is the height, followed by the width. auto paddingMatrix = getPaddingMatrix(adaptor.getPadding()); - auto paddingTopAttr = - rewriter.getSI32IntegerAttr(paddingMatrix[SPATIAL_DIM_HEIGHT][0]); - auto paddingBottomAttr = - rewriter.getSI32IntegerAttr(paddingMatrix[SPATIAL_DIM_HEIGHT][1]); - auto paddingLeftAttr = - rewriter.getSI32IntegerAttr(paddingMatrix[SPATIAL_DIM_WIDTH][0]); - auto paddingRightAttr = - rewriter.getSI32IntegerAttr(paddingMatrix[SPATIAL_DIM_WIDTH][1]); + auto paddingAttr = rewriter.getDenseI32ArrayAttr({ + static_cast(paddingMatrix[SPATIAL_DIM_HEIGHT][0]), + static_cast(paddingMatrix[SPATIAL_DIM_WIDTH][0]), + static_cast(paddingMatrix[SPATIAL_DIM_HEIGHT][1]), + static_cast(paddingMatrix[SPATIAL_DIM_WIDTH][1]), + }); auto groupsAttr = - rewriter.getSI32IntegerAttr(adaptor.getFeatureGroupCount()); + rewriter.getI32IntegerAttr(adaptor.getFeatureGroupCount()); llvm::ArrayRef outputShape = op.getResult().getType().getShape(); llvm::SmallVector newOutputShape{ @@ -445,9 +443,7 @@ struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern { weightDPSOutput, kernelPermutation); ttir::Conv2dOp newConv = rewriter.create( op.getLoc(), outputType, input, weight, adaptor.getBias(), - convDPSOutput, strideHeightAttr, strideWidthAttr, dilationHeightAttr, - dilationWidthAttr, groupsAttr, paddingLeftAttr, paddingRightAttr, - paddingTopAttr, paddingBottomAttr); + convDPSOutput, strideAttr, paddingAttr, dilationAttr, groupsAttr); // Applying the inverse of permutation to the output will restore the // tensor to the original layout. diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 99f31cae8f..e0ee62f94e 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -926,81 +926,78 @@ class Conv2dOpConversionPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op); - auto kernel_ty = - mlir::cast(adaptor.getWeight().getType()); - llvm::ArrayRef kernel_shape = kernel_ty.getShape(); - auto input_ty = mlir::cast(adaptor.getInput().getType()); - llvm::ArrayRef input_shape = input_ty.getShape(); + auto inputTy = mlir::cast(adaptor.getInput().getType()); + auto kernelTy = mlir::cast(adaptor.getWeight().getType()); + auto outputTy = mlir::cast(adaptor.getOutput().getType()); - auto output_ty = - mlir::cast(adaptor.getOutput().getType()); - llvm::ArrayRef output_shape = output_ty.getShape(); + auto batchSizeAttr = rewriter.getI32IntegerAttr(inputTy.getDimSize(0)); + auto inputHeightAttr = rewriter.getI32IntegerAttr(inputTy.getDimSize(1)); + auto inputWidthAttr = rewriter.getI32IntegerAttr(inputTy.getDimSize(2)); + auto inChannelsAttr = rewriter.getI32IntegerAttr(inputTy.getDimSize(3)); + auto outChannelsAttr = rewriter.getI32IntegerAttr(outputTy.getDimSize(3)); - auto in_channels = - rewriter.getI32IntegerAttr(input_shape[input_shape.size() - 1]); - auto out_channels = - rewriter.getI32IntegerAttr(output_shape[output_shape.size() - 1]); - auto batch_size = - rewriter.getI32IntegerAttr(input_shape[input_shape.size() - 4]); - auto input_height = - rewriter.getI32IntegerAttr(input_shape[input_shape.size() - 3]); - auto input_width = - rewriter.getI32IntegerAttr(input_shape[input_shape.size() - 2]); - - auto kernel_height = - rewriter.getI32IntegerAttr(kernel_shape[kernel_shape.size() - 2]); - auto kernel_width = - rewriter.getI32IntegerAttr(kernel_shape[kernel_shape.size() - 1]); - - auto stride_height = rewriter.getI32IntegerAttr(adaptor.getStrideHeight()); - auto stride_width = rewriter.getI32IntegerAttr(adaptor.getStrideWidth()); - - assert( - adaptor.getPaddingBottom() == adaptor.getPaddingTop() && - "TTNN only supports padding height/width attributes. Thus, padding_top " - "must equal padding_bottom for the op to execute as expected."); - assert(adaptor.getPaddingLeft() == adaptor.getPaddingRight() && - "TTNN only supports padding height/width attributes. Thus, " - "padding_left must equal padding_right for the op to execute as " - "expected."); - auto padding_height = rewriter.getI32IntegerAttr(adaptor.getPaddingTop()); - auto padding_width = rewriter.getI32IntegerAttr(adaptor.getPaddingRight()); - - auto dilation_height = - rewriter.getI32IntegerAttr(adaptor.getDilationHeight()); - auto dilation_width = - rewriter.getI32IntegerAttr(adaptor.getDilationWidth()); - auto groups = rewriter.getI32IntegerAttr(adaptor.getGroups()); - - std::vector flattenedInputShape = { - 1, 1, input_shape[0] * input_shape[1] * input_shape[2], input_shape[3]}; - Value flattenedInput = ttir_to_ttnn::utils::generateNHWFlatten( - mlir::cast>(adaptor.getInput()), - rewriter); + auto kernelSizeAttr = rewriter.getDenseI32ArrayAttr( + {static_cast(kernelTy.getDimSize(2)), + static_cast(kernelTy.getDimSize(3))}); - std::vector flattenedOutputShape = { + auto strideAttr = attrToDenseI32ArrayAttr(adaptor.getStride(), rewriter); + if (auto error = strideAttr.takeError()) { + return LogicalResult::failure(); + } + + auto paddingAttr = + attrToDenseI32ArrayAttr(adaptor.getPadding(), rewriter, 4); + if (auto error = paddingAttr.takeError()) { + return LogicalResult::failure(); + } + + auto paddingArrayRef = paddingAttr->asArrayRef(); + if (paddingArrayRef[0] != paddingArrayRef[1] || + paddingArrayRef[2] != paddingArrayRef[3]) { + return rewriter.notifyMatchFailure( + op, + "TTNN only supports padding height/width attributes. Thus, " + "padding_top/padding_left must equal padding_bottom/padding_right " + "for the op to execute as expected."); + } + + // Padding only supports 2 values in ttnn + auto reducedPaddingAttr = + rewriter.getDenseI32ArrayAttr({paddingArrayRef[0], paddingArrayRef[1]}); + + auto dilationAttr = + attrToDenseI32ArrayAttr(adaptor.getDilation(), rewriter); + if (auto error = dilationAttr.takeError()) { + return LogicalResult::failure(); + } + + auto groupsAttr = rewriter.getI32IntegerAttr(adaptor.getGroups()); + + // Convolution in ttnn returns a tensor in a flattened shape + // (1 x 1 x N * H * W x C) + llvm::ArrayRef output_shape = outputTy.getShape(); + llvm::SmallVector flattenedOutputShape = { 1, 1, output_shape[0] * output_shape[1] * output_shape[2], output_shape[3]}; - - output_ty = mlir::cast(getTypeConverter()->convertType( - output_ty.cloneWith(flattenedOutputShape, output_ty.getElementType()))); + outputTy = mlir::cast(getTypeConverter()->convertType( + outputTy.cloneWith(flattenedOutputShape, outputTy.getElementType()))); // Using a tensor::EmptyOp so that the rewriter for EmptyOp can handle the // attribute determination auto convDPSOutput = rewriter.replaceOpWithNewOp( adaptor.getOutput().getDefiningOp(), flattenedOutputShape, - output_ty.getElementType()); + outputTy.getElementType()); // Must set the type to the output type to maintain the layout attributes - convDPSOutput.getResult().setType(output_ty); + convDPSOutput.getResult().setType(outputTy); ttnn::Conv2dOp new_conv = rewriter.create( - op.getLoc(), output_ty, flattenedInput, adaptor.getWeight(), - adaptor.getBias(), convDPSOutput, device, in_channels, out_channels, - batch_size, input_height, input_width, kernel_height, kernel_width, - stride_height, stride_width, padding_height, padding_width, - dilation_height, dilation_width, groups); + op.getLoc(), outputTy, adaptor.getInput(), adaptor.getWeight(), + adaptor.getBias(), convDPSOutput, device, inChannelsAttr, + outChannelsAttr, batchSizeAttr, inputHeightAttr, inputWidthAttr, + kernelSizeAttr, *strideAttr, reducedPaddingAttr, *dilationAttr, + groupsAttr, nullptr); Value output = ttir_to_ttnn::utils::generateReshape(new_conv, output_shape, rewriter); @@ -1008,6 +1005,38 @@ class Conv2dOpConversionPattern : public OpConversionPattern { rewriter.replaceOp(op, output); return success(); } + +private: + llvm::Expected + attrToDenseI32ArrayAttr(mlir::Attribute attr, + ConversionPatternRewriter &rewriter, + uint32_t elementCount = 2) const { + + DenseI32ArrayAttr arrayAttr; + if (elementCount == 2) { + // Handles attributes requiring 2 spatial dimensions (e.g., stride, + // dilation). Converts the attribute into a pair of integers. + auto pair = ttmlir::utils::getPairOfInteger(attr); + if (auto error = pair.takeError()) { + return error; + } + + arrayAttr = rewriter.getDenseI32ArrayAttr({pair->first, pair->second}); + } else if (elementCount == 4) { + // Handles attributes requiring 4 spatial dimensions (e.g., padding in + // this case). Converts the attribute into a quadruple of integers. + auto quadruple = ttmlir::utils::getQuadrupleOfInteger(attr); + if (auto error = quadruple.takeError()) { + return error; + } + + arrayAttr = rewriter.getDenseI32ArrayAttr( + {std::get<0>(*quadruple), std::get<1>(*quadruple), + std::get<2>(*quadruple), std::get<3>(*quadruple)}); + } + + return arrayAttr; + } }; } // namespace diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index c6365b1ad2..7d711cbe2e 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -167,23 +167,161 @@ mlir::tt::ttir::GetDimensionSizeOp::fold(FoldAdaptor adaptor) { // Conv2dOp verification ::mlir::LogicalResult mlir::tt::ttir::Conv2dOp::verify() { - ::mlir::RankedTensorType inputType = getInput().getType(); - ::mlir::RankedTensorType weightType = getWeight().getType(); - std::optional<::mlir::RankedTensorType> biasType = + mlir::RankedTensorType inputType = getInput().getType(); + mlir::RankedTensorType weightType = getWeight().getType(); + mlir::RankedTensorType outputType = getOutput().getType(); + std::optional bias = getBias().getImpl() ? std::make_optional(getBias().getType()) : std::nullopt; - if (inputType.getRank() < 3) { - return emitOpError("Input must be at least a 3D tensor"); + if (inputType.getRank() != 4) { + return emitOpError("Input must be a 4D tensor"); + } + + if (outputType.getRank() != 4) { + return emitOpError("Output must be a 4D tensor"); } + if (weightType.getRank() != 4) { return emitOpError("Weight must be a 4D tensor"); } - if (biasType.has_value()) { - if (biasType->getRank() != 4) { + + if (bias.has_value()) { + if (bias->getRank() != 4) { return emitOpError("Bias must be a 4D tensor"); } + auto biasShape = bias->getShape(); + if (biasShape[0] != 1 || biasShape[1] != 1 || biasShape[2] != 1) { + return emitOpError("Bias must only have data on the final dimenstion"); + } + } + + uint32_t batchSize = inputType.getDimSize(0); + if (batchSize != outputType.getDimSize(0)) { + return emitOpError( + "First dimension of the input tensor must match the first dimension of " + "the output tensor, got: " + + std::to_string(batchSize) + " and " + + std::to_string(outputType.getDimSize(0))); + } + + uint32_t inputHeight = inputType.getDimSize(1); + uint32_t inputWidth = inputType.getDimSize(2); + uint32_t inChannels = inputType.getDimSize(3); + uint32_t outChannels = outputType.getDimSize(3); + + auto stride = ttmlir::utils::getPairOfInteger(getStride()); + if (auto error = stride.takeError()) { + return emitOpError() << llvm::toString(std::move(error)) << " for stride"; + } + if (stride->first < 1 || stride->second < 1) { + return emitOpError("Stride values must be greater than 0"); + } + + auto padding = ttmlir::utils::getQuadrupleOfInteger(getPadding()); + if (auto error = padding.takeError()) { + return emitOpError() << llvm::toString(std::move(error)) << " for padding"; + } + + auto [paddingTop, paddingLeft, paddingBottom, paddingRight] = *padding; + if (paddingTop < 0 || paddingBottom < 0 || paddingLeft < 0 || + paddingRight < 0) { + return emitOpError("Padding values must be greater or equal than 0"); + } + int32_t verticalPadding = paddingTop + paddingBottom; + int32_t horizontalPadding = paddingLeft + paddingRight; + + auto dilation = ttmlir::utils::getPairOfInteger(getDilation()); + if (auto error = dilation.takeError()) { + return emitOpError() << llvm::toString(std::move(error)) << " for dilation"; + } + if (dilation->first < 1 || dilation->second < 1) { + return emitOpError("Dilation values must be greater than 0"); + } + + llvm::ArrayRef kernelSize = { + static_cast(weightType.getDimSize(2)), + static_cast(weightType.getDimSize(3))}; + + llvm::SmallVector paddedInputSize = { + inputHeight + verticalPadding, inputWidth + horizontalPadding}; + llvm::SmallVector effectiveKernelSize = { + static_cast(kernelSize[0] + + (kernelSize[0] - 1) * (dilation->first - 1)), + static_cast(kernelSize[1] + + (kernelSize[1] - 1) * (dilation->second - 1))}; + if (paddedInputSize[0] < effectiveKernelSize[0] || + paddedInputSize[1] < effectiveKernelSize[1]) { + return emitOpError( + "Calculated padded input size per channel: (" + + std::to_string(paddedInputSize[0]) + " x " + + std::to_string(paddedInputSize[1]) + "). Kernel size: (" + + std::to_string(effectiveKernelSize[0]) + " x " + + std::to_string(effectiveKernelSize[1]) + + "). Kernel size can't be greater than actual input size"); + } + + uint32_t groups = getGroups(); + if (inChannels % groups != 0) { + return emitOpError() << "Number of input channels from input tensor must " + "be divisible by the number of groups. " + << "Got " << inChannels << " input channels and " + << groups << " groups"; + } + + if (outChannels % groups != 0) { + return emitOpError() << "Number of output channels from output tensor must " + "be divisible by the number of groups. " + << "Got " << outChannels << " output channels and " + << groups << " groups"; + } + + llvm::ArrayRef kernelShape = weightType.getShape(); + if (outChannels != kernelShape[0]) { + return emitOpError() << "Number of output channels from output tensor must " + "match the first dimension of the weight tensor. " + << "Got " << outChannels << " output channels and " + << kernelShape[0] << " in the weight tensor"; + } + + if (inChannels / groups != kernelShape[1]) { + return emitOpError() << "Number of input channels per group must match " + "the second dimension of the weight tensor. " + << "Got " << (inChannels / groups) + << " input channels per group and " << kernelShape[1] + << " in the weight tensor"; + } + + if (bias) { + if (bias->getDimSize(bias->getRank() - 1) != outChannels) { + return emitOpError() << "Mismatch in bias tensor dimensions. " + << "Bias tensor has " + << bias->getDimSize(bias->getRank() - 1) + << " channels, " + << "but the output tensor has " << outChannels + << " channels"; + } } + + int32_t calculatedHOut = (inputHeight + verticalPadding - + dilation->first * (kernelSize[0] - 1) - 1) / + stride->first + + 1; + int32_t calculatedWOut = (inputWidth + horizontalPadding - + dilation->second * (kernelSize[1] - 1) - 1) / + stride->second + + 1; + if (calculatedHOut != outputType.getDimSize(1) || + calculatedWOut != outputType.getDimSize(2)) { + return emitOpError() + << "Mismatch between calculated and got output height and width. " + << "Calculated: (" << calculatedHOut << " x " << calculatedWOut + << "). " + << "Got output tensor height and width: (" + << outputType.getDimSize(1) << " x " << outputType.getDimSize(2) + << ")"; + } + return success(); } diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index f19fe72540..512f242d87 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -55,27 +55,196 @@ ::mlir::LogicalResult mlir::tt::ttnn::ClampOp::verify() { // Conv2dOp verification ::mlir::LogicalResult mlir::tt::ttnn::Conv2dOp::verify() { - ::mlir::RankedTensorType inputType = getInput().getType(); - ::mlir::RankedTensorType weightType = getWeight().getType(); - std::optional<::mlir::RankedTensorType> biasType = + mlir::RankedTensorType inputType = getInput().getType(); + mlir::RankedTensorType weightType = getWeight().getType(); + mlir::RankedTensorType outputType = getOutput().getType(); + std::optional bias = getBias().getImpl() ? std::make_optional(getBias().getType()) : std::nullopt; if (inputType.getRank() < 3) { - return emitOpError("Input must be at least a 3D tensor"); + return emitOpError("Input must be a 4D tensor"); + } + + if (outputType.getRank() != 4) { + return emitOpError("Output must be a 4D tensor"); } + if (weightType.getRank() != 4) { return emitOpError("Weight must be a 4D tensor"); } - if (biasType.has_value()) { - if (biasType->getRank() != 4) { + + if (bias.has_value()) { + if (bias->getRank() != 4) { return emitOpError("Bias must be a 4D tensor"); } - auto biasShape = biasType->getShape(); + auto biasShape = bias->getShape(); if (biasShape[0] != 1 || biasShape[1] != 1 || biasShape[2] != 1) { return emitOpError("Bias must only have data on the final dimenstion"); } } + + uint32_t inChannels = getInChannels(); + if (inChannels != inputType.getDimSize(inputType.getRank() - 1)) { + return emitOpError( + "Input channels attribute must match the last dimension of the input " + "tensor, got: " + + std::to_string(inChannels) + " and " + + std::to_string(inputType.getDimSize(inputType.getRank() - 1))); + } + + uint32_t outChannels = getOutChannels(); + if (outChannels != outputType.getDimSize(outputType.getRank() - 1)) { + return emitOpError( + "Output channels attribute must match the last dimension of the output " + "tensor, got: " + + std::to_string(outChannels) + " and " + + std::to_string(outputType.getDimSize(outputType.getRank() - 1))); + } + + uint32_t batchSize = getBatchSize(); + if (batchSize != inputType.getDimSize(0)) { + return emitOpError("Batch size attribute must match the first dimension of " + "the input tensor, got: " + + std::to_string(batchSize) + " and " + + std::to_string(inputType.getDimSize(0))); + } + if (batchSize != outputType.getDimSize(0)) { + return emitOpError("Batch size attribute must match the first dimension of " + "the output tensor, got: " + + std::to_string(batchSize) + " and " + + std::to_string(outputType.getDimSize(0))); + } + + uint32_t inputHeight = getInputHeight(); + if (inputHeight != inputType.getDimSize(inputType.getRank() - 3)) { + return emitOpError( + "Input height attribute must match the second dimension of the input " + "tensor, got: " + + std::to_string(inputHeight) + " and " + + std::to_string(inputType.getDimSize(inputType.getRank() - 3))); + } + + uint32_t inputWidth = getInputWidth(); + if (inputWidth != inputType.getDimSize(inputType.getRank() - 2)) { + return emitOpError( + "Input width attribute must match the third dimension of the input " + "tensor, got: " + + std::to_string(inputWidth) + " and " + + std::to_string(inputType.getDimSize(inputType.getRank() - 2))); + } + + llvm::ArrayRef stride = getStride(); + if (!std::all_of(stride.begin(), stride.end(), + [](int32_t value) { return value >= 1; })) { + return emitOpError( + "Stride attribute must be greater than or equal to 1, got: " + + std::to_string(stride[0]) + ", " + std::to_string(stride[1])); + } + + llvm::ArrayRef padding = getPadding(); + if (!std::all_of(padding.begin(), padding.end(), + [](int32_t value) { return value >= 0; })) { + return emitOpError( + "Padding attribute must be greater than or equal to 0, got: " + + std::to_string(padding[0]) + ", " + std::to_string(padding[1])); + } + + llvm::ArrayRef dilation = getDilation(); + if (!std::all_of(dilation.begin(), dilation.end(), + [](int32_t value) { return value >= 1; })) { + return emitOpError( + "Dilation attribute must be greater than or equal to 1, got: " + + std::to_string(dilation[0]) + ", " + std::to_string(dilation[1])); + } + + llvm::ArrayRef kernelSize = getKernelSize(); + if (kernelSize[0] != weightType.getDimSize(2) || + kernelSize[1] != weightType.getDimSize(3)) { + return emitOpError("Kernel size attribute must match the last two " + "dimensions of the weight tensor, got: " + + std::to_string(kernelSize[0]) + ", " + + std::to_string(kernelSize[1]) + " and " + + std::to_string(weightType.getDimSize(2)) + ", " + + std::to_string(weightType.getDimSize(3))); + } + + llvm::SmallVector paddedInputSize = { + inputHeight + 2 * padding[0], inputWidth + 2 * padding[1]}; + llvm::SmallVector effectiveKernelSize = { + static_cast(kernelSize[0] + + (kernelSize[0] - 1) * (dilation[0] - 1)), + static_cast(kernelSize[1] + + (kernelSize[1] - 1) * (dilation[1] - 1))}; + if (paddedInputSize[0] < effectiveKernelSize[0] || + paddedInputSize[1] < effectiveKernelSize[1]) { + return emitOpError( + "Calculated padded input size per channel: (" + + std::to_string(paddedInputSize[0]) + " x " + + std::to_string(paddedInputSize[1]) + "). Kernel size: (" + + std::to_string(effectiveKernelSize[0]) + " x " + + std::to_string(effectiveKernelSize[1]) + + "). Kernel size can't be greater than actual input size"); + } + + uint32_t groups = getGroups(); + if (inChannels % groups != 0) { + return emitOpError() << "Number of input channels from input tensor must " + "be divisible by the number of groups. " + << "Got " << inChannels << " input channels and " + << groups << " groups."; + } + + if (outChannels % groups != 0) { + return emitOpError() << "Number of output channels from output tensor must " + "be divisible by the number of groups. " + << "Got " << outChannels << " output channels and " + << groups << " groups."; + } + + llvm::ArrayRef kernelShape = weightType.getShape(); + if (outChannels != kernelShape[0]) { + return emitOpError() << "Number of output channels from output tensor must " + "match the first dimension of the weight tensor. " + << "Got " << outChannels << " output channels and " + << kernelShape[0] << " in the weight tensor."; + } + + if (inChannels / groups != kernelShape[1]) { + return emitOpError() << "Number of input channels per group must match " + "the second dimension of the weight tensor. " + << "Got " << (inChannels / groups) + << " input channels per group and " << kernelShape[1] + << " in the weight tensor."; + } + + if (bias) { + if (bias->getDimSize(bias->getRank() - 1) != outChannels) { + return emitOpError() << "Mismatch in bias tensor dimensions. " + << "Bias tensor has " + << bias->getDimSize(bias->getRank() - 1) + << " channels, " + << "but the output tensor has " << outChannels + << " channels."; + } + } + + int32_t calculatedHOut = + (inputHeight + 2 * padding[0] - dilation[0] * (kernelSize[0] - 1) - 1) / + stride[0] + + 1; + int32_t calculatedWOut = + (inputWidth + 2 * padding[1] - dilation[1] * (kernelSize[1] - 1) - 1) / + stride[1] + + 1; + if (calculatedHOut < 0 || calculatedWOut < 0) { + return emitOpError() << "Given input size per channel: (" << inputHeight + << " x " << inputWidth << "). " + << "Calculated output size per channel: (" + << calculatedHOut << " x " << calculatedWOut << "). " + << "Output size is too small"; + } + return success(); } @@ -143,13 +312,13 @@ ::mlir::LogicalResult mlir::tt::ttnn::ConvTranspose2dOp::verify() { uint32_t inputHeight = getInputHeight(); if (inputHeight != inputType.getDimSize(inputType.getRank() - 3)) { - return emitOpError("Input height attribute must match the third " + return emitOpError("Input height attribute must match the second " "dimension of the input tensor"); } uint32_t inputWidth = getInputWidth(); if (inputWidth != inputType.getDimSize(inputType.getRank() - 2)) { - return emitOpError("Input width attribute must match the second " + return emitOpError("Input width attribute must match the third " "dimension of the input tensor"); } diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 1433795511..4d6cdceab0 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -149,6 +149,44 @@ memoryConfigToFlatbuffer(FlatbufferObjectCache &cache, return memoryConfigDesc; } +::flatbuffers::Offset<::tt::target::Conv2dConfigDesc> +conv2dConfigToFlatbuffer(FlatbufferObjectCache &cache, + ::mlir::tt::ttnn::Conv2dConfigAttr conv2dConfig) { + ::tt::target::DataType dtype = + ::tt::mlir::ttnn::utils::toTargetDataType(conv2dConfig.getDtype()); + ::tt::target::DataType weightsDtype = + ::tt::mlir::ttnn::utils::toTargetDataType(conv2dConfig.getWeightsDtype()); + ::flatbuffers::Offset<::flatbuffers::String> activation = + toFlatbuffer(cache, conv2dConfig.getActivation().getValue()); + ::flatbuffers::Optional<::tt::target::TensorMemoryLayout> shardLayout = + conv2dConfig.getShardLayout() + ? std::optional{::tt::mlir::ttnn::utils::toTargetTensorMemoryLayout( + conv2dConfig.getShardLayout().getValue())} + : std::nullopt; + ::flatbuffers::Optional coreGrid = std::nullopt; + ::tt::target::TensorLayout outputLayout = + ::tt::mlir::ttnn::utils::toTargetTensorLayout( + conv2dConfig.getOutputLayout()); + + ::flatbuffers::Offset<::tt::target::Conv2dConfigDesc> conv2dConfigDesc = + ::tt::target::CreateConv2dConfigDesc( + *cache.fbb, dtype, weightsDtype, activation, + conv2dConfig.getInputChannelsAlignment().getInt(), + conv2dConfig.getDeallocateActivation().getValue(), + conv2dConfig.getReallocateHaloOutput().getValue(), + conv2dConfig.getActBlockHOverride().getInt(), + conv2dConfig.getActBlockWDiv().getInt(), + conv2dConfig.getReshardIfNotOptimal().getValue(), + conv2dConfig.getOverrideShardingConfig().getValue(), shardLayout, + coreGrid, conv2dConfig.getTransposeShards().getValue(), outputLayout, + conv2dConfig.getEnableActDoubleBuffer().getValue(), + conv2dConfig.getEnableWeightsDoubleBuffer().getValue(), + conv2dConfig.getEnableSplitReader().getValue(), + conv2dConfig.getEnableSubblockPadding().getValue()); + + return conv2dConfigDesc; +} + ::flatbuffers::Offset<::tt::target::DeviceRef> createDeviceRef(FlatbufferObjectCache &cache, Value device) { auto deviceType = mlir::cast(device.getType()); @@ -493,26 +531,40 @@ createOp(FlatbufferObjectCache &cache, MorehCumSumOp op) { ::flatbuffers::Offset<::tt::target::ttnn::Conv2dOp> createOp(FlatbufferObjectCache &cache, Conv2dOp op) { - auto in0 = + auto input = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); - auto in1 = cache.at<::tt::target::TensorRef>( + auto weight = cache.at<::tt::target::TensorRef>( getOperandThroughDPSOps(op.getWeight())); - auto in2 = op.getODSOperands(2).empty() - ? flatbuffers::Offset<::tt::target::TensorRef>() - : cache.at<::tt::target::TensorRef>( - getOperandThroughDPSOps(op.getBias())); + auto bias = op.getODSOperands(2).empty() + ? flatbuffers::Offset<::tt::target::TensorRef>() + : cache.at<::tt::target::TensorRef>( + getOperandThroughDPSOps(op.getBias())); auto output = cache.at<::tt::target::TensorRef>( getOperandThroughDPSOps(op.getResult())); auto device = getOperandThroughDPSOps(op.getDevice()); + + ::flatbuffers::Offset<::flatbuffers::Vector> kernelSize = + toFlatbuffer(cache, op.getKernelSize()); + ::flatbuffers::Offset<::flatbuffers::Vector> stride = + toFlatbuffer(cache, op.getStride()); + ::flatbuffers::Offset<::flatbuffers::Vector> padding = + toFlatbuffer(cache, op.getPadding()); + ::flatbuffers::Offset<::flatbuffers::Vector> dilation = + toFlatbuffer(cache, op.getDilation()); + + std::optional<::flatbuffers::Offset<::tt::target::Conv2dConfigDesc>> + conv2dConfig = + op.getConv2dConfig() ? std::optional{conv2dConfigToFlatbuffer( + cache, *op.getConv2dConfig())} + : std::nullopt; + return ::tt::target::ttnn::CreateConv2dOp( - *cache.fbb, in0, in1, in2, output, + *cache.fbb, input, weight, bias, output, cache.at<::tt::target::DeviceRef>(device), op.getInChannels(), op.getOutChannels(), op.getBatchSize(), op.getInputHeight(), - op.getInputWidth(), op.getKernelHeight(), op.getKernelWidth(), - op.getStrideHeight(), op.getStrideWidth(), op.getPaddingHeight(), - op.getPaddingWidth(), op.getDilationHeight(), op.getDilationWidth(), - op.getGroups()); + op.getInputWidth(), kernelSize, stride, padding, dilation, op.getGroups(), + conv2dConfig ? *conv2dConfig : 0); } ::flatbuffers::Offset<::tt::target::ttnn::ConvTranspose2dOp> diff --git a/runtime/lib/ttnn/operations/conv/conv2d.cpp b/runtime/lib/ttnn/operations/conv/conv2d.cpp index 26d71df1ac..85b3543302 100644 --- a/runtime/lib/ttnn/operations/conv/conv2d.cpp +++ b/runtime/lib/ttnn/operations/conv/conv2d.cpp @@ -21,9 +21,19 @@ void run(const ::tt::target::ttnn::Conv2dOp *op, ProgramContext &context) { std::optional<::ttnn::Tensor> bias = op->bias() ? std::make_optional(tensorPool.at(op->bias()->global_id())) : std::nullopt; - auto config = ::ttnn::operations::conv::Conv2dConfig(); - config.dtype = utils::getDataType(op->input()); - config.weights_dtype = utils::getDataType(op->weight()); + + std::array kernelSize, stride, padding, dilation; + std::copy_n(op->kernel_size()->begin(), 2, kernelSize.begin()); + std::copy_n(op->stride()->begin(), 2, stride.begin()); + std::copy_n(op->padding()->begin(), 2, padding.begin()); + std::copy_n(op->dilation()->begin(), 2, dilation.begin()); + + std::optional<::ttnn::operations::conv::Conv2dConfig> conv2dConfig = + op->conv2d_config() + ? std::make_optional( + ::tt::runtime::ttnn::operations::utils::createConv2dConfig( + op->conv2d_config())) + : std::nullopt; // Use defaults for now, until compiler drives this. std::optional<::ttnn::DeviceComputeKernelConfig> computeConfig = std::nullopt; @@ -37,11 +47,8 @@ void run(const ::tt::target::ttnn::Conv2dOp *op, ProgramContext &context) { return std::get<0>(::ttnn::operations::conv::conv2d::conv2d( input, weight, &(targetDevice.get()), op->in_channels(), op->out_channels(), op->batch_size(), op->input_height(), - op->input_width(), {op->kernel_height(), op->kernel_width()}, - {op->stride_height(), op->stride_width()}, - {op->padding_height(), op->padding_width()}, - {op->dilation_height(), op->dilation_width()}, op->groups(), bias, - config, computeConfig, outMemConfig)); + op->input_width(), kernelSize, stride, padding, dilation, + op->groups(), bias, conv2dConfig, computeConfig, outMemConfig)); }, targetDevice); diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp index e2bfee6cf7..1b8231d74b 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp @@ -102,4 +102,34 @@ ::tt::tt_metal::DistributedTensorConfig distributedTensorConfigFromFlatbuffer( } } } + +::ttnn::operations::conv::conv2d::Conv2dConfig +createConv2dConfig(const ::tt::target::Conv2dConfigDesc *memcfg) { + std::optional shardLayout = std::nullopt; + if (memcfg->shard_layout()) { + shardLayout = ::tt::runtime::ttnn::utils::toTTNNTensorMemoryLayout( + memcfg->shard_layout().value()); + } + + ::ttnn::operations::conv::conv2d::Conv2dConfig conv2dConfig = { + .dtype = ::tt::runtime::ttnn::utils::toTTNNDataType(memcfg->dtype()), + .weights_dtype = + ::tt::runtime::ttnn::utils::toTTNNDataType(memcfg->weights_dtype()), + .activation = memcfg->activation()->str(), + .input_channels_alignment = memcfg->input_channels_alignment(), + .deallocate_activation = memcfg->deallocate_activation(), + .reallocate_halo_output = memcfg->reallocate_halo_output(), + .act_block_h_override = memcfg->act_block_h_override(), + .act_block_w_div = memcfg->act_block_w_div(), + .reshard_if_not_optimal = memcfg->reshard_if_not_optimal(), + .override_sharding_config = memcfg->override_sharding_config(), + .shard_layout = shardLayout, + .core_grid = std::nullopt, + .transpose_shards = memcfg->transpose_shards(), + .output_layout = + ::tt::runtime::ttnn::utils::toTTNNLayout(memcfg->output_layout()), + }; + + return conv2dConfig; +} } // namespace tt::runtime::ttnn::operations::utils diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h index f589362af2..e1e4e56640 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h @@ -40,5 +40,8 @@ inline ::ttnn::Shape toTTNNShape(const flatbuffers::Vector &vec) { return ::ttnn::Shape(rawShape); } +::ttnn::operations::conv::conv2d::Conv2dConfig +createConv2dConfig(const ::tt::target::Conv2dConfigDesc *memcfg); + } // namespace tt::runtime::ttnn::operations::utils #endif diff --git a/test/ttmlir/Dialect/TTIR/conv2d/conv2d_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/conv2d/conv2d_tests_negative.mlir new file mode 100644 index 0000000000..55e73e2c07 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/conv2d/conv2d_tests_negative.mlir @@ -0,0 +1,277 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for conv2d operation + +// Verify that the parsing fails if tensors don't have four dimensions +module attributes {} { + func.func @conv2d_invalid_input_shape(%arg0: tensor<32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x30x30x64xbf16> { + %0 = tensor.empty() : tensor<1x30x30x64xbf16> + // CHECK: error: 'ttir.conv2d' op Input must be a 4D tensor + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16> + return %1 : tensor<1x30x30x64xbf16> + } +} + +// ----- +module attributes {} { + func.func @conv2d_invalid_weight_shape(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x30x30x64xbf16> { + %0 = tensor.empty() : tensor<1x30x30x64xbf16> + // CHECK: error: 'ttir.conv2d' op Weight must be a 4D tensor + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16> + return %1 : tensor<1x30x30x64xbf16> + } +} + +// ----- +module attributes {} { + func.func @conv2d_invalid_bias_shape(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x64xbf16>) -> tensor<1x30x30x64xbf16> { + %0 = tensor.empty() : tensor<1x30x30x64xbf16> + // CHECK: error: 'ttir.conv2d' op Bias must be a 4D tensor + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16> + return %1 : tensor<1x30x30x64xbf16> + } +} + +// ----- +module attributes {} { + func.func @conv2d_invalid_output_shape(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<30x30x64xbf16> { + %0 = tensor.empty() : tensor<30x30x64xbf16> + // CHECK: error: 'ttir.conv2d' op Output must be a 4D tensor + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<30x30x64xbf16>) -> tensor<30x30x64xbf16> + return %1 : tensor<30x30x64xbf16> + } +} + +// Verify that the parsing fails if attributes are not integers or pair of integers +// ----- +module attributes {} { + func.func @conv2d_invalid_stride_shape(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x30x30x64xbf16> { + %0 = tensor.empty() : tensor<1x30x30x64xbf16> + // CHECK: error: 'ttir.conv2d' op Expected integer or pair of integers, got tuple of size 3 for stride + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = array, + padding = 0: i32, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16> + return %1 : tensor<1x30x30x64xbf16> + } +} + +// ----- +module attributes {} { + func.func @conv2d_invalid_padding_shape(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x30x30x64xbf16> { + %0 = tensor.empty() : tensor<1x30x30x64xbf16> + // CHECK: error: 'ttir.conv2d' op Expected integer, pair, or tuple of size 4, but got tuple of size 3 for padding + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = array, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16> + return %1 : tensor<1x30x30x64xbf16> + } +} + +// ----- +module attributes {} { + func.func @conv2d_invalid_dilation_shape(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x30x30x64xbf16> { + %0 = tensor.empty() : tensor<1x30x30x64xbf16> + // CHECK: error: 'ttir.conv2d' op Expected integer or pair of integers, got tuple of size 3 for dilation + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = array, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16> + return %1 : tensor<1x30x30x64xbf16> + } +} + +// Verify that the parsing fails if attributes have invalid values +// ----- +module attributes {} { + func.func @conv2d_invalid_stride_values(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x30x30x64xbf16> { + %0 = tensor.empty() : tensor<1x30x30x64xbf16> + // CHECK: error: 'ttir.conv2d' op Stride values must be greater than 0 + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = array, + padding = 0: i32, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16> + return %1 : tensor<1x30x30x64xbf16> + } +} + +// ----- +module attributes {} { + func.func @conv2d_invalid_padding_values(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x30x30x64xbf16> { + %0 = tensor.empty() : tensor<1x30x30x64xbf16> + // CHECK: error: 'ttir.conv2d' op Padding values must be greater or equal than 0 + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = array, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16> + return %1 : tensor<1x30x30x64xbf16> + } +} + +// ----- +module attributes {} { + func.func @conv2d_invalid_dilation_values(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x30x30x64xbf16> { + %0 = tensor.empty() : tensor<1x30x30x64xbf16> + // CHECK: error: 'ttir.conv2d' op Dilation values must be greater than 0 + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = array, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16> + return %1 : tensor<1x30x30x64xbf16> + } +} + +// Verify the parsing fails if number of channels are incorrect +// ----- +module attributes {} { + func.func @conv2d_input_channels_not_divisible_by_groups(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<100x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x30x30x100xbf16> { + %0 = tensor.empty() : tensor<1x30x30x100xbf16> + // CHECK: error: 'ttir.conv2d' op Number of input channels from input tensor must be divisible by the number of groups. Got 64 input channels and 10 groups + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = 1: i32, + groups = 10: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<100x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x100xbf16>) -> tensor<1x30x30x100xbf16> + return %1 : tensor<1x30x30x100xbf16> + } +} + +// ----- +module attributes {} { + func.func @conv2d_output_channels_not_divisible_by_groups(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<128x64x3x3xbf16>, %arg2: tensor<1x1x1x102xbf16>) -> tensor<1x30x30x102xbf16> { + %0 = tensor.empty() : tensor<1x30x30x102xbf16> + // CHECK: error: 'ttir.conv2d' op Number of output channels from output tensor must be divisible by the number of groups. Got 102 output channels and 8 groups + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = 1: i32, + groups = 8: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<128x64x3x3xbf16>, tensor<1x1x1x102xbf16>, tensor<1x30x30x102xbf16>) -> tensor<1x30x30x102xbf16> + return %1 : tensor<1x30x30x102xbf16> + } +} + +// ----- +module attributes {} { + func.func @conv2d_input_channels_missmatch_with_weight(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x128x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x30x30x64xbf16> { + %0 = tensor.empty() : tensor<1x30x30x64xbf16> + // CHECK: error: 'ttir.conv2d' op Number of input channels per group must match the second dimension of the weight tensor. Got 64 input channels per group and 128 in the weight tensor + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<64x128x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16> + return %1 : tensor<1x30x30x64xbf16> + } +} + +// ----- +module attributes {} { + func.func @conv2d_output_channels_missmatch_with_weight(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<128x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x30x30x64xbf16> { + %0 = tensor.empty() : tensor<1x30x30x64xbf16> + // CHECK: error: 'ttir.conv2d' op Number of output channels from output tensor must match the first dimension of the weight tensor. Got 64 output channels and 128 in the weight tensor + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<128x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16> + return %1 : tensor<1x30x30x64xbf16> + } +} + +// ----- +module attributes {} { + func.func @conv2d_output_channels_missmatch_with_bias(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x128xbf16>) -> tensor<1x30x30x64xbf16> { + %0 = tensor.empty() : tensor<1x30x30x64xbf16> + // CHECK: error: 'ttir.conv2d' op Mismatch in bias tensor dimensions. Bias tensor has 128 channels, but the output tensor has 64 channels + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x128xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16> + return %1 : tensor<1x30x30x64xbf16> + } +} + +// ----- +module attributes {} { + func.func @conv2d_input_size_smaller_than_kernel_size(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x30x30x64xbf16> { + %0 = tensor.empty() : tensor<1x30x30x64xbf16> + // CHECK: error: 'ttir.conv2d' op Calculated padded input size per channel: (56 x 56). Kernel size: (65 x 65). Kernel size can't be greater than actual input size + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 12: i32, + dilation = 32: i32, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16> + return %1 : tensor<1x30x30x64xbf16> + } +} + +// ----- +module attributes {} { + func.func @conv2d_calculated_output_size_per_channel_missmatch_with_output_tensor(%arg0: tensor<1x128x256x36xbf16>, %arg1: tensor<72x6x16x32xbf16>, %arg2: tensor<1x1x1x72xbf16>) -> tensor<1x32x32x72xbf16> { + %0 = tensor.empty() : tensor<1x32x32x72xbf16> + // CHECK: error: 'ttir.conv2d' op Mismatch between calculated and got output height and width. Calculated: (9 x 9). Got output tensor height and width: (32 x 32) + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 10: i32, + padding = array, + dilation = array, + groups = 6: i32 + }> : (tensor<1x128x256x36xbf16>, tensor<72x6x16x32xbf16>, tensor<1x1x1x72xbf16>, tensor<1x32x32x72xbf16>) -> tensor<1x32x32x72xbf16> + return %1 : tensor<1x32x32x72xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTIR/conv2d/conv2d_tests_positive.mlir b/test/ttmlir/Dialect/TTIR/conv2d/conv2d_tests_positive.mlir new file mode 100644 index 0000000000..73cd6722dc --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/conv2d/conv2d_tests_positive.mlir @@ -0,0 +1,133 @@ +// RUN: ttmlir-opt %s | FileCheck %s + +module attributes {} { + func.func @conv2d_simple(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x30x30x64xbf16> { + %0 = tensor.empty() : tensor<1x30x30x64xbf16> + // CHECK: %[[C:.*]] = "ttir.conv2d"[[C:.*]] + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16> + return %1 : tensor<1x30x30x64xbf16> + } + + func.func @conv2d_stride_1(%arg0: tensor<3x32x32x8xbf16>, %arg1: tensor<16x8x3x3xbf16>, %arg2: tensor<1x1x1x16xbf16>) -> tensor<3x15x15x16xbf16> { + %0 = tensor.empty() : tensor<3x15x15x16xbf16> + // CHECK: %[[C:.*]] = "ttir.conv2d"[[C:.*]] + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 2: i32, + padding = 0: i32, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<3x32x32x8xbf16>, tensor<16x8x3x3xbf16>, tensor<1x1x1x16xbf16>, tensor<3x15x15x16xbf16>) -> tensor<3x15x15x16xbf16> + return %1 : tensor<3x15x15x16xbf16> + } + + func.func @conv2d_stride_2(%arg0: tensor<4x32x32x16xbf16>, %arg1: tensor<8x16x3x3xbf16>, %arg2: tensor<1x1x1x8xbf16>) -> tensor<4x8x5x8xbf16> { + %0 = tensor.empty() : tensor<4x8x5x8xbf16> + // CHECK: %[[C:.*]] = "ttir.conv2d"[[C:.*]] + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = array, + padding = 0: i32, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<4x32x32x16xbf16>, tensor<8x16x3x3xbf16>, tensor<1x1x1x8xbf16>, tensor<4x8x5x8xbf16>) -> tensor<4x8x5x8xbf16> + return %1 : tensor<4x8x5x8xbf16> + } + + func.func @conv2d_padding_1(%arg0: tensor<32x32x32x4xbf16>, %arg1: tensor<8x4x3x3xbf16>, %arg2: tensor<1x1x1x8xbf16>) -> tensor<32x38x38x8xbf16> { + %0 = tensor.empty() : tensor<32x38x38x8xbf16> + // CHECK: %[[C:.*]] = "ttir.conv2d"[[C:.*]] + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 4: i32, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<32x32x32x4xbf16>, tensor<8x4x3x3xbf16>, tensor<1x1x1x8xbf16>, tensor<32x38x38x8xbf16>) -> tensor<32x38x38x8xbf16> + return %1 : tensor<32x38x38x8xbf16> + } + + func.func @conv2d_padding_2(%arg0: tensor<16x32x32x32xbf16>, %arg1: tensor<128x32x3x3xbf16>, %arg2: tensor<1x1x1x128xbf16>) -> tensor<16x54x46x128xbf16> { + %0 = tensor.empty() : tensor<16x54x46x128xbf16> + // CHECK: %[[C:.*]] = "ttir.conv2d"[[C:.*]] + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = array, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<16x32x32x32xbf16>, tensor<128x32x3x3xbf16>, tensor<1x1x1x128xbf16>, tensor<16x54x46x128xbf16>) -> tensor<16x54x46x128xbf16> + return %1 : tensor<16x54x46x128xbf16> + } + + func.func @conv2d_padding_3(%arg0: tensor<8x32x32x64xbf16>, %arg1: tensor<256x64x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<8x48x42x256xbf16> { + %0 = tensor.empty() : tensor<8x48x42x256xbf16> + // CHECK: %[[C:.*]] = "ttir.conv2d"[[C:.*]] + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = array, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<8x32x32x64xbf16>, tensor<256x64x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<8x48x42x256xbf16>) -> tensor<8x48x42x256xbf16> + return %1 : tensor<8x48x42x256xbf16> + } + + func.func @conv2d_dilation_1(%arg0: tensor<16x32x32x128xbf16>, %arg1: tensor<64x128x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<16x24x24x64xbf16> { + %0 = tensor.empty() : tensor<16x24x24x64xbf16> + // CHECK: %[[C:.*]] = "ttir.conv2d"[[C:.*]] + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = 4: i32, + groups = 1: i32 + }> : (tensor<16x32x32x128xbf16>, tensor<64x128x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<16x24x24x64xbf16>) -> tensor<16x24x24x64xbf16> + return %1 : tensor<16x24x24x64xbf16> + } + + func.func @conv2d_dilation_2(%arg0: tensor<32x32x32x16xbf16>, %arg1: tensor<64x16x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<32x20x28x64xbf16> { + %0 = tensor.empty() : tensor<32x20x28x64xbf16> + // CHECK: %[[C:.*]] = "ttir.conv2d"[[C:.*]] + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = array, + groups = 1: i32 + }> : (tensor<32x32x32x16xbf16>, tensor<64x16x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<32x20x28x64xbf16>) -> tensor<32x20x28x64xbf16> + return %1 : tensor<32x20x28x64xbf16> + } + + func.func @conv2d_groups(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<128x16x3x3xbf16>, %arg2: tensor<1x1x1x128xbf16>) -> tensor<1x30x30x128xbf16> { + %0 = tensor.empty() : tensor<1x30x30x128xbf16> + // CHECK: %[[C:.*]] = "ttir.conv2d"[[C:.*]] + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = 1: i32, + groups = 4: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<128x16x3x3xbf16>, tensor<1x1x1x128xbf16>, tensor<1x30x30x128xbf16>) -> tensor<1x30x30x128xbf16> + return %1 : tensor<1x30x30x128xbf16> + } + + func.func @conv2d_complex(%arg0: tensor<1x128x256x36xbf16>, %arg1: tensor<72x6x16x32xbf16>, %arg2: tensor<1x1x1x72xbf16>) -> tensor<1x9x9x72xbf16> { + %0 = tensor.empty() : tensor<1x9x9x72xbf16> + // CHECK: %[[C:.*]] = "ttir.conv2d"[[C:.*]] + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 10: i32, + padding = array, + dilation = array, + groups = 6: i32 + }> : (tensor<1x128x256x36xbf16>, tensor<72x6x16x32xbf16>, tensor<1x1x1x72xbf16>, tensor<1x9x9x72xbf16>) -> tensor<1x9x9x72xbf16> + return %1 : tensor<1x9x9x72xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/convolution/simple_conv.mlir b/test/ttmlir/Dialect/TTNN/convolution/simple_conv.mlir deleted file mode 100644 index 46e9334a9c..0000000000 --- a/test/ttmlir/Dialect/TTNN/convolution/simple_conv.mlir +++ /dev/null @@ -1,9 +0,0 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -module attributes {} { - func.func @forward(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x32x32x64xbf16> { - %0 = tensor.empty() : tensor<1x32x32x64xbf16> - // CHECK: %[[C:.*]] = "ttnn.conv2d"[[C:.*]] - %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) <{stride_height=1: si32, stride_width=1: si32, dilation_height=1: si32, dilation_width=1: si32, groups=1: si32, padding_left=1: si32, padding_right=1: si32, padding_top=1: si32, padding_bottom=1: si32, is_convtranspose2d=0: si32, output_height_transpose=0: si32, output_width_transpose=0: si32, stride_transpose=0: si32}> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x32x32x64xbf16>) -> tensor<1x32x32x64xbf16> - return %1 : tensor<1x32x32x64xbf16> - } -} diff --git a/test/ttmlir/Dialect/TTNN/convolution/simple_conv2d.mlir b/test/ttmlir/Dialect/TTNN/convolution/simple_conv2d.mlir new file mode 100644 index 0000000000..4188697e41 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/convolution/simple_conv2d.mlir @@ -0,0 +1,16 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s + +module attributes {} { + func.func @forward(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x32x32x64xbf16> { + %0 = tensor.empty() : tensor<1x32x32x64xbf16> + // CHECK: %[[C:.*]] = "ttnn.conv2d"[[C:.*]] + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = array, + padding = array, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x32x32x64xbf16>) -> tensor<1x32x32x64xbf16> + return %1 : tensor<1x32x32x64xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_conv.mlir b/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_conv2d.mlir similarity index 51% rename from test/ttmlir/Silicon/TTNN/n150/perf/test_perf_conv.mlir rename to test/ttmlir/Silicon/TTNN/n150/perf/test_perf_conv2d.mlir index 13708ef16a..838b6661fd 100644 --- a/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_conv.mlir +++ b/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_conv2d.mlir @@ -1,11 +1,18 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + module attributes {} { func.func @forward(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x32x32x64xbf16> { %0 = tensor.empty() : tensor<1x32x32x64xbf16> // CHECK: %[[C:.*]] = "ttnn.conv2d"[[C:.*]] - %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) <{stride_height=1: si32, stride_width=1: si32, dilation_height=1: si32, dilation_width=1: si32, groups=1: si32, padding_left=1: si32, padding_right=1: si32, padding_top=1: si32, padding_bottom=1: si32, is_convtranspose2d=0: si32, output_height_transpose=0: si32, output_width_transpose=0: si32, stride_transpose=0: si32}> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x32x32x64xbf16>) -> tensor<1x32x32x64xbf16> + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = array, + padding = array, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x32x32x64xbf16>) -> tensor<1x32x32x64xbf16> return %1 : tensor<1x32x32x64xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_conv2d_config.mlir b/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_conv2d_config.mlir new file mode 100644 index 0000000000..65c0d76604 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/n150/perf/test_perf_conv2d_config.mlir @@ -0,0 +1,47 @@ +// RUN: ttmlir-translate --ttnn-to-flatbuffer %s > %t.ttnn + +#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> +#dram = #ttnn.buffer_type +#system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux"}], [{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 99104, erisc_l1_unreserved_base = 104480, dram_unreserved_base = 32, dram_unreserved_end = 1073147200, physical_cores = {worker = [ 18x18, 18x19, 18x20, 18x21, 18x22, 18x23, 18x24, 18x25, 19x18, 19x19, 19x20, 19x21, 19x22, 19x23, 19x24, 19x25, 20x18, 20x19, 20x20, 20x21, 20x22, 20x23, 20x24, 20x25, 21x18, 21x19, 21x20, 21x21, 21x22, 21x23, 21x24, 21x25, 22x18, 22x19, 22x20, 22x21, 22x22, 22x23, 22x24, 22x25, 23x18, 23x19, 23x20, 23x21, 23x22, 23x23, 23x24, 23x25, 24x18, 24x19, 24x20, 24x21, 24x22, 24x23, 24x24, 24x25, 25x18, 25x19, 25x20, 25x21, 25x22, 25x23, 25x24, 25x25] dram = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0x10, 0x11] eth_inactive = [ 16x18, 16x19, 16x20, 16x21, 16x22, 16x23, 16x24, 16x25, 17x19, 17x20, 17x22, 17x23, 17x24]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]> +#system_memory = #ttnn.buffer_type +#ttnn_layout = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 1024 + d1 * 32 + d2, d3), <1x1>, memref<1024x64xbf16, #system_memory>> +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 3 + d2, d3), <1x1>, memref<12288x3xbf16, #system_memory>> +#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 + d1 + d2, d3), <1x1>, memref<1x64xbf16, #system_memory>> +#ttnn_layout3 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 900 + d1 * 30 + d2, d3), <1x1>, memref<900x64xbf16, #system_memory>> +#ttnn_layout4 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 900 + d1 * 30 + d2, d3), <1x1>, memref<29x2x!tt.tile<32x32, bf16>, #dram>, > +#ttnn_layout5 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 900 + d1 * 30 + d2, d3), <1x1>, memref<29x2x!tt.tile<32x32, bf16>, #system_memory>> + +#conv2d_config = #ttnn.conv2d_config< + dtype = bf16, + weightsDtype = bf16, + activation = "", + inputChannelsAlignment = 32 : i32, + deallocateActivation = false, + reallocateHaloOutput = true, + actBlockHOverride = 0 : i32, + actBlockWDiv = 1 : i32, + reshardIfNotOptimal = false, + overrideShardingConfig = false, + shardLayout = #ttnn.tensor_memory_layout, + transposeShards = true, + outputLayout = tile, + enableActDoubleBuffer = false, + enableWeightsDoubleBuffer = false, + enableSplitReader = false, + enableSubblockPadding = false +> + +module attributes {tt.device = #device, tt.system_desc = #system_desc} { + func.func @forward(%arg0: tensor<1x32x32x64xbf16, #ttnn_layout>, %arg1: tensor<64x64x3x3xbf16, #ttnn_layout1>, %arg2: tensor<1x1x1x64xbf16, #ttnn_layout2>) -> tensor<1x30x30x64xbf16, #ttnn_layout3> { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<29x2>>, >, shape = #ttnn.shape<1x1x900x64>}> : (!tt.device<#device>) -> tensor<1x1x900x64xbf16, #ttnn_layout4> + %2 = "ttnn.conv2d"(%arg0, %arg1, %arg2, %1, %0) <{batch_size = 1 : i32, conv2d_config = #conv2d_config, dilation = array, groups = 1 : i32, in_channels = 64 : i32, input_height = 32 : i32, input_width = 32 : i32, kernel_size = array, out_channels = 64 : i32, padding = array, stride = array}> : (tensor<1x32x32x64xbf16, #ttnn_layout>, tensor<64x64x3x3xbf16, #ttnn_layout1>, tensor<1x1x1x64xbf16, #ttnn_layout2>, tensor<1x1x900x64xbf16, #ttnn_layout4>, !tt.device<#device>) -> tensor<1x1x900x64xbf16, #ttnn_layout4> + %3 = "ttnn.reshape"(%2) <{shape = [1 : i32, 30 : i32, 30 : i32, 64 : i32]}> : (tensor<1x1x900x64xbf16, #ttnn_layout4>) -> tensor<1x30x30x64xbf16, #ttnn_layout4> + "ttnn.deallocate"(%1) <{force = false}> : (tensor<1x1x900x64xbf16, #ttnn_layout4>) -> () + %4 = "ttnn.from_device"(%3) : (tensor<1x30x30x64xbf16, #ttnn_layout4>) -> tensor<1x30x30x64xbf16, #ttnn_layout5> + "ttnn.deallocate"(%3) <{force = false}> : (tensor<1x30x30x64xbf16, #ttnn_layout4>) -> () + %5 = "ttnn.to_layout"(%4) <{layout = #ttnn.layout}> : (tensor<1x30x30x64xbf16, #ttnn_layout5>) -> tensor<1x30x30x64xbf16, #ttnn_layout3> + "ttnn.deallocate"(%4) <{force = false}> : (tensor<1x30x30x64xbf16, #ttnn_layout5>) -> () + return %5 : tensor<1x30x30x64xbf16, #ttnn_layout3> + } +} diff --git a/test/ttmlir/Silicon/TTNN/n150/simple_conv.mlir b/test/ttmlir/Silicon/TTNN/n150/simple_conv.mlir deleted file mode 100644 index 13708ef16a..0000000000 --- a/test/ttmlir/Silicon/TTNN/n150/simple_conv.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir -// RUN: FileCheck %s --input-file=%t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -module attributes {} { - func.func @forward(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x32x32x64xbf16> { - %0 = tensor.empty() : tensor<1x32x32x64xbf16> - // CHECK: %[[C:.*]] = "ttnn.conv2d"[[C:.*]] - %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) <{stride_height=1: si32, stride_width=1: si32, dilation_height=1: si32, dilation_width=1: si32, groups=1: si32, padding_left=1: si32, padding_right=1: si32, padding_top=1: si32, padding_bottom=1: si32, is_convtranspose2d=0: si32, output_height_transpose=0: si32, output_width_transpose=0: si32, stride_transpose=0: si32}> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x32x32x64xbf16>) -> tensor<1x32x32x64xbf16> - return %1 : tensor<1x32x32x64xbf16> - } -} diff --git a/test/ttmlir/Silicon/TTNN/n150/simple_conv2d.mlir b/test/ttmlir/Silicon/TTNN/n150/simple_conv2d.mlir new file mode 100644 index 0000000000..5eb1c5e7a1 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/n150/simple_conv2d.mlir @@ -0,0 +1,18 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +module attributes {} { + func.func @forward(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x30x30x64xbf16> { + %0 = tensor.empty() : tensor<1x30x30x64xbf16> + // CHECK: %[[C:.*]] = "ttnn.conv2d"[[C:.*]] + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + dilation = 1: i32, + groups = 1: i32 + }> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16> + return %1 : tensor<1x30x30x64xbf16> + } +}