From fc498b951ea816fe4868550e8ca0fe7301396a3d Mon Sep 17 00:00:00 2001 From: Nicholas Smith Date: Mon, 2 Sep 2024 16:08:08 -0500 Subject: [PATCH] Support dram memory space in metal direct This change adds support for the device dram memory space in metal backend. This required a few bits of refactoring including: - `DeviceAttr` splits out worker grid from L1 map, they now respectively map the physical compute cores and the physical L1 memory map. - `DeviceAttr` gets a new map, `dramMap`, which maps a linear tensor coordinate to a physical dram address. - Change `projectOnto` to take an arbitrary affine map instead of a grid attr. This let's us pass in unique affine map for L1 or DRAM or Eth (in the future) to the same interface. - `PhysicalCoreCoordMapping` can now take an L1 grid or a DRAM grid. - Giving explicit enum index names for the DeviceAttr affine map results. - Noc datamovement program can now be generated as reads or writes, necessary for writing to DRAM since dram cores do not have risc. Closes #359 --- include/ttmlir/Dialect/TT/IR/TTOpsTypes.h | 19 ++ include/ttmlir/Dialect/TT/IR/TTOpsTypes.td | 29 ++- .../Dialect/TT/Utils/PhysicalCoreCoord.h | 79 +++++- .../ttmlir/Target/Utils/MLIRToFlatbuffer.h | 16 +- include/ttmlir/Utils.h | 4 + lib/Dialect/TT/IR/TTOpsTypes.cpp | 225 ++++++++++++++---- lib/Dialect/TTIR/Transforms/Passes.cpp | 27 ++- lib/Dialect/TTMetal/Transforms/Passes.cpp | 163 ++++++++----- lib/Target/TTNN/TTNNToFlatbuffer.cpp | 2 +- python/TTModule.cpp | 27 ++- test/python/device_attr.py | 80 +++++-- test/ttmlir/Silicon/TTMetal/to_layout.mlir | 33 +++ 12 files changed, 518 insertions(+), 186 deletions(-) diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.h b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.h index a48e1dded..5adf8e21b 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.h +++ b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.h @@ -12,6 +12,25 @@ #include "ttmlir/Dialect/TT/IR/TTOpsEnums.h.inc" namespace mlir::tt { +struct PhysGridResultIdx { + enum : int64_t { + DeviceIdx = 0, + CoreCoordY = 1, + CoreCoordX = 2, + NumIndices = 3, + }; +}; + +struct MemoryMapResultIdx { + enum : int64_t { + DeviceIdx = 0, + CoreCoordY = 1, + CoreCoordX = 2, + ShardOffset = 3, + NumIndices = 4, + }; +}; + inline bool isSystemMemorySpace(MemorySpace memorySpace) { return memorySpace == MemorySpace::System || memorySpace == MemorySpace::SystemMMIO; diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td index 0142db0d5..b8012fdc6 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td +++ b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td @@ -290,7 +290,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { llvm::SmallVector getStride(ArrayRef logicalShape) const; llvm::SmallVector getPhysicalShape(ArrayRef logicalShape) const; llvm::SmallVector getShardShape() const; - AffineMap projectOnto(AffineMap linearMap, ArrayRef logicalTensorShape, GridAttr grid) const; + AffineMap projectOnto(AffineMap linearMap, AffineMap physicalMemoryMap, ArrayRef logicalTensorShape) const; AffineMap getIdentityTileLinearMap() const; llvm::SmallVector getTiledShape(ArrayRef logicalTensorShape) const; }]; @@ -321,18 +321,33 @@ def TT_BufferAttr : TT_Attr<"Buffer", "buffer", []> { } def TT_DeviceAttr : TT_Attr<"Device", "device", []> { - let summary = "Device attribute in TT dialect"; + let summary = "Device attribute in TT dialect."; let description = [{ + Describes the physical layout of a device in the system and is made up of a few components: + - A grid attribute that describes the device's compute grid shape. It not only describes the shape of the compute grid, but also + carries an affine map that describes how the logical grid maps to the physical grid. + - Two affine maps that describe how a tensor layout's linear attribute maps to the L1 and DRAM memory spaces. + - An array of chip ids that this device is made up of. }]; - let parameters = (ins TT_GridAttr:$grid, ArrayRefParameter<"unsigned">:$chipIds); - let assemblyFormat = "`<` qualified($grid) `,` `[` $chipIds `]` `>`"; + let parameters = (ins TT_GridAttr:$workerGrid, + "AffineMap":$l1Map, + "AffineMap":$dramMap, + ArrayRefParameter<"unsigned">:$chipIds); + let assemblyFormat = "`<` `workerGrid` `=` qualified($workerGrid) `,` `l1Map` `=` qualified($l1Map) `,` `dramMap` `=` qualified($dramMap) `,` `chipIds` `=` `[` $chipIds `]` `>`"; let extraClassDeclaration = [{ - static DeviceAttr get(::mlir::MLIRContext *context, ArrayRef shape, AffineMap physicalGridMapping, ArrayRef chipIds) { - return DeviceAttr::get(context, GridAttr::get(context, shape, physicalGridMapping), chipIds); - } static DeviceAttr get(::mlir::MLIRContext *context, SystemDescAttr systemDesc, ArrayRef chipIds); static DeviceAttr get(::mlir::MLIRContext *context, SystemDescAttr systemDesc, bool enableMultichip = false); + AffineMap getMapForMemorySpace(MemorySpace memorySpace) const { + switch (memorySpace) { + case MemorySpace::DeviceL1: + return getL1Map(); + case MemorySpace::DeviceDRAM: + return getDramMap(); + default: + llvm_unreachable("Unsupported memory space"); + } + } }]; let genVerifyDecl = 1; diff --git a/include/ttmlir/Dialect/TT/Utils/PhysicalCoreCoord.h b/include/ttmlir/Dialect/TT/Utils/PhysicalCoreCoord.h index b691d994b..c97e700e2 100644 --- a/include/ttmlir/Dialect/TT/Utils/PhysicalCoreCoord.h +++ b/include/ttmlir/Dialect/TT/Utils/PhysicalCoreCoord.h @@ -21,12 +21,20 @@ struct PhysicalCoreCoord { std::int64_t &operator[](std::size_t i) { assert(i < 3); - return i == 0 ? d : i == 1 ? y : x; + switch (i) { + case 0: + return d; + case 1: + return y; + case 2: + return x; + default: + llvm_unreachable("invalid index"); + } } std::int64_t operator[](std::size_t i) const { - assert(i < 3); - return i == 0 ? d : i == 1 ? y : x; + return (*const_cast(this))[i]; } bool operator==(PhysicalCoreCoord const &other) const { @@ -36,32 +44,79 @@ struct PhysicalCoreCoord { class PhysicalCoreCoordMapping { public: - PhysicalCoreCoordMapping(ArrayRef chipDescs) { - ArrayRef firstChipGrid = chipDescs.front().getGrid(); + static PhysicalCoreCoordMapping + getWorkerMapping(ArrayRef chipIds, + ArrayRef chipDescs) { + SmallVector> physCores; + ArrayRef firstChipGrid = chipDescs[chipIds.front()].getGrid(); assert(firstChipGrid.size() == 2); - grid = {firstChipGrid[0], firstChipGrid[1]}; + std::array grid = {firstChipGrid[0], firstChipGrid[1]}; - workers.reserve(chipDescs.size() * grid[0] * grid[1]); - for (auto chipDesc : chipDescs) { + physCores.reserve(chipIds.size() * grid[0] * grid[1]); + for (auto chipId : chipIds) { + auto chipDesc = chipDescs[chipId]; auto chipGrid = chipDesc.getGrid(); assert(chipGrid == firstChipGrid); ChipPhysicalCoresAttr chipPhysicalCores = chipDesc.getChipPhysicalCores(); assert(chipPhysicalCores.getWorker().size() == static_cast(grid[0] * grid[1])); for (auto worker : chipPhysicalCores.getWorker()) { - workers.push_back({worker.getY(), worker.getX()}); + physCores.push_back({worker.getY(), worker.getX()}); } } - assert(workers.size() == chipDescs.size() * grid[0] * grid[1]); + assert(physCores.size() == chipIds.size() * grid[0] * grid[1]); + return PhysicalCoreCoordMapping(grid, physCores); + } + + static PhysicalCoreCoordMapping + getDramMapping(ArrayRef chipIds, + ArrayRef chipDescs) { + ArrayRef firstChipDramCores = + chipDescs[chipIds.front()].getChipPhysicalCores().getDram(); + + std::array grid = { + 1, static_cast(firstChipDramCores.size())}; + SmallVector> physCores; + physCores.reserve(chipIds.size() * grid[0] * grid[1]); + for (auto chipId : chipIds) { + auto chipDesc = chipDescs[chipId]; + ChipPhysicalCoresAttr chipPhysicalCores = chipDesc.getChipPhysicalCores(); + assert(chipPhysicalCores.getDram().size() == + static_cast(grid[0] * grid[1])); + for (auto dram : chipPhysicalCores.getDram()) { + physCores.push_back({dram.getY(), dram.getX()}); + } + } + assert(physCores.size() == chipIds.size() * grid[0] * grid[1]); + return PhysicalCoreCoordMapping(grid, physCores); + } + + static PhysicalCoreCoordMapping + getMemorySpaceMapping(ArrayRef chipIds, + ArrayRef chipDescs, + MemorySpace memorySpace) { + switch (memorySpace) { + case MemorySpace::DeviceL1: + return getWorkerMapping(chipIds, chipDescs); + case MemorySpace::DeviceDRAM: + return getDramMapping(chipIds, chipDescs); + default: + llvm_unreachable("unsupported memory space"); + } } std::array operator[](PhysicalCoreCoord coord) const { - return workers[coord.d * grid[0] * grid[1] + coord.y * grid[1] + coord.x]; + return physCores[coord.d * grid[0] * grid[1] + coord.y * grid[1] + coord.x]; } +private: + PhysicalCoreCoordMapping(std::array grid, + SmallVector> physCores) + : grid(grid), physCores(physCores) {} + private: std::array grid; - SmallVector> workers; + SmallVector> physCores; }; } // namespace mlir::tt diff --git a/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h b/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h index f1a0a5a1f..45cda3e41 100644 --- a/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h +++ b/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h @@ -289,19 +289,23 @@ toFlatbuffer(FlatbufferObjectCache &cache, GridAttr tensorGrid, ::ttmlir::utils::sample( tensorGridShape, [&](ArrayRef virtualCoreCoord) { SmallVector coreCoord = mapping.compose(virtualCoreCoord); - assert(coreCoord.size() == 3 && "expected a 2D core"); - assert(coreCoord[0] == 0 && "expected single device"); + assert(coreCoord.size() == PhysGridResultIdx::NumIndices && + "expected a 2D core"); + assert(coreCoord[PhysGridResultIdx::DeviceIdx] == 0 && + "expected single device"); if (!coreRangeSet.empty() && - ((coreRangeSet.back().loc().y() == coreCoord[1]) && + ((coreRangeSet.back().loc().y() == + coreCoord[PhysGridResultIdx::CoreCoordY]) && (coreRangeSet.back().loc().x() + coreRangeSet.back().size().x()) == - coreCoord[2])) { + coreCoord[PhysGridResultIdx::CoreCoordX])) { coreRangeSet.back() = ::tt::target::Dim2dRange( coreRangeSet.back().loc(), ::tt::target::Dim2d(coreRangeSet.back().size().y(), coreRangeSet.back().size().x() + 1)); } else { coreRangeSet.push_back(::tt::target::Dim2dRange( - ::tt::target::Dim2d(coreCoord[1], coreCoord[2]), + ::tt::target::Dim2d(coreCoord[PhysGridResultIdx::CoreCoordY], + coreCoord[PhysGridResultIdx::CoreCoordX]), ::tt::target::Dim2d(1, 1))); } if (coreRangeSet.size() > 1 && @@ -401,7 +405,7 @@ layoutAttrToFlatbuffer(FlatbufferObjectCache &cache, Attribute attr, auto strideInt64 = layoutAttr.getStride(logicalShape); std::vector stride(strideInt64.begin(), strideInt64.end()); auto coreRangeSet = - toFlatbuffer(cache, layoutAttr.getGrid(), deviceAttr.getGrid()); + toFlatbuffer(cache, layoutAttr.getGrid(), deviceAttr.getWorkerGrid()); return ::tt::target::CreateLayoutDescDirect( *cache.fbb, &stride, toFlatbuffer(cache, layoutAttr.getOobVal()), &coreRangeSet, diff --git a/include/ttmlir/Utils.h b/include/ttmlir/Utils.h index 19baee00c..3df9f14b9 100644 --- a/include/ttmlir/Utils.h +++ b/include/ttmlir/Utils.h @@ -15,6 +15,10 @@ template T alignUp(T ptr, T alignment) { return (ptr + alignment - 1) & ~(alignment - 1); } +template T alignDown(T ptr, T alignment) { + return ptr & ~(alignment - 1); +} + template inline void sample(Vector const &shape, Fn fn) { llvm::SmallVector strides(shape.size()); diff --git a/lib/Dialect/TT/IR/TTOpsTypes.cpp b/lib/Dialect/TT/IR/TTOpsTypes.cpp index c3866b34c..5e8f2843b 100644 --- a/lib/Dialect/TT/IR/TTOpsTypes.cpp +++ b/lib/Dialect/TT/IR/TTOpsTypes.cpp @@ -84,7 +84,8 @@ mlir::tt::SystemDescAttr::getDefault(MLIRContext *context) { { tt::ChipDescAttr::get( context, tt::ArchAttr::get(context, tt::Arch::WormholeB0), - gridShape, 1499136, 12, (1 << 30), 16, 32, 32, 0, 0, 0, (1 << 30), + gridShape, 1499136, 12, (1 << 30), 16, 32, 32, 1024, 1024, 1024, + (1 << 30), tt::ChipPhysicalCoresAttr::get(context, workerCores, dramCores, {}, {}), supported_data_types, supported_tile_sizes), @@ -666,56 +667,38 @@ mlir::AffineMap LayoutAttr::getIdentityTileLinearMap() const { getContext()); } -// Projects tensor layout onto the device grid. Uses given linear map to derive -// the shard shape and the projection of shard indexes onto the logical grid. -// Then it composes the logical grid projection with physical grid mapping. +// Projects tensor layout onto a physical memory map. Uses given linear map to +// derive the shard shape and the projection of shard indexes onto the logical +// grid. Then it composes the logical grid projection with physical memory +// mapping. mlir::AffineMap LayoutAttr::projectOnto(mlir::AffineMap linearMap, - llvm::ArrayRef logicalTensorShape, - GridAttr grid) const { - assert(getGrid().getShape().size() == grid.getShape().size() && + mlir::AffineMap physicalMemoryMap, + llvm::ArrayRef logicalTensorShape) const { + assert(getGrid().getShape().size() == physicalMemoryMap.getNumDims() && "Layout and device grids must have same number of dimensions"); - assert(getLinear().getNumResults() == grid.getShape().size() && - "Linear map and device grid must have same number of dimensions"); - for (auto [layoutGridDim, otherGridDim] : - llvm::zip(getGrid().getShape(), grid.getShape())) { - assert(layoutGridDim <= otherGridDim && - "Layout grid dimensions must be less than or equal to device grid"); - } + assert(getLinear().getNumResults() == physicalMemoryMap.getNumDims() && + "Linear map and physical map must have same number of dimensions"); mlir::SmallVector logicalShardShape = calculateLogicalShardShape( getContext(), logicalTensorShape, linearMap, getGrid()); - // Compute the projection of the layout onto its own logical grid. - // Simultaneously compute the indexing of shards within each core. - mlir::SmallVector logicalGridProjection( - linearMap.getNumResults()); - mlir::AffineExpr shardIndexing = getAffineConstantExpr(0, getContext()); - int shardVolume = 1; - assert(logicalShardShape.size() == linearMap.getNumResults() && - "Logical shard shape and linear map must have same number of dims"); - for (int i = linearMap.getNumResults() - 1; i >= 0; i--) { - mlir::AffineExpr expr = linearMap.getResult(i); - mlir::AffineExpr shardDim = - getAffineConstantExpr(logicalShardShape[i], getContext()); - mlir::AffineExpr shardVolumeExpr = - getAffineConstantExpr(shardVolume, getContext()); - logicalGridProjection[i] = expr.floorDiv(shardDim); - shardIndexing = (expr % shardDim) * shardVolumeExpr + shardIndexing; - shardVolume *= logicalShardShape[i]; + SmallVector dimReplacements; + for (unsigned i = 0; i < linearMap.getNumDims(); ++i) { + dimReplacements.push_back(getAffineDimExpr(i, getContext())); } - // Compose the logical grid projection with the device grid mapping, now we - // have a projection onto the physical grid. - mlir::AffineMap gridProjection = - grid.getMapping().compose(mlir::AffineMap::get( - logicalTensorShape.size(), 0, logicalGridProjection, getContext())); + assert(physicalMemoryMap.getNumSymbols() == logicalShardShape.size() && + "Physical memory map must have same number of symbols as logical " + "shard rank"); + SmallVector symReplacements; + for (unsigned i = 0; i < physicalMemoryMap.getNumSymbols(); ++i) { + symReplacements.push_back( + getAffineConstantExpr(logicalShardShape[i], getContext())); + } - // Finally we append the indexing of shards within each core. - mlir::SmallVector projection(gridProjection.getResults()); - projection.push_back(shardIndexing); - return mlir::AffineMap::get(logicalTensorShape.size(), 0, projection, - getContext()); + return physicalMemoryMap.compose(linearMap).replaceDimsAndSymbols( + dimReplacements, symReplacements, linearMap.getNumDims(), 0); } mlir::Type BufferAttr::getElementType() const { @@ -731,6 +714,128 @@ llvm::SmallVector BufferAttr::getShape() const { return bufferShape; } +// +// This function creates an affine map that represents mapping the tensor's +// linear layout onto the 2d physical device grid. A typical example will look +// like: +// (d0, d1)[s0, s1] -> ( # Uses affine symbols s0, s1 to represent shard dims +// 0, # Device index +// d0 floordiv s0, # CoreCoordY +// d1 floordiv s1, # CoreCoordX +// (d0 mod s0) * s1 + d1 mod s1 # Element offset within shard +// ) +// +static mlir::AffineMap createL1Map(::mlir::MLIRContext *context, + GridAttr workerGrid, + SystemDescAttr systemDesc, + ::llvm::ArrayRef chipIds) { + mlir::AffineMap workerMap = workerGrid.getMapping(); + mlir::SmallVector l1MapResults(workerMap.getNumDims()); + mlir::AffineExpr shardIndexing = getAffineConstantExpr(0, context); + mlir::AffineExpr shardVolumeExpr = getAffineConstantExpr(1, context); + + // Compute the projection of the layout onto its own logical grid. + // Simultaneously compute the indexing of shards within each core. + for (int i = workerMap.getNumDims() - 1; i >= 0; i--) { + mlir::AffineExpr linearIdx = getAffineDimExpr(i, context); + mlir::AffineExpr shardDim = getAffineSymbolExpr(i, context); + l1MapResults[i] = linearIdx.floorDiv(shardDim); + shardIndexing = (linearIdx % shardDim) * shardVolumeExpr + shardIndexing; + shardVolumeExpr = shardVolumeExpr * shardDim; + } + + // Compose the logical grid projection with the device grid mapping, now we + // have a projection onto the physical grid. + mlir::AffineMap gridProjection = workerMap.compose(mlir::AffineMap::get( + workerMap.getNumDims(), workerMap.getNumDims(), l1MapResults, context)); + + // Finally we append the indexing of shards within each core. + mlir::SmallVector l1Map(gridProjection.getResults()); + l1Map.push_back(shardIndexing); + return mlir::AffineMap::get(workerMap.getNumDims(), workerMap.getNumDims(), + l1Map, context); +} + +// +// This function creates an affine map that represents mapping the tensor's +// linear layout onto physical dram banks. A typical example will end up looking +// pretty complicated: +// (d0, d1)[s0, s1] -> ( +// 0, # Device index +// 0, # CoreCoordY +// (addr floordiv 8192) mod 12, # Channel Idx / CoreCoordX +// addr floordiv 98304 + addr mod 8192 # Offset within channel +// ) +// +// Where `addr` is the linearized address as though it were indexing all of DRAM +// flat. Then we do some additional calculations to break up the channels into +// interleaved pages: +// addr = (((d1 floordiv s1) * 8 + d0 floordiv s0) * (s1 * s0) + +// (d0 mod s0) * s1 + d1 mod s1) +// +static mlir::AffineMap createDramMap(::mlir::MLIRContext *context, + GridAttr workerGrid, ArchAttr arch, + mlir::ArrayRef dramCores, + unsigned dramPageSize) { + mlir::AffineMap workerMap = workerGrid.getMapping(); + assert(workerMap.getNumResults() == PhysGridResultIdx::NumIndices); + mlir::AffineExpr addr = getAffineConstantExpr(0, context); + mlir::AffineExpr shardIndexing = getAffineConstantExpr(0, context); + mlir::AffineExpr shardVolumeExpr = getAffineConstantExpr(1, context); + mlir::AffineExpr gridVolumeExpr = getAffineConstantExpr(1, context); + + for (int i = workerMap.getNumDims() - 1; i >= 0; i--) { + mlir::AffineExpr linearIdx = getAffineDimExpr(i, context); + mlir::AffineExpr shardDim = getAffineSymbolExpr(i, context); + addr = addr * gridVolumeExpr + linearIdx.floorDiv(shardDim); + shardIndexing = (linearIdx % shardDim) * shardVolumeExpr + shardIndexing; + shardVolumeExpr = shardVolumeExpr * shardDim; + gridVolumeExpr = gridVolumeExpr * workerGrid.getShape()[i]; + } + + addr = addr * shardVolumeExpr + shardIndexing; + + mlir::AffineExpr pageSizeExpr = getAffineConstantExpr(dramPageSize, context); + mlir::AffineExpr numDramCores = + getAffineConstantExpr(dramCores.size(), context); + mlir::SmallVector dramMapResults = { + addr.floorDiv(pageSizeExpr) % numDramCores, + addr.floorDiv(pageSizeExpr * numDramCores) + addr % pageSizeExpr, + }; + + // Dram logical coords are 1d, so constant 0 index for + // MemMapResultIdx::CoreCoordY + dramMapResults.insert(dramMapResults.begin(), + getAffineConstantExpr(0, context)); + dramMapResults.insert(dramMapResults.begin(), + workerMap.getResult(MemoryMapResultIdx::DeviceIdx)); + assert(dramMapResults.size() == MemoryMapResultIdx::NumIndices); + + return mlir::AffineMap::get(workerMap.getNumDims(), workerMap.getNumDims(), + dramMapResults, context); +} + +static mlir::AffineMap createDramMap(::mlir::MLIRContext *context, + GridAttr workerGrid, + SystemDescAttr systemDesc, + ::llvm::ArrayRef chipIds, + unsigned dramPageSize) { + auto chipDesc = systemDesc.getChipDescs().front(); + auto chipPhysicalCores = chipDesc.getChipPhysicalCores(); + auto firstDramCores = chipPhysicalCores.getDram(); + assert(!firstDramCores.empty() && "expected at least one dram core"); + + for (unsigned chipId : chipIds) { + auto chipDesc = systemDesc.getChipDescs()[chipId]; + auto chipPhysicalCores = chipDesc.getChipPhysicalCores(); + auto dramCores = chipPhysicalCores.getDram(); + assert(dramCores.size() == firstDramCores.size()); + } + + return createDramMap(context, workerGrid, chipDesc.getArch(), firstDramCores, + dramPageSize); +} + DeviceAttr DeviceAttr::get(::mlir::MLIRContext *context, SystemDescAttr systemDesc, ArrayRef chipIds) { @@ -759,8 +864,12 @@ DeviceAttr DeviceAttr::get(::mlir::MLIRContext *context, SmallVector gridExprs = {dZ, dY, dX}; auto gridMap = AffineMap::get(virtualGrid.size(), 0, gridExprs, context); - - return get(context, GridAttr::get(context, virtualGrid, gridMap), chipIds); + auto workerGrid = GridAttr::get(context, virtualGrid, gridMap); + auto l1Map = createL1Map(context, workerGrid, systemDesc, chipIds); + constexpr unsigned dramPageSize = 8192; + auto dramMap = + createDramMap(context, workerGrid, systemDesc, chipIds, dramPageSize); + return get(context, workerGrid, l1Map, dramMap, chipIds); } DeviceAttr DeviceAttr::get(::mlir::MLIRContext *context, @@ -775,24 +884,38 @@ DeviceAttr DeviceAttr::get(::mlir::MLIRContext *context, ::mlir::LogicalResult DeviceAttr::verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, - ::mlir::tt::GridAttr grid, + ::mlir::tt::GridAttr workerGrid, ::mlir::AffineMap l1Map, + ::mlir::AffineMap dramMap, ::llvm::ArrayRef chipIds) { if (chipIds.empty()) { emitError() << "expected at least one chip"; return ::mlir::failure(); } - auto gridShape = grid.getShape(); - for (auto dim : gridShape) { + auto workerGridShape = workerGrid.getShape(); + for (auto dim : workerGridShape) { if (dim <= 0) { emitError() << "expected positive grid dimensions"; return ::mlir::failure(); } } - auto physicalGridMapping = grid.getMapping(); - if (physicalGridMapping.getNumResults() != 3) { - emitError() << "expected physical grid mapping to have 3 results"; + auto physicalGridMapping = workerGrid.getMapping(); + if (physicalGridMapping.getNumResults() != PhysGridResultIdx::NumIndices) { + emitError() << "expected physical grid mapping to have " + "PhysGridResultIdx::NumIndices results"; + return ::mlir::failure(); + } + + if (l1Map.getNumResults() != MemoryMapResultIdx::NumIndices) { + emitError() + << "expected l1Map to have MemoryMapResultIdx::NumIndices results"; + return ::mlir::failure(); + } + + if (dramMap.getNumResults() != MemoryMapResultIdx::NumIndices) { + emitError() + << "expected dramMap to have MemoryMapResultIdx::NumIndices results"; return ::mlir::failure(); } @@ -897,8 +1020,8 @@ mlir::Type TileType::getElementType() const { } SystemDescAttr mlir::tt::getCurrentScopeSystemDesc(mlir::Operation *op) { - // Walk up scope levels until we find the top level ModuleOp which carries the - // system desc + // Walk up scope levels until we find the top level ModuleOp which carries + // the system desc while (op) { if (mlir::isa(op)) { auto systemDesc = op->getAttrOfType(SystemDescAttr::name); diff --git a/lib/Dialect/TTIR/Transforms/Passes.cpp b/lib/Dialect/TTIR/Transforms/Passes.cpp index 5d05858fb..66fcec712 100644 --- a/lib/Dialect/TTIR/Transforms/Passes.cpp +++ b/lib/Dialect/TTIR/Transforms/Passes.cpp @@ -737,7 +737,7 @@ class TTIRLayout : public impl::TTIRLayoutBase { auto device = getCurrentScopeDevice(getOperation()); assert(device && "Device not found"); TTIRLayoutTensorTypeConverter typeConverter( - &getContext(), initMemorySpace, device.getGrid()); + &getContext(), initMemorySpace, device.getWorkerGrid()); RewritePatternSet patterns(&getContext()); patterns.add(typeConverter, &getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); @@ -916,14 +916,16 @@ inline uint64_t getTensorMemrefSizeBytes(RankedTensorType ty) { class TTIRAllocate : public impl::TTIRAllocateBase { struct SimpleAllocator { - static constexpr uint64_t kBaseAddress = 1llu << 18llu; uint64_t addressAlignment; - - SimpleAllocator(uint64_t addressAlignment) - : addressAlignment(addressAlignment) {} - - SmallVector currPtr = SmallVector( - getMaxEnumValForMemorySpace() + 1llu, kBaseAddress); + SmallVector currPtr; + + SimpleAllocator(uint64_t l1BaseAddress, uint64_t dramBaseAddress, + uint64_t addressAlignment) + : addressAlignment(addressAlignment) { + currPtr = SmallVector(getMaxEnumValForMemorySpace() + 1llu); + currPtr[static_cast(MemorySpace::DeviceL1)] = l1BaseAddress; + currPtr[static_cast(MemorySpace::DeviceDRAM)] = dramBaseAddress; + } uint64_t allocate(uint64_t size, MemorySpace memorySpace) { if (isSystemMemorySpace(memorySpace)) { @@ -972,12 +974,17 @@ class TTIRAllocate : public impl::TTIRAllocateBase { ModuleOp module = getOperation(); IRRewriter rewriter(&getContext()); + SystemDescAttr systemDesc = getCurrentScopeSystemDesc(module); + ChipDescAttr chipDesc = systemDesc.getChipDescs().front(); + module->walk([&](func::FuncOp func) { assert(func.getBody().hasOneBlock()); auto systemDesc = getCurrentScopeSystemDesc(func); assert(systemDesc); auto addressAlignment = systemDesc.getAddressAlignBytes(); - SimpleAllocator allocator(addressAlignment); + SimpleAllocator allocator(chipDesc.getL1UnreservedBase(), + chipDesc.getDramUnreservedBase(), + addressAlignment); Liveness liveness(func.getOperation()); const LivenessBlockInfo *livenessInfo = liveness.getLiveness(&func.getBody().front()); @@ -1040,7 +1047,7 @@ class TTIRGridSet : public impl::TTIRGridSetBase { assert(moduleOp->hasAttr(tt::DeviceAttr::name)); GridAttr max_grid = mlir::cast(moduleOp->getAttr(tt::DeviceAttr::name)) - .getGrid(); + .getWorkerGrid(); SystemDescAttr systemDesc = mlir::cast( moduleOp->getAttr(tt::SystemDescAttr::name)); diff --git a/lib/Dialect/TTMetal/Transforms/Passes.cpp b/lib/Dialect/TTMetal/Transforms/Passes.cpp index 8b2e74cf9..32974e4a3 100644 --- a/lib/Dialect/TTMetal/Transforms/Passes.cpp +++ b/lib/Dialect/TTMetal/Transforms/Passes.cpp @@ -78,20 +78,23 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - struct NocRead { - PhysicalCoreCoord srcCoord; + struct NocTx { + enum class Type { Read, Write }; + + Type type; + PhysicalCoreCoord coreCoord; std::int64_t srcOffset = 0; std::int64_t dstOffset = 0; std::int64_t size = 0; - NocRead(PhysicalCoreCoord srcCoord, std::int64_t srcOffset, - std::int64_t dstOffset, std::int64_t size) - : srcCoord(srcCoord), srcOffset(srcOffset), dstOffset(dstOffset), - size(size) {} + NocTx(Type type, PhysicalCoreCoord coreCoord, std::int64_t srcOffset, + std::int64_t dstOffset, std::int64_t size) + : type(type), coreCoord(coreCoord), srcOffset(srcOffset), + dstOffset(dstOffset), size(size) {} bool isContiguous(PhysicalCoreCoord nextCoord, std::int64_t nextSrcOffset, std::int64_t nextDstOffset) const { - return (nextCoord == srcCoord) && (nextSrcOffset == srcOffset + size) && + return (nextCoord == coreCoord) && (nextSrcOffset == srcOffset + size) && (nextDstOffset == dstOffset + size); } }; @@ -102,66 +105,71 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { // lambda with the current index. It walks the shape in innermost-major // order. It also coalesces the noc transactions. // - // The return value is a map of destination physical cores where each core has - // an associated list of noc reads to be performed. - llvm::MapVector> - calculateDataMovement(ArrayRef tensorShape, int64_t elemSize, - AffineMap src, AffineMap dst) const { - - // For now it's just a simple pull model, but eventually we want to leverage - // both NoCs and the both read and write - llvm::MapVector> dst2srcMap; - assert(src.getNumResults() == 4); - assert(dst.getNumResults() == 4); - - ::ttmlir::utils::sample(tensorShape, [&dst2srcMap, src, dst, elemSize]( - ArrayRef index) { + // The return value is a map of physical cores where each core has + // an associated list of noc reads/writes to be performed. + llvm::MapVector> + calculateDataMovement(ArrayRef tensorShape, std::int64_t elemSize, + AffineMap src, AffineMap dst, NocTx::Type type) const { + bool read = type == NocTx::Type::Read; + llvm::MapVector> txMap; + assert(src.getNumResults() == MemoryMapResultIdx::NumIndices); + assert(dst.getNumResults() == MemoryMapResultIdx::NumIndices); + + ::ttmlir::utils::sample(tensorShape, [&txMap, src, dst, elemSize, read, + type](ArrayRef index) { SmallVector srcResults = src.compose(index); SmallVector dstResults = dst.compose(index); - assert(srcResults.size() == src.getNumResults()); assert(dstResults.size() == dst.getNumResults()); PhysicalCoreCoord srcCoord(srcResults); PhysicalCoreCoord dstCoord(dstResults); std::int64_t srcOffset = srcResults.back() * elemSize; std::int64_t dstOffset = dstResults.back() * elemSize; - SmallVector &srcs = dst2srcMap[dstCoord]; - if (not srcs.empty() && - srcs.back().isContiguous(srcCoord, srcOffset, dstOffset)) { - srcs.back().size += elemSize; + SmallVector &txs = txMap[read ? dstCoord : srcCoord]; + if (not txs.empty() && txs.back().isContiguous(read ? srcCoord : dstCoord, + srcOffset, dstOffset)) { + txs.back().size += elemSize; } else { - srcs.push_back(NocRead(srcCoord, srcOffset, dstOffset, elemSize)); + txs.push_back(NocTx(type, read ? srcCoord : dstCoord, srcOffset, + dstOffset, elemSize)); } }); - return dst2srcMap; + return txMap; } - void buildNocAsyncRead(mlir::Location loc, std::int64_t inputBaseAddress, - std::int64_t outputBaseAddress, - std::int64_t addressAlignment, NocRead read, - PhysicalCoreCoordMapping const &physicalCoordMapping, - mlir::OpBuilder &nocBuilder) const { - assert(read.srcOffset % addressAlignment == 0); - assert(read.dstOffset % addressAlignment == 0); - assert(read.size % addressAlignment == 0); - auto [yPhys, xPhys] = physicalCoordMapping[read.srcCoord]; + void buildNocAsyncTx(mlir::Location loc, std::int64_t inputBaseAddress, + std::int64_t outputBaseAddress, + std::int64_t addressAlignment, NocTx nocTx, + PhysicalCoreCoordMapping const &physicalCoordMapping, + mlir::OpBuilder &nocBuilder) const { + assert(nocTx.srcOffset % addressAlignment == 0); + assert(nocTx.dstOffset % addressAlignment == 0); + assert(nocTx.size % addressAlignment == 0); + auto [yPhys, xPhys] = physicalCoordMapping[nocTx.coreCoord]; auto y = nocBuilder.create( loc, nocBuilder.getI32Type(), nocBuilder.getI32IntegerAttr(yPhys)); auto x = nocBuilder.create( loc, nocBuilder.getI32Type(), nocBuilder.getI32IntegerAttr(xPhys)); - auto srcOffset = nocBuilder.create( + auto srcLocalL1Addr = nocBuilder.create( loc, nocBuilder.getI32Type(), - nocBuilder.getI32IntegerAttr(inputBaseAddress + read.srcOffset)); - auto srcRemoteNocAddr = - nocBuilder.create(loc, x, y, srcOffset); + nocBuilder.getI32IntegerAttr(inputBaseAddress + nocTx.srcOffset)); auto dstLocalL1Addr = nocBuilder.create( loc, nocBuilder.getI32Type(), - nocBuilder.getI32IntegerAttr(outputBaseAddress + read.dstOffset)); + nocBuilder.getI32IntegerAttr(outputBaseAddress + nocTx.dstOffset)); auto size = nocBuilder.create( - loc, nocBuilder.getI32Type(), nocBuilder.getI32IntegerAttr(read.size)); - nocBuilder.create(loc, srcRemoteNocAddr, - dstLocalL1Addr, size); + loc, nocBuilder.getI32Type(), nocBuilder.getI32IntegerAttr(nocTx.size)); + if (nocTx.type == NocTx::Type::Read) { + auto srcRemoteNocAddr = + nocBuilder.create(loc, x, y, srcLocalL1Addr); + nocBuilder.create(loc, srcRemoteNocAddr, + dstLocalL1Addr, size); + } else { + auto dstRemoteNocAddr = + nocBuilder.create(loc, x, y, dstLocalL1Addr); + nocBuilder.create(loc, srcLocalL1Addr, + dstRemoteNocAddr, size); + } } LogicalResult relayout(ttir::ToLayoutOp op, PatternRewriter &rewriter) const { @@ -169,7 +177,6 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { auto outputTy = mlir::cast(op.getType()); auto inputLayout = mlir::cast(inputTy.getEncoding()); auto outputLayout = mlir::cast(outputTy.getEncoding()); - assert(inputLayout.getMemorySpace() == outputLayout.getMemorySpace()); tt::DeviceAttr device = op.getDevice(); assert(device); tt::SystemDescAttr systemDesc = op.getSystemDesc(); @@ -206,22 +213,34 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { ? outputLayout.getIdentityTileLinearMap() : outputLayout.getLinear(); - AffineMap src = - inputLayout.projectOnto(inputLinearMap, inputShape, device.getGrid()); + assert(inputLayout.getMemorySpace() == MemorySpace::DeviceL1 || + outputLayout.getMemorySpace() == MemorySpace::DeviceL1 && + "DRAM <-> DRAM is not supported yet"); + NocTx::Type dataMovementType = + outputLayout.getMemorySpace() == MemorySpace::DeviceL1 + ? NocTx::Type::Read + : NocTx::Type::Write; + + AffineMap src = inputLayout.projectOnto( + inputLinearMap, + device.getMapForMemorySpace(inputLayout.getMemorySpace()), inputShape); - AffineMap dst = outputLayout.projectOnto(outputLinearMap, outputShape, - device.getGrid()); + AffineMap dst = outputLayout.projectOnto( + outputLinearMap, + device.getMapForMemorySpace(outputLayout.getMemorySpace()), + outputShape); - auto dm = calculateDataMovement( - inputShape, inputLayout.getElementSizeBytes(), src, dst); + auto dm = + calculateDataMovement(inputShape, inputLayout.getElementSizeBytes(), + src, dst, dataMovementType); auto noc0Attr = rewriter.getAttr(ttkernel::ThreadType::Noc0); SmallVector threadTypes(dm.size(), noc0Attr); SmallVector coreRanges; coreRanges.reserve(dm.size()); - for (auto [dstCoord, srcs] : dm) { - SmallVector offset = {dstCoord.y, dstCoord.x}; + for (auto [coreCoord, txs] : dm) { + SmallVector offset = {coreCoord.y, coreCoord.x}; SmallVector size = {1, 1}; coreRanges.push_back( rewriter.getAttr(offset, size)); @@ -234,22 +253,32 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { rewriter.getArrayAttr(threadTypes), threadTypes.size()); int i = 0; - PhysicalCoreCoordMapping physicalCoordMapping(systemDesc.getChipDescs()); + PhysicalCoreCoordMapping physicalCoordMapping = + PhysicalCoreCoordMapping::getMemorySpaceMapping( + device.getChipIds(), systemDesc.getChipDescs(), + dataMovementType == NocTx::Type::Read + ? inputLayout.getMemorySpace() + : outputLayout.getMemorySpace()); std::int64_t inputBaseAddress = lookupAddress(op.getInput()); std::int64_t outputBaseAddress = lookupAddress(op.getOutput()); assert(inputBaseAddress); assert(outputBaseAddress); assert(inputBaseAddress % addressAlignment == 0); assert(outputBaseAddress % addressAlignment == 0); - for (auto [dstCoord, srcs] : dm) { + for (auto [coreCoord, txs] : dm) { Block *nocBlock = rewriter.createBlock(&metalDispatch.getRegion(i++)); OpBuilder nocBuilder(nocBlock, nocBlock->begin()); - for (auto s : srcs) { - buildNocAsyncRead(op.getLoc(), inputBaseAddress, outputBaseAddress, - addressAlignment, s, physicalCoordMapping, - nocBuilder); + NocTx::Type type = txs.front().type; + for (auto tx : txs) { + assert(tx.type == type); + buildNocAsyncTx(op.getLoc(), inputBaseAddress, outputBaseAddress, + addressAlignment, tx, physicalCoordMapping, nocBuilder); + } + if (type == NocTx::Type::Read) { + nocBuilder.create(op.getLoc()); + } else { + nocBuilder.create(op.getLoc()); } - nocBuilder.create(op.getLoc()); nocBuilder.create(op.getLoc(), ValueRange()); } @@ -366,10 +395,12 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { auto components = op.compoundComponents(); bool isCompound = (static_cast(components.isLayoutChange) + - static_cast(components.isGridChange) + - static_cast(components.isFormatChange) + - static_cast(components.isMemorySpaceChange) + - static_cast(components.isMemoryLayoutChange)) > 1; + static_cast(components.isGridChange || + components.isMemorySpaceChange) + + static_cast(components.isFormatChange)) > 1; + assert(!components.isMemoryLayoutChange && + "Memory layout is not used in direct to metal path"); + assert(!isCompound && "Only one change is allowed"); assert(!isCompound && "Only one change is allowed"); assert(!components.isMemoryLayoutChange && @@ -384,7 +415,7 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { rewriter.replaceOpWithNewOp( op, outputTy, op.getInput(), op.getOutput()); } else { - assert(false && "L1 <-> DRAM not supported yet"); + return relayout(op, rewriter); } } else if (components.isLayoutChange || components.isGridChange) { return relayout(op, rewriter); diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 2b0ce6123..ba5edda10 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -59,7 +59,7 @@ createOp(FlatbufferObjectCache &cache, OpenDeviceOp op) { auto result = op.getResult(); auto resultType = mlir::cast(result.getType()); ::tt::target::Dim2d grid = - toFlatbuffer(cache, resultType.getDesc().getGrid()); + toFlatbuffer(cache, resultType.getDesc().getWorkerGrid()); auto chipIds = toFlatbuffer(cache, resultType.getDesc().getChipIds()); auto out = cache.getOrCreate(result, createDeviceRef); return ::tt::target::ttnn::CreateOpenDeviceOp(*cache.fbb, &grid, chipIds, diff --git a/python/TTModule.cpp b/python/TTModule.cpp index a96c662a2..685c11a4b 100644 --- a/python/TTModule.cpp +++ b/python/TTModule.cpp @@ -157,6 +157,10 @@ void populateTTModule(py::module &m) { }); py::class_(m, "SystemDescAttr") + .def_static("get_default", + [](MlirContext ctx) { + return wrap(tt::SystemDescAttr::getDefault(unwrap(ctx))); + }) .def_static("get", [](MlirContext ctx, std::vector chipDescs, std::vector chipDescIndices, @@ -227,13 +231,22 @@ void populateTTModule(py::module &m) { }); py::class_(m, "DeviceAttr") - .def_static( - "get", - [](MlirContext ctx, std::vector shape, - MlirAffineMap physicalGridMapping, std::vector chipIds) { - return wrap(tt::DeviceAttr::get( - unwrap(ctx), shape, unwrap(physicalGridMapping), chipIds)); - }) + .def_static("from_system_desc", + [](MlirContext ctx, MlirAttribute systemDesc) { + return wrap(tt::DeviceAttr::get( + unwrap(ctx), + mlir::cast(unwrap(systemDesc)))); + }) + .def_static("get", + [](MlirContext ctx, std::vector shape, + MlirAffineMap workerGridMapping, MlirAffineMap l1Map, + MlirAffineMap dramMap, std::vector chipIds) { + return wrap(tt::DeviceAttr::get( + unwrap(ctx), + tt::GridAttr::get(unwrap(ctx), shape, + unwrap(workerGridMapping)), + unwrap(l1Map), unwrap(dramMap), chipIds)); + }) .def("unwrap", [](MlirAttribute const &self) { return mlir::cast(unwrap(self)); }); diff --git a/test/python/device_attr.py b/test/python/device_attr.py index 590fae147..8179f34e7 100644 --- a/test/python/device_attr.py +++ b/test/python/device_attr.py @@ -26,7 +26,7 @@ def getTotalDevices(grid, physicalGrid=[8, 8]): return volume(grid) // volume(physicalGrid) -def inferAffineMap(grid, physicalGrid=[8, 8]): +def inferWorkerGridMap(grid, physicalGrid=[8, 8]): assert len(grid) >= 2 mesh = grid[:-2] + [ updiv(grid[-2], physicalGrid[-2]), @@ -53,11 +53,35 @@ def inferAffineMap(grid, physicalGrid=[8, 8]): return AffineMap.get(len(grid), 0, exprs, ctx) -def createDeviceAttr(grid, physicalGrid=[8, 8], deviceStartIdx=0, affMap=None): +def inferMemoryMap(grid): + assert len(grid) <= 4 + zero = AffineConstantExpr.get(0, ctx) + exprs = [AffineDimExpr.get(i, ctx) for i in range(len(grid))] + while len(exprs) < 4: + exprs.insert(0, zero) + return AffineMap.get(len(grid), len(grid), exprs, ctx) + + +def createDeviceAttr( + grid, physicalGrid=[8, 8], deviceStartIdx=0, workerGridMap=None, system_desc=None +): + if system_desc is not None: + return tt.ir.DeviceAttr.from_system_desc(ctx, system_desc) totalDevices = getTotalDevices(grid, physicalGrid=physicalGrid) - affineMap = affMap if affMap is not None else inferAffineMap(grid, physicalGrid) + workerGridMap = ( + workerGridMap + if workerGridMap is not None + else inferWorkerGridMap(grid, physicalGrid) + ) + l1Map = inferMemoryMap(grid) + dramMap = inferMemoryMap(grid) return tt.ir.DeviceAttr.get( - ctx, grid, affineMap, list(range(deviceStartIdx, deviceStartIdx + totalDevices)) + ctx, + grid, + workerGridMap, + l1Map, + dramMap, + list(range(deviceStartIdx, deviceStartIdx + totalDevices)), ) @@ -71,89 +95,93 @@ def createDeviceAttr(grid, physicalGrid=[8, 8], deviceStartIdx=0, affMap=None): d1 = d(1) d2 = d(2) +print("=== From SystemDesc ===") +# CHECK: tt.device (0, d0, d1)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [0]> +print("", createDeviceAttr([8, 8], system_desc=tt.ir.SystemDescAttr.get_default(ctx))) + # ------------------------------------------------------------------------------ print("=== Simple single device ===") -# CHECK: tt.device<#tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, [0]> +# CHECK: tt.device (0, d0, d1)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [0]> print("", createDeviceAttr([8, 8])) # ------------------------------------------------------------------------------ print("\n=== Data parallel over batch ===") -# CHECK: tt.device<#tt.grid<2x8x8, (d0, d1, d2) -> (d0 + d1 floordiv 8 + d2 floordiv 8, d1, d2)>, [0, 1]> +# CHECK: tt.device (d0 + d1 floordiv 8 + d2 floordiv 8, d1, d2)>, l1Map = [[M:.*]], dramMap = [[M:.*]], chipIds = [0, 1]> print("divide batch by 2\n", createDeviceAttr([2, 8, 8])) -# CHECK: tt.device<#tt.grid<4x8x8, (d0, d1, d2) -> (d0 + d1 floordiv 8 + d2 floordiv 8, d1, d2)>, [0, 1, 2, 3]> +# CHECK: tt.device (d0 + d1 floordiv 8 + d2 floordiv 8, d1, d2)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [0, 1, 2, 3]> print("divide batch by 4\n", createDeviceAttr([4, 8, 8])) # ------------------------------------------------------------------------------ print("\n=== Data parallel over 2d ===") -# CHECK: tt.device<#tt.grid<8x16, (d0, d1) -> ((d0 floordiv 8) * 2 + d1 floordiv 8, d0, d1 mod 8)>, [0, 1]> +# CHECK: tt.device ((d0 floordiv 8) * 2 + d1 floordiv 8, d0, d1 mod 8)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [0, 1]> print( "Reinterpret 2 devices as grid side by side, 1x2 mesh\n", createDeviceAttr([8, 16]) ) -# CHECK: tt.device<#tt.grid<16x8, (d0, d1) -> (d0 floordiv 8 + d1 floordiv 8, d0 mod 8, d1)>, [0, 1]> +# CHECK: tt.device (d0 floordiv 8 + d1 floordiv 8, d0 mod 8, d1)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [0, 1]> print( "Reinterpret 2 devices as grid top to bottom, 2x1 mesh\n", createDeviceAttr([16, 8]) ) -# CHECK: tt.device<#tt.grid<16x32, (d0, d1) -> ((d0 floordiv 8) * 4 + d1 floordiv 8, d0 mod 8, d1 mod 8)>, [0, 1, 2, 3, 4, 5, 6, 7]> +# CHECK: tt.device ((d0 floordiv 8) * 4 + d1 floordiv 8, d0 mod 8, d1 mod 8)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [0, 1, 2, 3, 4, 5, 6, 7]> print("8 devices 2x4 mesh\n", createDeviceAttr([16, 32])) -# CHECK: tt.device<#tt.grid<32x16, (d0, d1) -> ((d0 floordiv 8) * 2 + d1 floordiv 8, d0 mod 8, d1 mod 8)>, [0, 1, 2, 3, 4, 5, 6, 7]> +# CHECK: tt.device ((d0 floordiv 8) * 2 + d1 floordiv 8, d0 mod 8, d1 mod 8)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [0, 1, 2, 3, 4, 5, 6, 7]> print("8 devices 4x2 mesh\n", createDeviceAttr([32, 16])) # ------------------------------------------------------------------------------ print("\n=== Data parallel over 2d and batch (3d) ===") -# CHECK: tt.device<#tt.grid<2x8x16, (d0, d1, d2) -> (d0 * 2 + (d1 floordiv 8) * 2 + d2 floordiv 8, d1, d2 mod 8)>, [0, 1, 2, 3]> +# CHECK: tt.device (d0 * 2 + (d1 floordiv 8) * 2 + d2 floordiv 8, d1, d2 mod 8)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [0, 1, 2, 3]> print("divide batch by 2, 2x1x2 mesh\n", createDeviceAttr([2, 8, 16])) -# CHECK: tt.device<#tt.grid<3x24x8, (d0, d1, d2) -> (d0 * 3 + d1 floordiv 8 + d2 floordiv 8, d1 mod 8, d2)>, [0, 1, 2, 3, 4, 5, 6, 7, 8]> +# CHECK: tt.device (d0 * 3 + d1 floordiv 8 + d2 floordiv 8, d1 mod 8, d2)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [0, 1, 2, 3, 4, 5, 6, 7, 8]> print("divide batch by 3, 3x3x1 mesh\n", createDeviceAttr([3, 24, 8])) # ------------------------------------------------------------------------------ print("\n=== nD ===") -# CHECK: tt.device<#tt.grid<3x2x8x8, (d0, d1, d2, d3) -> (d0 * 2 + d1 + d2 floordiv 8 + d3 floordiv 8, d2, d3)>, [0, 1, 2, 3, 4, 5]> +# CHECK: tt.device (d0 * 2 + d1 + d2 floordiv 8 + d3 floordiv 8, d2, d3)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [0, 1, 2, 3, 4, 5]> print("", createDeviceAttr([3, 2, 8, 8])) # ------------------------------------------------------------------------------ print("\n=== Data parallel batch on single device ===") -# CHECK: tt.device<#tt.grid<2x4x8, (d0, d1, d2) -> (0, d0 * 4 + d1, d2)>, [0]> +# CHECK: tt.device (0, d0 * 4 + d1, d2)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [0]> print( "divide batch by 2, top 4 rows get batch 0, bottom 4 rows get batch 1\n", - createDeviceAttr([2, 4, 8], affMap=amap(3, [c0, d0 * 4 + d1, d2])), + createDeviceAttr([2, 4, 8], workerGridMap=amap(3, [c0, d0 * 4 + d1, d2])), ) # ------------------------------------------------------------------------------ print("\n=== Pipeline parallel ===") -# CHECK: tt.device<#tt.grid<2x8x16, (d0, d1, d2) -> (d0 * 2 + (d1 floordiv 8) * 2 + d2 floordiv 8, d1, d2 mod 8)>, [0, 1, 2, 3]> +# CHECK: tt.device (d0 * 2 + (d1 floordiv 8) * 2 + d2 floordiv 8, d1, d2 mod 8)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [0, 1, 2, 3]> print("view devices 0-3 in one way\n", createDeviceAttr([2, 8, 16], deviceStartIdx=0)) -# CHECK: tt.device<#tt.grid<16x16, (d0, d1) -> ((d0 floordiv 8) * 2 + d1 floordiv 8, d0 mod 8, d1 mod 8)>, [4, 5, 6, 7]> +# CHECK: tt.device ((d0 floordiv 8) * 2 + d1 floordiv 8, d0 mod 8, d1 mod 8)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [4, 5, 6, 7]> print("view devices 4-7 in another way\n", createDeviceAttr([16, 16], deviceStartIdx=4)) # ------------------------------------------------------------------------------ print("\n=== Reinterpreted Grids ===") -# CHECK: tt.device<#tt.grid<8x8, (d0, d1) -> (0, d1, d0)>, [0]> -print("transposed\n", createDeviceAttr([8, 8], affMap=amap(2, [c0, d1, d0]))) -# CHECK: tt.device<#tt.grid<1x64, (d0, d1) -> (0, d0 * 8 + d1 floordiv 8, d1 mod 8)>, [0]> +# CHECK: tt.device (0, d1, d0)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [0]> +print("transposed\n", createDeviceAttr([8, 8], workerGridMap=amap(2, [c0, d1, d0]))) +# CHECK: tt.device (0, d0 * 8 + d1 floordiv 8, d1 mod 8)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [0]> print( "extra wide\n", createDeviceAttr( - [1, 64], affMap=amap(2, [c0, d0 * 8 + floordiv(d1, c(8)), d1 % 8]) + [1, 64], workerGridMap=amap(2, [c0, d0 * 8 + floordiv(d1, c(8)), d1 % 8]) ), ) -# CHECK: tt.device<#tt.grid<64x1, (d0, d1) -> (0, d1 * 8 + d0 floordiv 8, d0 mod 8)>, [0]> +# CHECK: tt.device (0, d1 * 8 + d0 floordiv 8, d0 mod 8)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [0]> print( "extra tall transposed\n", createDeviceAttr( [64, 1], - affMap=amap(2, [c0, d1 * 8 + floordiv(d0, c(8)), d0 % 8]), + workerGridMap=amap(2, [c0, d1 * 8 + floordiv(d0, c(8)), d0 % 8]), ), ) -# CHECK: tt.device<#tt.grid<8x8, (d0, d1) -> (0, d0, (d0 + d1) mod 8)>, [0]> +# CHECK: tt.device (0, d0, (d0 + d1) mod 8)>, l1Map = [[L1:.*]], dramMap = [[DRAM:.*]], chipIds = [0]> print( "staircase systolic\n", - createDeviceAttr([8, 8], affMap=amap(2, [c0, d0, (d0 + d1) % 8])), + createDeviceAttr([8, 8], workerGridMap=amap(2, [c0, d0, (d0 + d1) % 8])), ) diff --git a/test/ttmlir/Silicon/TTMetal/to_layout.mlir b/test/ttmlir/Silicon/TTMetal/to_layout.mlir index b964f5681..015e65175 100644 --- a/test/ttmlir/Silicon/TTMetal/to_layout.mlir +++ b/test/ttmlir/Silicon/TTMetal/to_layout.mlir @@ -3,6 +3,7 @@ // RUN: ttmlir-translate --ttmetal-to-flatbuffer %t.mlir > %t.ttm #l1_ = #tt.memory_space +#dram = #tt.memory_space #layout = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<4x16xf32, #l1_>> #layout1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<2x8xf32, #l1_>> @@ -24,3 +25,35 @@ func.func @tilize(%arg0: tensor<64x128xf32, #untilized>) -> tensor<64x128xf32, # %3 = "ttir.to_layout"(%1, %2) : (tensor<64x128xf32, #tilized>, tensor<64x128xf32, #untilized>) -> tensor<64x128xf32, #untilized> return %3 : tensor<64x128xf32, #untilized> } + +#untilized_dram = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<16x64xf32, #dram>> +#untilized_l1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<16x64xf32, #l1_>> +#untilized2x2_dram = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<8x32xf32, #dram>> +#untilized2x2_l1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<8x32xf32, #l1_>> +#untilized1x4_l1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x4>, memref<16x16xf32, #l1_>> +func.func @dram_to_l1(%arg0: tensor<16x64xf32, #untilized_dram>) -> tensor<16x64xf32, #untilized_l1> { + %0 = tensor.empty() : tensor<16x64xf32, #untilized_l1> + // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] + %1 = "ttir.to_layout"(%arg0, %0) : (tensor<16x64xf32, #untilized_dram>, tensor<16x64xf32, #untilized_l1>) -> tensor<16x64xf32, #untilized_l1> + return %1 : tensor<16x64xf32, #untilized_l1> +} + +func.func @l1_to_dram(%arg0: tensor<16x64xf32, #untilized_l1>) -> tensor<16x64xf32, #untilized_dram> { + %0 = tensor.empty() : tensor<16x64xf32, #untilized_dram> + // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] + %1 = "ttir.to_layout"(%arg0, %0) : (tensor<16x64xf32, #untilized_l1>, tensor<16x64xf32, #untilized_dram>) -> tensor<16x64xf32, #untilized_dram> + return %1 : tensor<16x64xf32, #untilized_dram> +} + +func.func @l1dram_reblock0(%arg0: tensor<16x64xf32, #untilized_l1>) -> tensor<16x64xf32, #untilized_l1> { + %0 = tensor.empty() : tensor<16x64xf32, #untilized2x2_dram> + // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] + %1 = "ttir.to_layout"(%arg0, %0) : (tensor<16x64xf32, #untilized_l1>, tensor<16x64xf32, #untilized2x2_dram>) -> tensor<16x64xf32, #untilized2x2_dram> + %2 = tensor.empty() : tensor<16x64xf32, #untilized1x4_l1> + // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] + %3 = "ttir.to_layout"(%1, %2) : (tensor<16x64xf32, #untilized2x2_dram>, tensor<16x64xf32, #untilized1x4_l1>) -> tensor<16x64xf32, #untilized1x4_l1> + %4 = tensor.empty() : tensor<16x64xf32, #untilized_l1> + // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] + %5 = "ttir.to_layout"(%3, %4) : (tensor<16x64xf32, #untilized1x4_l1>, tensor<16x64xf32, #untilized_l1>) -> tensor<16x64xf32, #untilized_l1> + return %5 : tensor<16x64xf32, #untilized_l1> +}