Skip to content

Commit

Permalink
#1271 Add CPU Desc to System Desc (#1293)
Browse files Browse the repository at this point in the history
* Add CPUDesc field, which contains host vs device specifier + target triple info, as part of SystemDesc throughout the stack
  • Loading branch information
vwellsTT authored Nov 15, 2024
1 parent 8f3f90a commit 3617528
Show file tree
Hide file tree
Showing 12 changed files with 161 additions and 33 deletions.
11 changes: 6 additions & 5 deletions include/ttmlir-c/TTAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipChannelAttrGet(
int64_t *ethernetCoreCoord1, size_t ethernetCoreCoord1Size);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTSystemDescAttrGet(
MlirContext ctx, MlirAttribute *chipDescs, size_t chipDescsSize,
unsigned *chipDescIndices, size_t chipDescIndicesSize,
MlirAttribute *chipCapabilities, size_t chipCapabilitiesSize,
MlirAttribute *chipCoords, size_t chipCoordsSize,
MlirAttribute *chipChannels, size_t chipChannelsSize);
MlirContext ctx, MlirAttribute *cpuDescs, size_t cpuDescsSize,
MlirAttribute *chipDescs, size_t chipDescsSize, unsigned *chipDescIndices,
size_t chipDescIndicesSize, MlirAttribute *chipCapabilities,
size_t chipCapabilitiesSize, MlirAttribute *chipCoords,
size_t chipCoordsSize, MlirAttribute *chipChannels,
size_t chipChannelsSize);

MLIR_CAPI_EXPORTED MlirAttribute
ttmlirTTLayoutAttrGet(MlirContext ctx, MlirAffineMap linear, unsigned oobVal,
Expand Down
12 changes: 12 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,16 @@ def TT_BufferAccess : I32BitEnumAttr<"BufferAccess", "TT Buffer Access",
let cppNamespace = "::mlir::tt";
}

def TT_CPURoleHost : I32EnumAttrCase<"Host", 0, "host">;
def TT_CPURoleDevice : I32EnumAttrCase<"Device", 1, "device">;

def TT_CPURole : I32EnumAttr<"CPURole", "TT CPU Role",
[
TT_CPURoleHost,
TT_CPURoleDevice,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt";
}

#endif
21 changes: 19 additions & 2 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,22 @@ def TT_ChipDescAttr : TT_Attr<"ChipDesc", "chip_desc"> {
}];
}

def TT_CPURoleAttr : EnumAttr<TT_Dialect, TT_CPURole, "cpu_role"> {
let assemblyFormat = "$value";
}

def TT_CPUDescAttr : TT_Attr<"CPUDesc", "cpu_desc"> {
let summary = "TT cpu_desc attribute";
let description = [{
TT cpu_desc attribute
}];

let parameters = (ins "CPURole":$role,
"StringAttr":$target_triple);
let assemblyFormat = [{`{` `role` `=` $role `,`
`target_triple` `=` $target_triple `}`}];
}

def TT_ChipCoordAttr : TT_Attr<"ChipCoord", "chip_coord"> {
let summary = "TT chip_coord attribute";
let description = [{
Expand Down Expand Up @@ -177,12 +193,13 @@ def TT_SystemDescAttr : TT_Attr<"SystemDesc", "system_desc"> {
TT system_desc attribute
}];

let parameters = (ins ArrayRefParameter<"ChipDescAttr">:$chipDescs,
let parameters = (ins ArrayRefParameter<"CPUDescAttr">:$cpuDescs,
ArrayRefParameter<"ChipDescAttr">:$chipDescs,
ArrayRefParameter<"unsigned">:$chipDescIndices,
ArrayRefParameter<"ChipCapabilityAttr">:$chipCapabilities,
ArrayRefParameter<"ChipCoordAttr">:$chipCoords,
OptionalArrayRefParameter<"ChipChannelAttr">:$chipChannels);
let assemblyFormat = "`<` `[` $chipDescs `]` `,` `[` $chipDescIndices `]` `,` `[` $chipCapabilities `]` `,` `[` $chipCoords `]` (`,` `[` $chipChannels^ `]`)? `>`";
let assemblyFormat = "`<` `[` $cpuDescs `]` `,` `[` $chipDescs `]` `,` `[` $chipDescIndices `]` `,` `[` $chipCapabilities `]` `,` `[` $chipCoords `]` (`,` `[` $chipChannels^ `]`)? `>`";

let extraClassDeclaration = [{
static tt::SystemDescAttr getDefault(MLIRContext *context);
Expand Down
12 changes: 12 additions & 0 deletions include/ttmlir/Target/Common/types.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,19 @@ table ChipPhysicalCores {
eth_inactive: [Dim2d];
}

enum CPURole: uint8
{
Host = 0,
Device = 1,
}

table CPUDesc {
role: CPURole;
target_triple: string;
}

table SystemDesc {
cpu_descs: [CPUDesc];
chip_descs: [ChipDesc];
chip_desc_indices: [uint32];
chip_capabilities: [ChipCapability];
Expand Down
24 changes: 21 additions & 3 deletions include/ttmlir/Target/Utils/MLIRToFlatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,16 +293,34 @@ toFlatbuffer(FlatbufferObjectCache &cache, ChipDescAttr chipDesc) {
chipDesc.getNumCBs());
}

inline ::tt::target::CPURole toFlatbuffer(FlatbufferObjectCache &,
CPURole memLayout) {
switch (memLayout) {
case CPURole::Host:
return ::tt::target::CPURole::Host;
case CPURole::Device:
return ::tt::target::CPURole::Device;
}
}

inline flatbuffers::Offset<::tt::target::CPUDesc>
toFlatbuffer(FlatbufferObjectCache &cache, CPUDescAttr cpuDesc) {
return ::tt::target::CreateCPUDesc(
*cache.fbb, toFlatbuffer(cache, cpuDesc.getRole()),
cache.fbb->CreateString(cpuDesc.getTargetTriple().getValue().str()));
}

inline flatbuffers::Offset<::tt::target::SystemDesc>
toFlatbuffer(FlatbufferObjectCache &cache, SystemDescAttr systemDesc) {
auto cpuDescs = toFlatbuffer(cache, systemDesc.getCpuDescs());
auto chipDescs = toFlatbuffer(cache, systemDesc.getChipDescs());
auto chipDescIndices = toFlatbuffer(cache, systemDesc.getChipDescIndices());
auto chipCapabilities = toFlatbuffer(cache, systemDesc.getChipCapabilities());
auto chipCoords = toFlatbuffer(cache, systemDesc.getChipCoords());
auto chipChannels = toFlatbuffer(cache, systemDesc.getChipChannels());
return ::tt::target::CreateSystemDesc(*cache.fbb, chipDescs, chipDescIndices,
chipCapabilities, chipCoords,
chipChannels);
return ::tt::target::CreateSystemDesc(*cache.fbb, cpuDescs, chipDescs,
chipDescIndices, chipCapabilities,
chipCoords, chipChannels);
}

inline std::vector<::tt::target::Dim2dRange>
Expand Down
25 changes: 16 additions & 9 deletions lib/CAPI/TTAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,14 @@ MlirAttribute ttmlirTTChipChannelAttrGet(MlirContext ctx, unsigned deviceId0,
}

MlirAttribute ttmlirTTSystemDescAttrGet(
MlirContext ctx, MlirAttribute *chipDescs, size_t chipDescsSize,
unsigned *chipDescIndices, size_t chipDescIndicesSize,
MlirAttribute *chipCapabilities, size_t chipCapabilitiesSize,
MlirAttribute *chipCoords, size_t chipCoordsSize,
MlirAttribute *chipChannels, size_t chipChannelsSize) {
llvm::ArrayRef<MlirAttribute> chipDescsRef(chipDescs, chipDescsSize),
MlirContext ctx, MlirAttribute *cpuDescs, size_t cpuDescsSize,
MlirAttribute *chipDescs, size_t chipDescsSize, unsigned *chipDescIndices,
size_t chipDescIndicesSize, MlirAttribute *chipCapabilities,
size_t chipCapabilitiesSize, MlirAttribute *chipCoords,
size_t chipCoordsSize, MlirAttribute *chipChannels,
size_t chipChannelsSize) {
llvm::ArrayRef<MlirAttribute> cpuDescsRef(cpuDescs, cpuDescsSize),
chipDescsRef(chipDescs, chipDescsSize),
chipCapabilitiesRef(chipCapabilities, chipCapabilitiesSize),
chipCoordsRef(chipCoords, chipCoordsSize),
chipChannelsRef(chipChannels, chipChannelsSize);
Expand Down Expand Up @@ -107,9 +109,14 @@ MlirAttribute ttmlirTTSystemDescAttrGet(
mlir::cast<ChipChannelAttr>(unwrap(chipChannel)));
}

return wrap(SystemDescAttr::get(unwrap(ctx), chipDescsUnwrapped,
chipDescIndicesRef, chipCapabilitiesUnwrapped,
chipCoordsUnwrapped, chipChannelsUnwrapped));
std::vector<tt::CPUDescAttr> cpuDescsUnwrapped;
for (auto cpuDesc : cpuDescsRef) {
cpuDescsUnwrapped.push_back(mlir::cast<CPUDescAttr>(unwrap(cpuDesc)));
}

return wrap(SystemDescAttr::get(
unwrap(ctx), cpuDescsUnwrapped, chipDescsUnwrapped, chipDescIndicesRef,
chipCapabilitiesUnwrapped, chipCoordsUnwrapped, chipChannelsUnwrapped));
}

MlirAttribute ttmlirTTLayoutAttrGet(MlirContext ctx, MlirAffineMap linear,
Expand Down
28 changes: 26 additions & 2 deletions lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ mlir::tt::SystemDescAttr::getDefault(MLIRContext *context) {
}
return tt::SystemDescAttr::get(
context,
// CPU Descriptors
{tt::CPUDescAttr::get(
context, tt::CPURole::Host,
mlir::StringAttr::get(context, "x86_64-pc-linux-gnu"))},
// Chip Descriptors
{
tt::ChipDescAttr::get(
Expand Down Expand Up @@ -123,13 +127,33 @@ mlir::tt::SystemDescAttr::getFromPath(MLIRContext *context, std::string &path) {
// Read relevant information from binary
auto const *binary_system_desc =
::tt::target::GetSizePrefixedSystemDescRoot(buffer.get())->system_desc();
auto const *binary_cpu_desc = binary_system_desc->cpu_descs();
auto const *binary_chip_desc = binary_system_desc->chip_descs();
auto const *binary_chip_desc_indices =
binary_system_desc->chip_desc_indices();
auto const *chip_capabilities = binary_system_desc->chip_capabilities();
auto const *binary_chip_coords = binary_system_desc->chip_coords();
auto const *chip_channel_connections = binary_system_desc->chip_channels();

// Acquire cpu descs
std::vector<tt::CPUDescAttr> cpu_desc_list;
for (auto const *element : *binary_cpu_desc) {
static_assert(static_cast<std::underlying_type_t<::tt::target::CPURole>>(
::mlir::tt::CPURole::Device) ==
static_cast<std::underlying_type_t<::tt::target::CPURole>>(
::tt::target::CPURole::Device));
static_assert(static_cast<std::underlying_type_t<::tt::target::CPURole>>(
::mlir::tt::CPURole::Host) ==
static_cast<std::underlying_type_t<::tt::target::CPURole>>(
::tt::target::CPURole::Host));
const auto *flatbufferTargetTripleString = element->target_triple();
cpu_desc_list.emplace_back(tt::CPUDescAttr::get(
context, static_cast<mlir::tt::CPURole>(element->role()),
mlir::StringAttr::get(
context, std::string(flatbufferTargetTripleString->c_str(),
flatbufferTargetTripleString->size()))));
}

// Acquire chip descs
std::vector<tt::ChipDescAttr> chip_desc_list;
for (auto const *element : *binary_chip_desc) {
Expand Down Expand Up @@ -299,8 +323,8 @@ mlir::tt::SystemDescAttr::getFromPath(MLIRContext *context, std::string &path) {

// Generate system desc attribute
auto system_desc_attr = tt::SystemDescAttr::get(
context, chip_desc_list, chip_indices_list, chip_capabilities_list,
chip_coordinate_list, chip_channel_list);
context, cpu_desc_list, chip_desc_list, chip_indices_list,
chip_capabilities_list, chip_coordinate_list, chip_channel_list);

return system_desc_attr;
}
Expand Down
12 changes: 9 additions & 3 deletions python/TTModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ void populateTTModule(py::module &m) {
})
.def_static(
"get",
[](MlirContext ctx, std::vector<MlirAttribute> chipDescs,
[](MlirContext ctx, std::vector<MlirAttribute> cpuDescs,
std::vector<MlirAttribute> chipDescs,
std::vector<unsigned> chipDescIndices,
std::vector<MlirAttribute> chipCapabilities,
std::vector<MlirAttribute> chipCoords,
Expand All @@ -301,9 +302,14 @@ void populateTTModule(py::module &m) {
chipChannelsUnwrapped.push_back(
mlir::cast<tt::ChipChannelAttr>(unwrap(chipChannel)));
}
std::vector<tt::CPUDescAttr> cpuDescsUnwrapped;
for (const auto &cpuDesc : cpuDescs) {
cpuDescsUnwrapped.push_back(
mlir::cast<tt::CPUDescAttr>(unwrap(cpuDesc)));
}
return wrap(tt::SystemDescAttr::get(
unwrap(ctx), chipDescsUnwrapped, chipDescIndices,
chipCapabilitiesUnwrapped, chipCoordsUnwrapped,
unwrap(ctx), cpuDescsUnwrapped, chipDescsUnwrapped,
chipDescIndices, chipCapabilitiesUnwrapped, chipCoordsUnwrapped,
chipChannelsUnwrapped));
})
.def_property_readonly(
Expand Down
27 changes: 27 additions & 0 deletions runtime/lib/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,30 @@
#credit: https://github.com/google/libcxx/blob/master/cmake/Modules/GetTriple.cmake
# Get the architecture.
set(arch ${CMAKE_SYSTEM_PROCESSOR})
if (arch STREQUAL "x86")
set(arch "i686")
endif()
# Get the vendor.
if (${CMAKE_SYSTEM_NAME} STREQUAL "Darwin")
set(vendor "apple")
else()
set(vendor "pc")
endif()
# Get os.
if (${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
set(os "win32")
else()
string(TOLOWER ${CMAKE_SYSTEM_NAME} os)
endif()
set(triple "${arch}-${vendor}-${os}")
set(${out} ${triple} PARENT_SCOPE)
set(${out_arch} ${arch} PARENT_SCOPE)
set(${out_vendor} ${vendor} PARENT_SCOPE)
set(${out_os} ${os} PARENT_SCOPE)
message(STATUS "Target triple: ${triple}")

add_definitions(-DTARGET_TRIPLE="${triple}")

add_library(TTRuntimeSysDesc STATIC system_desc.cpp)
set_property(TARGET TTRuntimeSysDesc PROPERTY CXX_STANDARD 20)
target_include_directories(TTRuntimeSysDesc
Expand Down
12 changes: 9 additions & 3 deletions runtime/lib/common/system_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ static std::unique_ptr<::tt::runtime::SystemDesc> getCurrentSystemDescImpl(

auto dramUnreservedEnd = calculateDRAMUnreservedEnd(device);

chipDescs.push_back(::tt::target::CreateChipDesc(
chipDescs.emplace_back(::tt::target::CreateChipDesc(
fbb, toFlatbuffer(device->arch()), &deviceGrid,
device->l1_size_per_core(), device->num_dram_channels(),
device->dram_size_per_channel(), L1_ALIGNMENT, PCIE_ALIGNMENT,
Expand All @@ -248,10 +248,16 @@ static std::unique_ptr<::tt::runtime::SystemDesc> getCurrentSystemDescImpl(
// Extract chip connected channels
std::vector<::tt::target::ChipChannel> allConnections =
getAllDeviceConnections(devices);
// Store CPUDesc
std::vector<::flatbuffers::Offset<tt::target::CPUDesc>> cpuDescs;
cpuDescs.emplace_back(::tt::target::CreateCPUDesc(
fbb, ::tt::target::CPURole::Host,
fbb.CreateString(std::string(TARGET_TRIPLE))));

// Create SystemDesc
auto systemDesc = ::tt::target::CreateSystemDescDirect(
fbb, &chipDescs, &chipDescIndices, &chipCapabilities, &chipCoords,
&allConnections);
fbb, &cpuDescs, &chipDescs, &chipDescIndices, &chipCapabilities,
&chipCoords, &allConnections);
::ttmlir::Version ttmlirVersion = ::ttmlir::getVersion();
::tt::target::Version version(ttmlirVersion.major, ttmlirVersion.minor,
ttmlirVersion.patch);
Expand Down
5 changes: 2 additions & 3 deletions test/ttmlir/Dialect/TTNN/optimizer/test_grid_set.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
// RUN: ttmlir-opt --ttnn-optimizer %s | FileCheck %s
// RUN: ttmlir-opt --ttir-load-system-desc --ttnn-optimizer %s | FileCheck %s
#device = #tt.device<workerGrid = #tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]>
#dram = #tt.memory_space<dram>
#system = #tt.memory_space<system>
#system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [<f32>, <f16>, <bf16>, <bfp_f8>, <bfp_bf8>, <bfp_f4>, <bfp_bf4>, <bfp_f2>, <bfp_bf2>, <u32>, <u16>, <u8>], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]>
#layout = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #system>>
#layout1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, interleaved>
#layout2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #dram>, interleaved>
module attributes {tt.device = #device, tt.system_desc = #system_desc} {
module attributes {tt.device = #device} {
func.func @forward(%arg0: tensor<64x128xf32, #layout>, %arg1: tensor<64x128xf32, #layout>) -> tensor<64x128xf32, #layout> {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<64x128>>>}> : (tensor<64x128xf32, #layout>, !tt.device<#device>) -> tensor<64x128xf32, #layout1>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
// RUN: ttmlir-opt --ttnn-optimizer="memory-layout-analysis-enabled=true memreconfig-enabled=true insert-memreconfig=add_0_1_2=0 override-output-layout=add_1_2=1x1:dram:interleaved:row_major:f32" %s | FileCheck %s
// RUN: ttmlir-opt --ttir-load-system-desc --ttnn-optimizer="memory-layout-analysis-enabled=true memreconfig-enabled=true insert-memreconfig=add_0_1_2=0 override-output-layout=add_1_2=1x1:dram:interleaved:row_major:f32" %s | FileCheck %s
#device = #tt.device<workerGrid = #tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]>
#dram = #tt.memory_space<dram>
#system = #tt.memory_space<system>
#system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [<f32>, <f16>, <bf16>, <bfp_f8>, <bfp_bf8>, <bfp_f4>, <bfp_bf4>, <bfp_f2>, <bfp_bf2>, <u32>, <u16>, <u8>], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]>
#layout = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #system>>
#layout1 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #dram>, interleaved>
module attributes {tt.device = #device, tt.system_desc = #system_desc} {
module attributes {tt.device = #device} {
func.func @main(%arg0: tensor<1x32x32xf32, #layout>, %arg1: tensor<1x32x32xf32, #layout>, %arg2: tensor<1x32x32xf32, #layout>) -> tensor<1x32x32xf32, #layout> {
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #dram>, interleaved>
// CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, width_sharded>
Expand Down

0 comments on commit 3617528

Please sign in to comment.