Skip to content

Commit

Permalink
Adding ttnn_to_dtype op in TTNN dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjordjevicTT committed Jan 31, 2025
1 parent 6eda280 commit 6a219f2
Show file tree
Hide file tree
Showing 22 changed files with 802 additions and 96 deletions.
15 changes: 15 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,21 @@ def TTNN_TypecastOp : TTNN_Op<"typecast"> {
let results = (outs AnyRankedTensor:$result);
}

def TTNN_ToDTypeOp : TTNN_Op<"to_dtype"> {
let summary = "ToDType op.";
let description = [{
This op converts the data type of the input tensor based on the given data type on the host.

Args:
- :attr:`input`: the ttnn.Tensor
- :attr:`dtype`: `ttnn` data type.
}];

let arguments = (ins AnyRankedTensor:$input,
TT_DataTypeAttr:$dtype);
let results = (outs AnyRankedTensor:$result);
}

def TTNN_ToDeviceOp : TTNN_Op<"to_device"> {
let summary = "ToDevice op.";
let description = [{
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ table ToLayoutOp {
out: tt.target.TensorRef;
}

table ToDTypeOp {
in: tt.target.TensorRef;
dtype: tt.target.DataType;
out: tt.target.TensorRef;
}

table TypecastOp {
in: tt.target.TensorRef;
dtype: tt.target.DataType;
Expand Down Expand Up @@ -396,6 +402,7 @@ union OpType {
GetDeviceOp,
ToMemoryConfigOp,
ToLayoutOp,
ToDTypeOp,
TypecastOp,
ToDeviceOp,
FromDeviceOp,
Expand Down
28 changes: 28 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,33 @@ class TypecastOpConversionPattern
};
} // namespace

// ToDTypeOp conversion pattern
//
namespace {
class ToDTypeOpConversionPattern
: public TTNNToEmitCBaseOpConversionPattern<ttnn::ToDTypeOp> {

public:
using TTNNToEmitCBaseOpConversionPattern<
ttnn::ToDTypeOp>::TTNNToEmitCBaseOpConversionPattern;

LogicalResult
matchAndRewrite(ttnn::ToDTypeOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

ArrayAttr arrayAttrs = rewriter.getArrayAttr(
{mlir::IntegerAttr::get(rewriter.getIndexType(), 0),
ttnn_to_emitc::utils::convertDType(rewriter, srcOp.getDtypeAttr())});

rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
srcOp, this->getTypeConverter()->convertType(srcOp.getType()),
this->convertOpName(srcOp), arrayAttrs, nullptr, adaptor.getOperands());

return success();
}
};
} // namespace

// ToMemoryConfig conversion pattern
//
namespace {
Expand Down Expand Up @@ -1128,6 +1155,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// clang-format off
patterns.add<ToLayoutOpConversionPattern,
ToMemoryConfigOpConversionPattern,
DefaultOpConversionPattern<ttnn::ToDTypeOp>,
TypecastOpConversionPattern,
ToDeviceOpConversionPattern,
FromDeviceOpConversionPattern,
Expand Down
Loading

0 comments on commit 6a219f2

Please sign in to comment.