diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index b965d6a614..201bac9109 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -447,6 +447,38 @@ def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> { let hasVerifier = 1; } +def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> { + let summary = "Applies a 2D max pooling over an input signal composed of several input planes."; + let description = [{ + Applies a 2D max pooling over an input signal composed of several input planes. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + SI32Attr:$kernel_height, + SI32Attr:$kernel_width, + SI32Attr:$stride_height, + SI32Attr:$stride_width, + SI32Attr:$dilation_height, + SI32Attr:$dilation_width, + BoolAttr:$ceil_mode, + SI32Attr:$padding_left, + SI32Attr:$padding_right, + SI32Attr:$padding_top, + SI32Attr:$padding_bottom, + TT_OperandConstraintArrayAttr:$operand_constraints, + OptionalAttr:$original_height, + OptionalAttr:$original_width); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + def TTIR_ReshapeOp: TTIR_DPSOp<"reshape"> { let summary = "Reshape op."; let description = [{ diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index 886287b6b3..b4be042ff0 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -58,6 +58,13 @@ def TTIRLayout: Pass<"ttir-layout", "::mlir::ModuleOp"> { ]; } +def TTIRSlidingWindow2dFixShapes: Pass<"ttir-sliding-window-2d-fix-shapes", "::mlir::ModuleOp"> { + let summary = "Insert reshapes on the input and output of 2-dimensional sliding window ops that collapse N,H,W on the input: i.e (N, H, W, C) --> (1, 1, N*H*W, C), and unflatten the output: i.e (1, 1, N*H*W, C) --> (N, H, W, C)"; + let description = [{ + Insert reshapes on the input and output of 2-dimensional sliding window ops that collapse N,H,W on the input: i.e (N, H, W, C) --> (1, 1, N*H*W, C), and unflatten the output: i.e (1, 1, N*H*W, C) --> (N, H, W, C) + }]; +} + def TTIRSplitCompoundLayout: Pass<"ttir-split-compound-layout", "::mlir::ModuleOp"> { let summary = "Split compound layouts."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 26b14654ce..9d5a822fc2 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -350,6 +350,38 @@ def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> { let hasVerifier = 1; } +def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> { + let summary = "Applies a 2D max pooling over an input signal composed of several input planes."; + let description = [{ + Applies a 2D max pooling over an input signal composed of several input planes. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + TT_Device:$device, + SI32Attr:$batch_size, + SI32Attr:$input_height, + SI32Attr:$input_width, + SI32Attr:$channels, + SI32Attr:$kernel_height, + SI32Attr:$kernel_width, + SI32Attr:$stride_height, + SI32Attr:$stride_width, + SI32Attr:$dilation_height, + SI32Attr:$dilation_width, + BoolAttr:$ceil_mode, + SI32Attr:$padding_height, + SI32Attr:$padding_width); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + def TTNN_EmptyOp : TTNN_Op<"empty"> { let summary = "Empty op."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 99443598c9..8709d81677 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -121,6 +121,25 @@ table Conv2dOp { groups: uint32; } +table MaxPool2dOp { + in: tt.target.TensorRef; + out: tt.target.TensorRef; + device: tt.target.DeviceRef; + batch_size: uint32; + input_height: uint32; + input_width: uint32; + channels: uint32; + kernel_height: uint32; + kernel_width: uint32; + stride_height: uint32; + stride_width: uint32; + dilation_height: uint32; + dilation_width: uint32; + ceil_mode: bool; + padding_height: uint32; + padding_width: uint32; +} + table DeallocOp { in: tt.target.TensorRef; } @@ -139,6 +158,7 @@ union OpType { Conv2dOp, ConcatOp, ReshapeOp, + MaxPool2dOp, DeallocOp } diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index dd2e413aa2..fb4d4d511a 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -364,18 +364,62 @@ class Conv2dOpConversionPattern : public OpConversionPattern { auto dilation_width = rewriter.getI32IntegerAttr(adaptor.getDilationWidth()); auto groups = rewriter.getI32IntegerAttr(adaptor.getGroups()); - rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), adaptor.getInput(), adaptor.getWeight(), adaptor.getBias(), adaptor.getOutput(), device, in_channels, out_channels, batch_size, - input_width, input_height, kernel_height, kernel_width, stride_height, + input_height, input_width, kernel_height, kernel_width, stride_height, stride_width, padding_height, padding_width, dilation_height, dilation_width, groups); return success(); } }; +class MaxPool2dOpConversionPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::MaxPool2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + assert(adaptor.getPaddingBottom() == adaptor.getPaddingTop() && + "TTNN max_pool2d does not support padding top/bottom/left/right " + "separately"); + assert(adaptor.getPaddingLeft() == adaptor.getPaddingRight() && + "TTNN max_pool2d does not support padding top/bottom/left/right " + "separately"); + + auto device = getOrInsertDevice(rewriter, op); + auto input_ty = mlir::cast(adaptor.getInput().getType()); + llvm::ArrayRef input_shape = input_ty.getShape(); + + auto batch_size = + rewriter.getSI32IntegerAttr(input_shape[input_shape.size() - 4]); + auto channels = + rewriter.getSI32IntegerAttr(input_shape[input_shape.size() - 1]); + + assert(adaptor.getOriginalHeight().has_value() && + "ttir::MaxPool2dOp must have original_height set before translating " + "to TTNN dialect."); + assert(adaptor.getOriginalWidth().has_value() && + "ttir::MaxPool2dOp must have original_width set before translating " + "to TTNN dialect."); + + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), + adaptor.getInput(), adaptor.getOutput(), device, batch_size, + adaptor.getOriginalHeightAttr(), adaptor.getOriginalWidthAttr(), + channels, adaptor.getKernelHeightAttr(), adaptor.getKernelWidthAttr(), + adaptor.getStrideHeightAttr(), adaptor.getStrideWidthAttr(), + adaptor.getDilationHeightAttr(), adaptor.getDilationWidthAttr(), + adaptor.getCeilModeAttr(), adaptor.getPaddingTopAttr(), + adaptor.getPaddingRightAttr()); + return success(); + } +}; + namespace mlir::tt { void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, @@ -407,7 +451,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, SqueezeOpConversionPattern, UnsqueezeOpConversionPattern, MatmulOpConversionPattern, - Conv2dOpConversionPattern + Conv2dOpConversionPattern, + MaxPool2dOpConversionPattern >(typeConverter, ctx); // ANCHOR_END: op_rewriter_pattern_set // clang-format on diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 06644bdce7..38c864b082 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -179,6 +179,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // Conv ops // patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, + ctx); // Other ops // diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index ae2cff496f..c504989181 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -8,6 +8,7 @@ #include "ttmlir/Dialect/TTIR/IR/TTIROps.h" #include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.cpp.inc" +#include #include #define GET_OP_CLASSES @@ -268,6 +269,45 @@ ::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() { return success(); } +::mlir::LogicalResult mlir::tt::ttir::MaxPool2dOp::verify() { + ::mlir::RankedTensorType inputType = getInput().getType(); + std::vector inputShape = getInput().getType().getShape().vec(); + + if (inputType.getRank() != 4) { + return emitOpError() + << "Input tensor rank must be 4. Recieved input with rank " + << inputType.getRank() << ". Shape: (" << inputShape << ")."; + } + + if (getOriginalHeight().has_value() != getOriginalWidth().has_value()) { + std::string with_value = + getOriginalHeight().has_value() ? "original_height" : "original_width"; + return emitOpError() + << "If providing the original height and width as attributes, both " + "original_height and original_width must be set. However, only " + << with_value << " was provided."; + } + + if (getOriginalHeight().has_value() && getOriginalWidth().has_value()) { + inputShape[1] = getOriginalHeight().value(); + inputShape[2] = getOriginalWidth().value(); + } + + if (getKernelHeight() > inputShape[1]) { + return emitOpError() << "Kernel height " << getKernelHeight() + << " is greater than input height " << inputShape[1] + << ". This MaxPool2d configuration is invalid."; + } + + if (getKernelWidth() > inputShape[2]) { + return emitOpError() << "Kernel width " << getKernelWidth() + << " is greater than input width " << inputShape[2] + << ". This MaxPool2d configuration is invalid."; + } + + return success(); +} + ::mlir::LogicalResult mlir::tt::ttir::SqueezeOp::verify() { ::mlir::RankedTensorType inputType = getInput().getType(); ::mlir::RankedTensorType outputType = getOutput().getType(); diff --git a/lib/Dialect/TTIR/Transforms/Passes.cpp b/lib/Dialect/TTIR/Transforms/Passes.cpp index 18dd7b6b18..654a219c59 100644 --- a/lib/Dialect/TTIR/Transforms/Passes.cpp +++ b/lib/Dialect/TTIR/Transforms/Passes.cpp @@ -22,9 +22,16 @@ #include "ttmlir/Dialect/TTIR/Analysis/OptimalTargetGridAnalysis.h" #include "ttmlir/Dialect/TTIR/Transforms/Passes.h" #include "ttmlir/Utils.h" +#include +#include +#include +#include +#include +#include #include namespace mlir::tt::ttir { +#define GEN_PASS_DEF_TTIRSLIDINGWINDOW2DFIXSHAPES #define GEN_PASS_DEF_TTIRGENERICKERNEL #define GEN_PASS_DEF_TTIRGENERICREGION #define GEN_PASS_DEF_TTIRGENERICREGIONOPERANDSTOMEMREF @@ -772,6 +779,135 @@ class TTIRLayout : public impl::TTIRLayoutBase { } }; +std::vector collapseNHW(std::vector shape) { + std::vector collapsed(shape.size(), 1); + + int64_t NHW = 1; + for (uint32_t i = 0; i < shape.size() - 1; i++) { + NHW *= shape[i]; + } + collapsed[collapsed.size() - 2] = NHW; + collapsed[collapsed.size() - 1] = shape[shape.size() - 1]; + return collapsed; +} + +template +class UncollapsedSlidingWindow2dPatternRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + ReshapeOp createReshapeOp(PatternRewriter &rewriter, Location loc, + Value input, ::llvm::ArrayRef shapei64, + ::mlir::ArrayAttr operandConstraints) const { + auto ty = mlir::cast(input.getType()); + auto output = + rewriter.create(loc, shapei64, ty.getElementType()); + + auto shape_attr = rewriter.getI32ArrayAttr( + {static_cast(shapei64[0]), static_cast(shapei64[1]), + static_cast(shapei64[2]), static_cast(shapei64[3])}); + return rewriter.create( + loc, output.getType(), input, output, shape_attr, operandConstraints); + } + + MaxPool2dOp createMaxPool2dOp(PatternRewriter &rewriter, MaxPool2dOp op, + Value input, int32_t input_height, + int32_t input_width, + RankedTensorType new_result_type) const { + auto output = rewriter.create( + op->getLoc(), new_result_type.getShape(), + new_result_type.getElementType()); + + auto input_height_attr = rewriter.getSI32IntegerAttr(input_height); + auto input_width_attr = rewriter.getSI32IntegerAttr(input_width); + + MaxPool2dOp new_maxpool = rewriter.create( + op.getLoc(), new_result_type, input, output, op.getKernelHeightAttr(), + op.getKernelWidthAttr(), op.getStrideHeightAttr(), + op.getStrideWidthAttr(), op.getDilationHeightAttr(), + op.getDilationWidthAttr(), op.getCeilModeAttr(), + op.getPaddingLeftAttr(), op.getPaddingRightAttr(), + op.getPaddingTopAttr(), op.getPaddingBottomAttr(), + op.getOperandConstraints(), input_height_attr, input_width_attr); + + return new_maxpool; + } + + LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final { + ::llvm::ArrayRef input_shape = + mlir::cast(op.getInput().getType()).getShape(); + + if (input_shape.size() != 4) { + return failure(); + } + + if (input_shape[0] == 1 && input_shape[1] == 1) { + return failure(); + } + + if (!llvm::isa(op)) { + return failure(); + } + + // By this point we are certain that the input tensor is not in the form (1, + // 1, N*H*W, C) And so we must insert reshapes on the input/output + + std::vector new_input_shape = collapseNHW(input_shape); + ::llvm::ArrayRef new_input_shape_array(new_input_shape); + + ReshapeOp input_reshape = + createReshapeOp(rewriter, op.getLoc(), op.getInput(), + new_input_shape_array, op.getOperandConstraints()); + + std::vector new_result_shape = + collapseNHW(op.getResult().getType().getShape().vec()); + ::llvm::ArrayRef new_result_shape_array(new_result_shape); + + RankedTensorType new_result_type = RankedTensorType::get( + new_result_shape_array, op.getResult().getType().getElementType(), + op.getResult().getType().getEncoding()); + + Operation *new_op = createMaxPool2dOp( + rewriter, mlir::cast(op), input_reshape, + static_cast(input_shape[1]), + static_cast(input_shape[2]), new_result_type); + + ReshapeOp output_reshape = createReshapeOp( + rewriter, op.getLoc(), new_op->getResult(0), + op.getResult().getType().getShape().vec(), op.getOperandConstraints()); + + rewriter.replaceOp(op, output_reshape); + return success(); + } +}; + +class TTIRSlidingWindow2dFixShapes + : public impl::TTIRSlidingWindow2dFixShapesBase< + TTIRSlidingWindow2dFixShapes> { +public: + using impl::TTIRSlidingWindow2dFixShapesBase< + TTIRSlidingWindow2dFixShapes>::TTIRSlidingWindow2dFixShapesBase; + + void runOnOperation() final { + { + RewritePatternSet patterns(&getContext()); + patterns.add>( + &getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { + signalPassFailure(); + return; + } + } + } + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + } +}; + class TTIRSplitCompoundLayoutRewriter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 7c80fcd121..c5e28ea2cf 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -8,6 +8,7 @@ #include "ttmlir/Dialect/TTKernel/IR/TTKernelOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNN.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" +#include #include #define GET_OP_CLASSES @@ -322,6 +323,54 @@ ::mlir::LogicalResult mlir::tt::ttnn::Conv2dOp::verify() { return success(); } +::mlir::LogicalResult mlir::tt::ttnn::MaxPool2dOp::verify() { + ::mlir::RankedTensorType inputType = getInput().getType(); + ::llvm::ArrayRef inputShape = getInput().getType().getShape(); + if (getKernelHeight() > getInputHeight()) { + return emitOpError() << "Kernel height " << getKernelHeight() + << " is greater than input height " << getInputHeight() + << ". This MaxPool2d configuration is invalid."; + } + + if (getKernelWidth() > getInputWidth()) { + return emitOpError() << "Kernel width " << getKernelWidth() + << " is greater than input width " << getInputWidth() + << ". This MaxPool2d configuration is invalid."; + } + + if (inputType.getRank() != 4) { + return emitOpError() + << "Input tensor rank must be 4. Recieved input with rank " + << inputType.getRank() << ". Shape: (" << inputShape << ")."; + } + + if (inputShape[0] != 1 || inputShape[1] != 1) { + return emitOpError() << "Maxpool input must be in the form (1, 1, N*H*W, " + "C). Recieved shape (" + << inputShape << ")."; + } + + if (inputShape[2] != getBatchSize() * getInputHeight() * getInputWidth()) { + return emitOpError() << "Maxpool shape (" << inputShape + << ") at dim -2 must be equal to N*H*W. However the " + "attributes given are N=" + << getBatchSize() << ", H=" << getInputHeight() + << ", W=" << getInputWidth() << ". " << getBatchSize() + << "*" << getInputHeight() << "*" << getInputWidth() + << " != " << inputShape[2] << "."; + } + + if (inputShape[3] != getChannels()) { + return emitOpError() << "Maxpool shape (" << inputShape + << ") at dim -3 must be equal to C. However the " + "attribute given is C=" + << getChannels() << ". " << inputShape[3] + << " != " << getChannels(); + } + + return success(); +} + ::mlir::LogicalResult AllocOp::verify() { auto layout = mlir::dyn_cast_or_null( getResult().getType().getEncoding()); diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index ffc8e31cbf..42432e851c 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -20,6 +20,7 @@ void createTTIRToTTNNBackendPipeline( ttir::TTIRLoadSystemDescOptions systemDescOptions; systemDescOptions.path = options.systemDescPath; + pm.addPass(mlir::tt::ttir::createTTIRSlidingWindow2dFixShapes()); pm.addPass(mlir::tt::ttir::createTTIRLoadSystemDesc(systemDescOptions)); pm.addPass(mlir::tt::ttir::createTTIRImplicitDevice()); diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 888065c6e2..3290dc867b 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -264,6 +264,24 @@ createReshapeOp(FlatbufferObjectCache &cache, ReshapeOp op) { return ::tt::target::ttnn::CreateReshapeOp(*cache.fbb, in, out, shape); } +template +::flatbuffers::Offset<::tt::target::ttnn::MaxPool2dOp> +createMaxPool2dOp(FlatbufferObjectCache &cache, MaxPool2dOp op) { + auto in = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); + auto out = cache.at<::tt::target::TensorRef>( + getOperandThroughDPSOps(op.getResult())); + + auto device = getOperandThroughDPSOps(op.getDevice()); + return ::tt::target::ttnn::CreateMaxPool2dOp( + *cache.fbb, in, out, cache.at<::tt::target::DeviceRef>(device), + op.getBatchSize(), op.getInputHeight(), op.getInputWidth(), + op.getChannels(), op.getKernelHeight(), op.getKernelWidth(), + op.getStrideHeight(), op.getStrideWidth(), op.getDilationHeight(), + op.getDilationWidth(), op.getCeilMode(), op.getPaddingHeight(), + op.getPaddingWidth()); +} + template ::flatbuffers::Offset<::tt::target::ttnn::SoftmaxOp> createSoftmaxOp(FlatbufferObjectCache &cache, SoftmaxOp op) { @@ -374,6 +392,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, return createOperation(cache, createReshapeOp(cache, reshapeOp), debugString); } + if (auto max_pool2dOp = dyn_cast(op); max_pool2dOp) { + return createOperation(cache, createMaxPool2dOp(cache, max_pool2dOp), + debugString); + } if (auto deallocOp = dyn_cast(op); deallocOp) { return createOperation(cache, createDeallocOp(cache, deallocOp), debugString); diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index 01b7f0b4fe..3d043fa872 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -15,6 +15,8 @@ #include "ttmlir/Target/TTNN/program_generated.h" #include "ttnn/device.hpp" #include "ttnn/operations/conv/conv2d/conv2d.hpp" +#include "ttnn/operations/pool/maxpool/max_pool2d.hpp" +#include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/types.hpp" #include "ttnn/types.hpp" #include "types_generated.h" @@ -748,6 +750,25 @@ static void run(::tt::target::ttnn::Conv2dOp const *op, return; } +static void run(::tt::target::ttnn::MaxPool2dOp const *op, + std::unordered_map &devicePool, + ProgramTensorPool &tensorPool) { + const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id()); + const ::ttnn::operations::pool::MaxPoolNewOp operation = + ::ttnn::operations::pool::MaxPoolNewOp(); + + ::ttnn::Device &device = getDevice(op->device(), devicePool); + ::ttnn::Tensor out = operation.invoke( + 0, input, op->batch_size(), op->input_height(), op->input_width(), + op->channels(), {op->kernel_height(), op->kernel_width()}, + {op->stride_height(), op->stride_width()}, + {op->padding_height(), op->padding_width()}, + {op->dilation_height(), op->dilation_width()}, &device); + + tensorPool.insert_or_assign(op->out()->global_id(), std::move(out)); + return; +} + static void run(::tt::target::ttnn::DeallocOp const *op, std::unordered_map &devicePool, ProgramTensorPool &tensorPool) { @@ -799,6 +820,7 @@ run(::tt::target::ttnn::Operation const *op, const std::unordered_map &allDevices, std::unordered_map &devicePool, ProgramTensorPool &tensorPool) { + switch (op->type_type()) { case ::tt::target::ttnn::OpType::GetDeviceOp: { return run(op->type_as_GetDeviceOp(), allDevices, devicePool, tensorPool); @@ -837,13 +859,17 @@ run(::tt::target::ttnn::Operation const *op, } case ::tt::target::ttnn::OpType::ConcatOp: { return run(op->type_as_ConcatOp(), devicePool, tensorPool); + } case ::tt::target::ttnn::OpType::ReshapeOp: { return run(op->type_as_ReshapeOp(), devicePool, tensorPool); } case ::tt::target::ttnn::OpType::DeallocOp: { return run(op->type_as_DeallocOp(), devicePool, tensorPool); } - default: + case ::tt::target::ttnn::OpType::MaxPool2dOp: { + return run(op->type_as_MaxPool2dOp(), devicePool, tensorPool); + } + default: { throw std::runtime_error("Unsupported operation type"); } } diff --git a/test/ttmlir/Dialect/TTNN/simple_maxpool2d.mlir b/test/ttmlir/Dialect/TTNN/simple_maxpool2d.mlir new file mode 100644 index 0000000000..5116b2bfbb --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/simple_maxpool2d.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x128x128x32xbf16>) -> tensor<1x64x64x32xbf16> { + %0 = tensor.empty() : tensor<1x64x64x32xbf16> + // CHECK: %[[C:.*]] = "ttnn.max_pool2d"[[C:.*]] + %1 = "ttir.max_pool2d"(%arg0, %0) <{kernel_height=2: si32, kernel_width=2: si32, stride_height=2: si32, stride_width=2: si32, dilation_height=1: si32, dilation_width=1: si32, ceil_mode=false, padding_left=0: si32, padding_right=0: si32, padding_top=0: si32, padding_bottom=0: si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x128x128x32xbf16>, tensor<1x64x64x32xbf16>) -> tensor<1x64x64x32xbf16> + return %1 : tensor<1x64x64x32xbf16> + } +}