diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index e0533fc184..38f8a5ccc5 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -316,22 +316,17 @@ def TTNN_EmbeddingOp : TTNN_Op<"embedding"> { let hasVerifier = 1; } -def TTNN_SoftmaxOp : TTNN_NamedDPSOp<"softmax"> { +def TTNN_SoftmaxOp : TTNN_Op<"softmax"> { let summary = "Softmax op."; let description = [{ Softmax operation. }]; let arguments = (ins AnyRankedTensor:$input, - AnyRankedTensor:$output, SI32Attr: $dimension); let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index a895acc182..31f10ab097 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -301,9 +301,10 @@ class SoftmaxOpConversionPattern : public OpConversionPattern { LogicalResult matchAndRewrite(ttir::SoftmaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + removeDpsOp(rewriter, adaptor); rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), - adaptor.getInput(), adaptor.getOutput(), adaptor.getDimension()); + adaptor.getInput(), adaptor.getDimension()); return success(); } }; diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 78650dff05..506a45f6d3 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -107,7 +107,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::EmbeddingOp::verify() { ::mlir::LogicalResult mlir::tt::ttnn::SoftmaxOp::verify() { ::mlir::RankedTensorType inputType = getInput().getType(); - ::mlir::RankedTensorType outputType = getOutput().getType(); + ::mlir::RankedTensorType outputType = getResult().getType(); // Shapes of input and output of a softmax operation must be the same if (inputType.getShape() != outputType.getShape()) { diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index cfed0120fa..78d0dff2ac 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -390,8 +390,8 @@ ::flatbuffers::Offset<::tt::target::ttnn::SoftmaxOp> createSoftmaxOp(FlatbufferObjectCache &cache, SoftmaxOp op) { auto in = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); - auto out = cache.at<::tt::target::TensorRef>( - getOperandThroughDPSOps(op.getResult())); + auto out = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedAddress, kHostAllocatedSize); int32_t dimension = op.getDimension(); return ::tt::target::ttnn::CreateSoftmaxOp(*cache.fbb, in, out, dimension); diff --git a/test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir b/test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir index f9a0a8fff0..ec05a3006e 100644 --- a/test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir +++ b/test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir @@ -2,12 +2,10 @@ #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x1024xbf16> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<512x1024xbf16> // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] // Check for positive dimension attribute %1 = "ttir.softmax"(%arg0, %0) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %2 = tensor.empty() : tensor<512x1024xbf16> // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] // Check for negative dimension attribute diff --git a/test/ttmlir/Silicon/TTNN/sharded/simple_eltwise_sharded.mlir b/test/ttmlir/Silicon/TTNN/sharded/simple_eltwise_sharded.mlir index b398d070fa..902f3aa598 100644 --- a/test/ttmlir/Silicon/TTNN/sharded/simple_eltwise_sharded.mlir +++ b/test/ttmlir/Silicon/TTNN/sharded/simple_eltwise_sharded.mlir @@ -83,12 +83,10 @@ func.func @sqrt(%arg0: tensor<224x64xf32>) -> tensor<224x64xf32> { } func.func @softmax(%arg0: tensor<224x64xbf16>) -> tensor<224x64xbf16> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<224x64xbf16> // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] // Check for positive dimension attribute %1 = "ttir.softmax"(%arg0, %0) <{dimension = 1 : si32, operand_constraints = [#l1_block_sharded, #l1_block_sharded]}> : (tensor<224x64xbf16>, tensor<224x64xbf16>) -> tensor<224x64xbf16> - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %2 = tensor.empty() : tensor<224x64xbf16> // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] // Check for negative dimension attribute diff --git a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir index 5977e46b7b..3d5cbba155 100644 --- a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir @@ -106,12 +106,10 @@ func.func @rsqrt(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { } func.func @softmax(%arg0: tensor<512x1024xbf16>) -> tensor<512x1024xbf16> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<512x1024xbf16> // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] // Check for positive dimension attribute %1 = "ttir.softmax"(%arg0, %0) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %2 = tensor.empty() : tensor<512x1024xbf16> // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] // Check for negative dimension attribute