Skip to content

Commit

Permalink
Change !tt.tile shape encoding to be consistent with other shape
Browse files Browse the repository at this point in the history
Previously it was:

    !tt.tile<32 x 32, bfp8_b>

There was an awkward space between the `x` because of the way we can
legally write the assembly syntax.

Now it uses the same DimensionList custom parser as all other shape
types so we have:

    !tt.tile<32x32, bfp8_b>
  • Loading branch information
nsmithtt committed Aug 7, 2024
1 parent ad7da1c commit 41e128f
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 23 deletions.
12 changes: 12 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpImplementation.h"

#include "ttmlir/Dialect/TT/IR/TTOpsEnums.h.inc"

Expand All @@ -20,6 +21,17 @@ inline bool isDeviceMemorySpace(MemorySpace memorySpace) {
return memorySpace == MemorySpace::DeviceDRAM ||
memorySpace == MemorySpace::DeviceL1;
}

inline void printDimensionList(::mlir::AsmPrinter &printer,
::llvm::ArrayRef<int64_t> shape) {
printer.printDimensionList(shape);
}

inline ::mlir::ParseResult
parseDimensionList(::mlir::AsmParser &odsParser,
::llvm::SmallVector<int64_t> &dimensions) {
return odsParser.parseDimensionList(dimensions, false, false);
}
} // namespace mlir::tt

#define GET_ATTRDEF_CLASSES
Expand Down
9 changes: 7 additions & 2 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> {
bool isSystemMemorySpace() const { return ::mlir::tt::isSystemMemorySpace(getMemorySpace()); }
bool isDeviceMemorySpace() const { return ::mlir::tt::isDeviceMemorySpace(getMemorySpace()); }
Type getElementType() const;
uint64_t getElementSizeBytes() const;
llvm::SmallVector<int64_t> getStride(ArrayRef<int64_t> logicalShape) const;
llvm::SmallVector<int64_t> getPhysicalShape(ArrayRef<int64_t> logicalShape) const;
llvm::SmallVector<int64_t> getShardShape() const;
Expand Down Expand Up @@ -266,14 +267,18 @@ class TT_Type<string name, string typeMnemonic, list<Trait> traits = []>
def TT_Tile : TT_Type<"Tile", "tile", [MemRefElementTypeInterface]> {
let summary = "TT tile";
let description = "Tile type in TT dialect";
let parameters = (ins "unsigned":$height, "unsigned":$width, "DataType":$dataType);
let assemblyFormat = "`<` $height `x` $width`,` $dataType `>`";
let parameters = (ins ArrayRefParameter<"int64_t">:$shape, "DataType":$dataType);
let assemblyFormat = "`<` custom<DimensionList>($shape) `,` $dataType `>`";

let extraClassDeclaration = [{
SmallVector<int64_t> getScalarShape(SmallVector<int64_t> tiledShape) const;
SmallVector<int64_t> getTiledShape(SmallVector<int64_t> scalarShape) const;
uint64_t getSizeBytes() const;
int64_t getHeight() const { return getShape()[0]; }
int64_t getWidth() const { return getShape()[1]; }
}];

let genVerifyDecl = 1;
}

def TT_Device : TT_Type<"Device", "device", []> {
Expand Down
3 changes: 2 additions & 1 deletion lib/CAPI/TTTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ namespace mlir::tt {

MlirType ttmlirTTTileTypeGet(MlirContext ctx, unsigned height, unsigned width,
uint32_t dataType) {
return wrap(TileType::get(unwrap(ctx), height, width,
return wrap(TileType::get(unwrap(ctx),
SmallVector<std::int64_t>{height, width},
static_cast<tt::DataType>(dataType)));
}

Expand Down
13 changes: 0 additions & 13 deletions lib/Dialect/TT/IR/TTDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,6 @@ struct TTOpAsmDialectInterface : public OpAsmDialectInterface {
}
};

namespace mlir::tt {
static void printDimensionList(::mlir::AsmPrinter &printer,
::llvm::ArrayRef<int64_t> shape) {
printer.printDimensionList(shape);
}

static ::mlir::ParseResult
parseDimensionList(::mlir::AsmParser &odsParser,
::llvm::SmallVector<int64_t> &dimensions) {
return odsParser.parseDimensionList(dimensions, false, false);
}
} // namespace mlir::tt

#include "ttmlir/Dialect/TT/IR/TTOpsDialect.cpp.inc"

#define GET_ATTRDEF_CLASSES
Expand Down
19 changes: 19 additions & 0 deletions lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,15 @@ mlir::Type LayoutAttr::getElementType() const {
return getMemref().getElementType();
}

uint64_t LayoutAttr::getElementSizeBytes() const {
mlir::Type elementType = getElementType();
if (mlir::isa<TileType>(elementType)) {
auto tileType = mlir::cast<TileType>(elementType);
return tileType.getSizeBytes();
}
return elementType.getIntOrFloatBitWidth() / 8;
}

LayoutAttr LayoutAttr::withGrid(
::mlir::MLIRContext *context, ArrayRef<int64_t> tensorShape, GridAttr grid,
ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals) {
Expand Down Expand Up @@ -449,6 +458,16 @@ DeviceAttr::verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
return ::mlir::success();
}

::mlir::LogicalResult
TileType::verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, DataType dataType) {
if (shape.size() != 2) {
emitError() << "expected 2D shape";
return ::mlir::failure();
}
return ::mlir::success();
}

llvm::SmallVector<int64_t>
TileType::getScalarShape(SmallVector<int64_t> tiledShape) const {
assert(tiledShape.size() >= 2 && "expected at least 2D shape");
Expand Down
8 changes: 6 additions & 2 deletions lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,12 @@ class TTIRLayout : public impl::TTIRLayoutBase<TTIRLayout> {
};

inline uint64_t getElementSizeBytes(Type ty) {
assert(ty.isF32() && "Only support f32 for now");
return 4;
if (isa<TileType>(ty)) {
auto tileType = mlir::cast<TileType>(ty);
return tileType.getSizeBytes();
} else {
return ty.getIntOrFloatBitWidth() / 8;
}
}

inline uint64_t getMemrefSizeBytes(MemRefType ty) {
Expand Down
8 changes: 4 additions & 4 deletions python/TTModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,11 @@ void populateTTModule(py::module &m) {

py::class_<tt::TileType>(m, "TileType")
.def_static("get",
[](MlirContext ctx, unsigned height, unsigned width,
[](MlirContext ctx, std::int64_t height, std::int64_t width,
uint32_t dataType) {
return wrap(
tt::TileType::get(unwrap(ctx), height, width,
static_cast<tt::DataType>(dataType)));
return wrap(tt::TileType::get(
unwrap(ctx), SmallVector<std::int64_t>{height, width},
static_cast<tt::DataType>(dataType)));
})
.def_property_readonly("data_type", &tt::TileType::getDataType)
.def_property_readonly("shape", [](tt::TileType const &tile) {
Expand Down
2 changes: 1 addition & 1 deletion test/python/tensor_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def parallelize(tensor, grid, collapseIntervals=[(0, -1)]):
t0 = createTensorLayout([2, 3, 64, 128], [2, 4])
# CHECK: tensor<2x3x64x128xf32, #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, memref<192x32xf32, #tt.memory_space<l1>>>>
print(t0)
# CHECK: #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, memref<6x1x!tt.tile<32 x 32, bfp_bf8>, #tt.memory_space<l1>>>
# CHECK: #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, memref<6x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space<l1>>>
print(tilize(t0, tt.DataType.BFP_BFloat8).wrapped())
print(parallelize(t0, [3, 2]).wrapped())

Expand Down

0 comments on commit 41e128f

Please sign in to comment.