Skip to content

Commit

Permalink
#1102: Remove tt::target::DataType::None (#1103)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT authored Oct 30, 2024
1 parent caa7b90 commit 22a06f2
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 27 deletions.
14 changes: 14 additions & 0 deletions include/ttmlir/Dialect/TTNN/Types/Types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_TYPES_TYPES_H
#define TTMLIR_DIALECT_TTNN_TYPES_TYPES_H
#include <cstdint>

namespace mlir::tt::ttnn {
static constexpr const uint32_t TILE_HEIGHT = 32;
static constexpr const uint32_t TILE_WIDTH = 32;
} // namespace mlir::tt::ttnn

#endif
25 changes: 12 additions & 13 deletions include/ttmlir/Target/Common/types.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,18 @@ enum Arch: uint {
}

enum DataType: uint16 {
None = 0,
Float32 = 1,
Float16 = 2,
BFloat16 = 3,
BFP_Float8 = 4,
BFP_BFloat8 = 5,
BFP_Float4 = 6,
BFP_BFloat4 = 7,
BFP_Float2 = 8,
BFP_BFloat2 = 9,
UInt32 = 10,
UInt16 = 11,
UInt8 = 12,
Float32 = 0,
Float16 = 1,
BFloat16 = 2,
BFP_Float8 = 3,
BFP_BFloat8 = 4,
BFP_Float4 = 5,
BFP_BFloat4 = 6,
BFP_Float2 = 7,
BFP_BFloat2 = 8,
UInt32 = 9,
UInt16 = 10,
UInt8 = 11,
}

enum OOBVal: ushort {
Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ table ToMemoryConfigOp {
table ToLayoutOp {
in: tt.target.TensorRef;
layout: tt.target.TensorLayout;
dtype: tt.target.DataType;
dtype: tt.target.DataType = null;
memcfg: tt.target.MemoryConfigDesc;
device: tt.target.DeviceRef;
out: tt.target.TensorRef;
Expand Down
4 changes: 3 additions & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Types/Types.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"

#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand Down Expand Up @@ -251,7 +252,8 @@ class ToLayoutOpConversionPattern

if (newOutputLayoutEnum == ttnn::Layout::Tile) {
TileType tileType =
TileType::get(rewriter.getContext(), {32, 32}, outputDtype);
TileType::get(rewriter.getContext(),
{ttnn::TILE_HEIGHT, ttnn::TILE_WIDTH}, outputDtype);
llvm::SmallVector<int64_t> newShardShape =
tileType.getTiledShape(llvm::SmallVector<int64_t>(
oldShardShape.begin(), oldShardShape.end()));
Expand Down
3 changes: 0 additions & 3 deletions lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,6 @@ mlir::tt::SystemDescAttr::getFromPath(MLIRContext *context, std::string &path) {

for (auto it : *(element->supported_data_types())) {
switch (it) {
case ::tt::target::DataType::None:
assert(false && "Unexpected None DataType");
break;
case ::tt::target::DataType::Float32:
supported_data_types_attr.push_back(
tt::DataTypeAttr::get(context, tt::DataType::Float32));
Expand Down
7 changes: 2 additions & 5 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/Types/Types.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"
#include "ttmlir/Utils.h"

Expand All @@ -19,9 +20,6 @@

namespace mlir::tt::ttnn {

constexpr int TTNN_TILE_HEIGHT = 32;
constexpr int TTNN_TILE_WIDTH = 32;

//===----------------------------------------------------------------------===//
// Conv2dOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -549,8 +547,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::ToMemoryConfigOp::verify() {
}
if (outputMemoryLayout == ::mlir::tt::TensorMemoryLayout::BlockSharded) {
// TTNN tiles are (32, 32), shard shape must evenly divide the tile shape
if (shardShape[0] % TTNN_TILE_HEIGHT != 0 or
shardShape[1] % TTNN_TILE_WIDTH != 0) {
if (shardShape[0] % TILE_HEIGHT != 0 or shardShape[1] % TILE_WIDTH != 0) {
return emitOpError(
"Shard shape must divide tile shape (32, 32) evenly");
}
Expand Down
5 changes: 3 additions & 2 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,9 @@ createOp(FlatbufferObjectCache &cache, ToLayoutOp op) {
return ::tt::target::ttnn::CreateToLayoutOp(
*cache.fbb, input, layout,
dtype.has_value()
? ::tt::mlir::ttnn::utils::toTargetDataType(dtype.value())
: ::tt::target::DataType::None,
? ::flatbuffers::Optional<::tt::target::DataType>(
::tt::mlir::ttnn::utils::toTargetDataType(dtype.value()))
: ::flatbuffers::nullopt,
memoryConfig.has_value()
? cache.getOrCreate(memoryConfig.value(), memoryConfigToFlatbuffer)
: 0,
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/operations/layout/to_layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ void run(const ::tt::target::ttnn::ToLayoutOp *op, ProgramContext &context) {
std::optional<::ttnn::MemoryConfig> memoryConfig = std::nullopt;
::ttnn::Device *device = nullptr;

if (op->dtype() != ::tt::target::DataType::None) {
dtype = ::tt::runtime::ttnn::utils::toTTNNDataType(op->dtype());
if (op->dtype()) {
dtype = ::tt::runtime::ttnn::utils::toTTNNDataType(*(op->dtype()));
}

if (op->memcfg()) {
Expand Down

0 comments on commit 22a06f2

Please sign in to comment.