Skip to content

Commit

Permalink
Refactor conv2d op and add conv2d config support
Browse files Browse the repository at this point in the history
  • Loading branch information
jserbedzijaTT committed Jan 31, 2025
1 parent a2bed30 commit 147c7b9
Show file tree
Hide file tree
Showing 21 changed files with 1,115 additions and 166 deletions.
53 changes: 43 additions & 10 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1050,22 +1050,55 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> {
def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
let summary = "Conv2d operation.";
let description = [{
Applies a 2D convolution over an input image composed of several input planes.
Applies a 2D convolution over an input image composed of several input planes.

Inputs:
- `input` AnyRankedTensor: expected in the following format (N, H_in, W_in, C) where:
- N is the batch size
- H_in is the height of the input planes
- W_in is the width of the input planes
- C is the number of channels

- `weight` AnyRankedTensor: expected in the following format (O, C/G, K_H, K_W).
- `bias` Optional<AnyRankedTensor>: expected in the following format (1, 1, 1, O) where:
- C is the number of input channels
- O is the number of output channels
- G is the number of groups
- K_H is the height of the kernel
- K_W is the width of the kernel

- `output` AnyRankedTensor: expected in the following format (N, H_out, W_out, O) where:
- `H_out = (H_in + padding_top + padding_bottom - dilation[0] * (K_H - 1) - 1) / stride[0] + 1`
- `W_out = (W_in + padding_left + padding_right - dilation[1] * (K_W - 1) - 1) / stride[1] + 1`

Attributes:
- `stride` (i32 | array<2xi32>): Stride of the convolution.
- `padding` (i32 | array<2xi32> | array<4xi32>): Padding on height/width. Can be symmetric (`[pH, pW]`) or asymmetric (`[pT, pL, pB, pR]`).
- `dilation` (i32 | array<2xi32>): Spacing between kernel elements.
- `groups` i32: Number of blocked connections from input channels to output channels. input and output channels must both be divisible by groups.

Example:
%input = tensor.empty() : () -> tensor<1x32x32x64xbf16>
%weight = tensor.empty() : () -> tensor<64x64x3x3xbf16>
%bias = tensor.empty() : () -> tensor<1x1x1x64xbf16>
%output = tensor.empty() : () -> tensor<1x30x30x64xbf16>
%0 = "ttir.conv2d"(%input, %weight, %bias, %output)
<{
stride = 1: i32,
padding = 0: i32,
dilation = 1: i32,
groups = 1: i32
> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16>
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$weight,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output,
SI32Attr:$stride_height,
SI32Attr:$stride_width,
SI32Attr:$dilation_height,
SI32Attr:$dilation_width,
SI32Attr:$groups,
SI32Attr:$padding_left,
SI32Attr:$padding_right,
SI32Attr:$padding_top,
SI32Attr:$padding_bottom);
AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$stride,
AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$padding,
AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$dilation,
I32Attr:$groups);

let results = (outs AnyRankedTensor:$result);

Expand Down
37 changes: 24 additions & 13 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -971,15 +971,12 @@ def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> {
I32Attr:$batch_size,
I32Attr:$input_height,
I32Attr:$input_width,
I32Attr:$kernel_height,
I32Attr:$kernel_width,
I32Attr:$stride_height,
I32Attr:$stride_width,
I32Attr:$padding_height,
I32Attr:$padding_width,
I32Attr:$dilation_height,
I32Attr:$dilation_width,
I32Attr:$groups);
DenseI32ArrayAttr:$kernel_size,
DenseI32ArrayAttr:$stride,
DenseI32ArrayAttr:$padding,
DenseI32ArrayAttr:$dilation,
I32Attr:$groups,
OptionalAttr<TTNN_Conv2dConfigAttr>:$conv2d_config);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -996,10 +993,24 @@ def TTNN_ConvTranspose2dOp : TTNN_NamedDPSOp<"conv_transpose2d"> {
Applies a 2D transposed convolution operator over an input image composed of several input planes.

Inputs:
- `input` AnyRankedTensor: NHWC format (batch_size x height x width x channels)
- `weight` AnyRankedTensor: OIHW format (output_channels x input_channels x height x width)
- `bias` Optional<AnyRankedTensor>: (1 x 1 x 1 x output_channels)
- `output` AnyRankedTensor: (1 x 1 x (batch_size * height * width) x channels)
- `input (AnyRankedTensor): Expected format `(N x H_in x W_in x C)`
- `N`: Batch size
- `H_in`: Height of the input
- `W_in`: Width of the input
- `C`: Number of input channels

- `weight` (AnyRankedTensor): Expected format `(C x O / G x K_H x K_W)`
- `C`: Number of input channels
- `O`: Number of output channels
- `G`: Number of groups
- `K_H`: Kernel height
- `K_W`: Kernel width

- `bias` (Optional<AnyRankedTensor>): Expected format `(1 x 1 x 1 x O / G)`

- `output` (AnyRankedTensor): Expected format `(1 x 1 x (N * H_out * W_out) x O)`
- `H_out = (H_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1`
- `W_out = (W_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1`

Attributes:
- `in_channels` i32: The number of input channels.
Expand Down
30 changes: 30 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,36 @@ def TTNN_MemoryConfigAttr : TTNN_Attr<"MemoryConfig", "memory_config"> {
}];
}

def TTNN_Conv2dConfigAttr : TTNN_Attr<"Conv2dConfig", "conv2d_config"> {
let summary = "TTNN Conv2dConfig attribute";
let description = [{
TTNN conv2d config attribute
}];

let parameters = (ins
"DataType":$dtype,
"DataType":$weightsDtype,
"StringAttr":$activation,
"IntegerAttr":$inputChannelsAlignment,
"BoolAttr":$deallocateActivation,
"BoolAttr":$reallocateHaloOutput,
"IntegerAttr":$actBlockHOverride,
"IntegerAttr":$actBlockWDiv,
"BoolAttr":$reshardIfNotOptimal,
"BoolAttr":$overrideShardingConfig,
OptionalParameter<"TensorMemoryLayoutAttr", "TTNN tensor memory layout">:$shardLayout,
// TODO: Finish adding this attribute
OptionalParameter<"Attribute", "TTNN core grid">:$coreGrid,
"BoolAttr":$transposeShards,
"Layout":$outputLayout,
"BoolAttr":$enableActDoubleBuffer,
"BoolAttr":$enableWeightsDoubleBuffer,
"BoolAttr":$enableSplitReader,
"BoolAttr":$enableSubblockPadding);

let assemblyFormat = "`<` params `>`";
}

def TTNN_MeshShapeAttr : TTNN_Attr<"MeshShape", "mesh_shape"> {
let summary = "TTNN Mesh Shape";
let description = [{
Expand Down
21 changes: 21 additions & 0 deletions include/ttmlir/Target/Common/types.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,27 @@ table MemoryConfigDesc {
shard_spec: ShardSpec;
}

table Conv2dConfigDesc {
dtype: DataType;
weights_dtype: DataType;
activation: string;
input_channels_alignment: uint32;
deallocate_activation: bool;
reallocate_halo_output: bool;
act_block_h_override: uint32;
act_block_w_div: uint32;
reshard_if_not_optimal: bool;
override_sharding_config: bool;
shard_layout: TensorMemoryLayout = null;
core_grid: bool = null;
transpose_shards: bool;
output_layout: TensorLayout;
enable_act_double_buffer: bool;
enable_weights_double_buffer: bool;
enable_split_reader: bool;
enable_subblock_padding: bool;
}

table ReplicateTensor {
replication_factor: uint32;
}
Expand Down
13 changes: 5 additions & 8 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -284,15 +284,12 @@ table Conv2dOp {
batch_size: uint32;
input_height: uint32;
input_width: uint32;
kernel_height: uint32;
kernel_width: uint32;
stride_height: uint32;
stride_width: uint32;
padding_height: uint32;
padding_width: uint32;
dilation_height: uint32;
dilation_width: uint32;
kernel_size: [int32];
stride: [int32];
padding: [int32];
dilation: [int32];
groups: uint32;
conv2d_config: Conv2dConfigDesc;
}

table ConvTranspose2dOp {
Expand Down
45 changes: 45 additions & 0 deletions include/ttmlir/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,51 @@ getPairOfInteger(mlir::Attribute attr) {
return std::make_pair(x, y);
}

/// For a given mlir::Attribute attr, returns a tuple of four integers of type
/// ReturnTy. If attr is an IntegerAttr, it's interpreted as a
/// (value(attr), value(attr), value(attr), value(attr)) tuple, where
/// value(attr) is of type ScalarTy. If attr is a
/// DenseArrayAttr<VectorElementTy> of size 2, it's interpreted as a (attr[0],
/// attr[1], attr[0], attr[1]) tuple. If attr is a
/// DenseArrayAttr<VectorElementTy> of size 4, it is returned directly as
/// (attr[0], attr[1], attr[2], attr[3]). Otherwise, returns an error message.
template <typename ScalarTy, typename VectorElementTy = ScalarTy,
typename ReturnTy = ScalarTy>
inline llvm::Expected<std::tuple<ReturnTy, ReturnTy, ReturnTy, ReturnTy>>
getQuadrupleOfInteger(mlir::Attribute attr) {
ReturnTy x{}, y{}, z{}, w{};

// If attr is IntegerAttr, interpret it as (attr, attr, attr, attr)
if (auto value = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
x = y = z = w = integerAs<ScalarTy>(value.getValue());
}
// If attr is DenseArrayAttr, handle based on its size
else if (auto tuple = mlir::dyn_cast<
::mlir::detail::DenseArrayAttrImpl<VectorElementTy>>(attr)) {
if (tuple.size() == 2) {
x = tuple[0];
y = tuple[1];
z = tuple[0];
w = tuple[1];
} else if (tuple.size() == 4) {
x = tuple[0];
y = tuple[1];
z = tuple[2];
w = tuple[3];
} else {
return llvm::createStringError("Expected integer, pair, or tuple of size "
"4, but got tuple of size %lu",
tuple.size());
}
}
// If attr is of an unsupported type
else {
return llvm::createStringError("Unexpected attribute type");
}

return std::make_tuple(x, y, z, w);
}

} // namespace ttmlir::utils

#endif
36 changes: 16 additions & 20 deletions lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,31 +379,29 @@ struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern {
return failure();
}

auto strideHeightAttr = rewriter.getSI32IntegerAttr(
adaptor.getWindowStrides()[SPATIAL_DIM_HEIGHT]);
auto strideWidthAttr = rewriter.getSI32IntegerAttr(
adaptor.getWindowStrides()[SPATIAL_DIM_WIDTH]);
auto dilationHeightAttr = rewriter.getSI32IntegerAttr(
adaptor.getWeightDilation()[SPATIAL_DIM_HEIGHT]);
auto dilationWidthAttr = rewriter.getSI32IntegerAttr(
adaptor.getWeightDilation()[SPATIAL_DIM_WIDTH]);
auto strideAttr = rewriter.getDenseI32ArrayAttr({
static_cast<int32_t>(adaptor.getWindowStrides()[SPATIAL_DIM_HEIGHT]),
static_cast<int32_t>(adaptor.getWindowStrides()[SPATIAL_DIM_WIDTH]),
});
auto dilationAttr = rewriter.getDenseI32ArrayAttr({
static_cast<int32_t>(adaptor.getWeightDilation()[SPATIAL_DIM_HEIGHT]),
static_cast<int32_t>(adaptor.getWeightDilation()[SPATIAL_DIM_WIDTH]),
});

// Padding is a list of 2-tuples, the order of the 2-tuples is in
// most-significant spatial dimension first order For Conv2d the most
// significant spatial dimension is the height, followed by the width.
auto paddingMatrix =
getPaddingMatrix<NUM_SPATIAL_DIMS>(adaptor.getPadding());
auto paddingTopAttr =
rewriter.getSI32IntegerAttr(paddingMatrix[SPATIAL_DIM_HEIGHT][0]);
auto paddingBottomAttr =
rewriter.getSI32IntegerAttr(paddingMatrix[SPATIAL_DIM_HEIGHT][1]);
auto paddingLeftAttr =
rewriter.getSI32IntegerAttr(paddingMatrix[SPATIAL_DIM_WIDTH][0]);
auto paddingRightAttr =
rewriter.getSI32IntegerAttr(paddingMatrix[SPATIAL_DIM_WIDTH][1]);
auto paddingAttr = rewriter.getDenseI32ArrayAttr({
static_cast<int32_t>(paddingMatrix[SPATIAL_DIM_HEIGHT][0]),
static_cast<int32_t>(paddingMatrix[SPATIAL_DIM_WIDTH][0]),
static_cast<int32_t>(paddingMatrix[SPATIAL_DIM_HEIGHT][1]),
static_cast<int32_t>(paddingMatrix[SPATIAL_DIM_WIDTH][1]),
});

auto groupsAttr =
rewriter.getSI32IntegerAttr(adaptor.getFeatureGroupCount());
rewriter.getI32IntegerAttr(adaptor.getFeatureGroupCount());

llvm::ArrayRef<int64_t> outputShape = op.getResult().getType().getShape();
llvm::SmallVector<int64_t> newOutputShape{
Expand Down Expand Up @@ -445,9 +443,7 @@ struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern {
weightDPSOutput, kernelPermutation);
ttir::Conv2dOp newConv = rewriter.create<ttir::Conv2dOp>(
op.getLoc(), outputType, input, weight, adaptor.getBias(),
convDPSOutput, strideHeightAttr, strideWidthAttr, dilationHeightAttr,
dilationWidthAttr, groupsAttr, paddingLeftAttr, paddingRightAttr,
paddingTopAttr, paddingBottomAttr);
convDPSOutput, strideAttr, paddingAttr, dilationAttr, groupsAttr);

// Applying the inverse of permutation to the output will restore the
// tensor to the original layout.
Expand Down
Loading

0 comments on commit 147c7b9

Please sign in to comment.