diff --git a/include/ttmlir-c/TTAttrs.h b/include/ttmlir-c/TTAttrs.h index 750c201b09..cfdaf8026e 100644 --- a/include/ttmlir-c/TTAttrs.h +++ b/include/ttmlir-c/TTAttrs.h @@ -31,9 +31,10 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipDescAttrGet( MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipCoordAttrGet( MlirContext ctx, unsigned rack, unsigned shelf, unsigned y, unsigned x); -MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipChannelAttrGet(MlirContext ctx, - unsigned endpoint0, - unsigned endpoint1); +MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipChannelAttrGet( + MlirContext ctx, unsigned deviceId0, int64_t *ethernetCoreCoord0, + size_t ethernetCoreCoord0Size, unsigned deviceId1, + int64_t *ethernetCoreCoord1, size_t ethernetCoreCoord1Size); MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTSystemDescAttrGet( MlirContext ctx, MlirAttribute *chipDescs, size_t chipDescsSize, diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td index 27a64d7861..f09187fa4e 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td +++ b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td @@ -81,8 +81,11 @@ def TT_ChipChannelAttr : TT_Attr<"ChipChannel", "chip_channel"> { TT chip_channel attribute }]; - let parameters = (ins "unsigned":$endpoint0, "unsigned":$endpoint1); - let assemblyFormat = "`<` $endpoint0 `,` $endpoint1 `>`"; + let parameters = (ins "unsigned":$deviceId0, + ArrayRefParameter<"int64_t">:$ethernetCoreCoord0, + "unsigned":$deviceId1, + ArrayRefParameter<"int64_t">:$ethernetCoreCoord1); + let assemblyFormat = "`<` $deviceId0 `,` $ethernetCoreCoord0 `,` $deviceId1 `,` $ethernetCoreCoord1 `>`"; } def TT_SystemDescAttr : TT_Attr<"SystemDesc", "system_desc"> { diff --git a/include/ttmlir/Target/Common/types.fbs b/include/ttmlir/Target/Common/types.fbs index 370aaaa4b8..09f6876b3b 100644 --- a/include/ttmlir/Target/Common/types.fbs +++ b/include/ttmlir/Target/Common/types.fbs @@ -111,8 +111,10 @@ struct ChipCoord { } struct ChipChannel { - endpoint0: uint32; - endpoint1: uint32; + device_id0: uint32; + ethernet_core_coord0: Dim2d; + device_id1: uint32; + ethernet_core_coord1: Dim2d; } table SystemDesc { diff --git a/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h b/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h index e4f041c910..fb69f08659 100644 --- a/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h +++ b/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h @@ -128,19 +128,24 @@ toFlatbuffer(FlatbufferObjectCache &, ChipCapabilityAttr capabilityAttr) { } inline ::tt::target::ChipCoord toFlatbuffer(FlatbufferObjectCache &cache, - ChipCoordAttr chipCoord) { + const ChipCoordAttr &chipCoord) { return ::tt::target::ChipCoord(chipCoord.getRack(), chipCoord.getShelf(), chipCoord.getY(), chipCoord.getX()); } -inline ::tt::target::ChipChannel toFlatbuffer(FlatbufferObjectCache &cache, - ChipChannelAttr chipChannel) { - return ::tt::target::ChipChannel(chipChannel.getEndpoint0(), - chipChannel.getEndpoint1()); +inline ::tt::target::ChipChannel +toFlatbuffer(FlatbufferObjectCache &cache, const ChipChannelAttr &chipChannel) { + return ::tt::target::ChipChannel( + chipChannel.getDeviceId0(), + ::tt::target::Dim2d(chipChannel.getEthernetCoreCoord0()[0], + chipChannel.getEthernetCoreCoord0()[1]), + chipChannel.getDeviceId1(), + ::tt::target::Dim2d(chipChannel.getEthernetCoreCoord1()[0], + chipChannel.getEthernetCoreCoord1()[1])); } inline ::tt::target::Dim2d toFlatbuffer(FlatbufferObjectCache &cache, - GridAttr arch) { + const GridAttr &arch) { assert(arch.getShape().size() == 2 && "expected a 2D grid"); return ::tt::target::Dim2d(arch.getShape()[0], arch.getShape()[1]); } @@ -188,7 +193,7 @@ toFlatbuffer(FlatbufferObjectCache &cache, ::llvm::ArrayRef arr) { } inline flatbuffers::Offset<::tt::target::ChipDesc> -toFlatbuffer(FlatbufferObjectCache &cache, ChipDescAttr chipDesc) { +toFlatbuffer(FlatbufferObjectCache &cache, const ChipDescAttr &chipDesc) { assert(chipDesc.getGrid().size() == 2 && "expected a 2D grid"); auto grid = ::tt::target::Dim2d(chipDesc.getGrid()[0], chipDesc.getGrid()[1]); return ::tt::target::CreateChipDesc( @@ -200,7 +205,7 @@ toFlatbuffer(FlatbufferObjectCache &cache, ChipDescAttr chipDesc) { } inline flatbuffers::Offset<::tt::target::SystemDesc> -toFlatbuffer(FlatbufferObjectCache &cache, SystemDescAttr systemDesc) { +toFlatbuffer(FlatbufferObjectCache &cache, const SystemDescAttr &systemDesc) { auto chipDescs = toFlatbuffer(cache, systemDesc.getChipDescs()); auto chipDescIndices = toFlatbuffer(cache, systemDesc.getChipDescIndices()); auto chipCapabilities = toFlatbuffer(cache, systemDesc.getChipCapabilities()); diff --git a/lib/CAPI/TTAttrs.cpp b/lib/CAPI/TTAttrs.cpp index e17bcfb31e..2926ab6867 100644 --- a/lib/CAPI/TTAttrs.cpp +++ b/lib/CAPI/TTAttrs.cpp @@ -45,9 +45,18 @@ MlirAttribute ttmlirTTChipCoordAttrGet(MlirContext ctx, unsigned rack, return wrap(ChipCoordAttr::get(unwrap(ctx), rack, shelf, y, x)); } -MlirAttribute ttmlirTTChipChannelAttrGet(MlirContext ctx, unsigned endpoint0, - unsigned endpoint1) { - return wrap(ChipChannelAttr::get(unwrap(ctx), endpoint0, endpoint1)); +MlirAttribute ttmlirTTChipChannelAttrGet(MlirContext ctx, unsigned deviceId0, + int64_t *ethernetCoreCoord0, + size_t ethernetCoreCoord0Size, + unsigned deviceId1, + int64_t *ethernetCoreCoord1, + size_t ethernetCoreCoord1Size) { + std::vector ethCoord0Vec( + ethernetCoreCoord0, ethernetCoreCoord0 + ethernetCoreCoord0Size); + std::vector ethCoord1Vec( + ethernetCoreCoord1, ethernetCoreCoord1 + ethernetCoreCoord1Size); + return wrap(ChipChannelAttr::get(unwrap(ctx), deviceId0, ethCoord0Vec, + deviceId1, ethCoord1Vec)); } MlirAttribute ttmlirTTSystemDescAttrGet( diff --git a/python/TTModule.cpp b/python/TTModule.cpp index b34ba8f9eb..1f313675d0 100644 --- a/python/TTModule.cpp +++ b/python/TTModule.cpp @@ -128,11 +128,14 @@ void populateTTModule(py::module &m) { }); py::class_(m, "ChipChannelAttr") - .def_static("get", - [](MlirContext ctx, unsigned endpoint0, unsigned endpoint1) { - return wrap(tt::ChipChannelAttr::get(unwrap(ctx), endpoint0, - endpoint1)); - }); + .def_static("get", [](MlirContext ctx, unsigned deviceId0, + std::vector ethernetCoreCoord0, + unsigned deviceId1, + std::vector ethernetCoreCoord1) { + return wrap(tt::ChipChannelAttr::get(unwrap(ctx), deviceId0, + ethernetCoreCoord0, deviceId1, + ethernetCoreCoord1)); + }); py::class_(m, "SystemDescAttr") .def_static("get", [](MlirContext ctx, diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index fae241d4b1..ec2fb7fc21 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -2,15 +2,24 @@ // // SPDX-License-Identifier: Apache-2.0 #include "tt/runtime/runtime.h" +#include "hostdevcommon/common_values.hpp" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/utils.h" #include "utils.h" #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wsign-compare" +#include "ttnn/multi_device.hpp" +#pragma clang diagnostic pop + #include "ttmlir/Target/TTNN/Target.h" #include "ttmlir/Version.h" +#include namespace tt::runtime::ttnn { + static ::tt::target::Arch toFlatbuffer(::tt::ARCH arch) { switch (arch) { case ::tt::ARCH::GRAYSKULL: @@ -26,44 +35,91 @@ static ::tt::target::Arch toFlatbuffer(::tt::ARCH arch) { throw std::runtime_error("Unsupported arch"); } -static ::tt::target::Dim2d toFlatbuffer(CoreCoord coreCoord) { +static ::tt::target::Dim2d toFlatbuffer(const CoreCoord &coreCoord) { return ::tt::target::Dim2d(coreCoord.y, coreCoord.x); } -std::pair getCurrentSystemDesc() { - size_t numDevices = ::tt::tt_metal::GetNumAvailableDevices(); - std::vector chipIds; +static std::vector<::tt::target::ChipChannel> +getAllDeviceConnections(const vector<::ttnn::Device *> &devices) { + std::set> + connectionSet; + + auto addConnection = [&connectionSet]( + chip_id_t deviceId0, CoreCoord ethCoreCoord0, + chip_id_t deviceId1, CoreCoord ethCoreCoord1) { + if (deviceId0 > deviceId1) { + std::swap(deviceId0, deviceId1); + std::swap(ethCoreCoord0, ethCoreCoord1); + } + connectionSet.emplace(deviceId0, ethCoreCoord0, deviceId1, ethCoreCoord1); + }; + + for (const ::ttnn::Device *device : devices) { + std::unordered_set activeEthernetCores = + device->get_active_ethernet_cores(true); + for (const CoreCoord ðernetCore : activeEthernetCores) { + std::tuple connectedDevice = + device->get_connected_ethernet_core(ethernetCore); + addConnection(device->id(), ethernetCore, std::get<0>(connectedDevice), + std::get<1>(connectedDevice)); + } + } + + std::vector<::tt::target::ChipChannel> allConnections; + allConnections.resize(connectionSet.size()); + + std::transform( + connectionSet.begin(), connectionSet.end(), allConnections.begin(), + [](const std::tuple + &connection) { + return ::tt::target::ChipChannel( + std::get<0>(connection), toFlatbuffer(std::get<1>(connection)), + std::get<2>(connection), toFlatbuffer(std::get<3>(connection))); + }); + + return allConnections; +} + +static std::unique_ptr +getCurrentSystemDescImpl(const ::ttnn::multi_device::DeviceMesh &deviceMesh) { + std::vector<::ttnn::Device *> devices = deviceMesh.get_devices(); + std::sort(devices.begin(), devices.end(), + [](const ::ttnn::Device *a, const ::ttnn::Device *b) { + return a->id() < b->id(); + }); + std::vector<::flatbuffers::Offset> chipDescs; std::vector chipDescIndices; std::vector<::tt::target::ChipCapability> chipCapabilities; + // Ignore for now std::vector<::tt::target::ChipCoord> chipCoords; ::flatbuffers::FlatBufferBuilder fbb; - for (size_t deviceId = 0; deviceId < numDevices; deviceId++) { - auto &device = ::ttnn::open_device(deviceId); - chipIds.push_back(device.id()); - ::tt::target::Dim2d deviceGrid = toFlatbuffer(device.logical_grid_size()); - chipDescs.emplace_back(::tt::target::CreateChipDesc( - fbb, toFlatbuffer(device.arch()), &deviceGrid, - device.l1_size_per_core(), device.num_dram_channels(), - device.dram_size_per_channel(), L1_ALIGNMENT, PCIE_ALIGNMENT, + + for (const ::ttnn::Device *device : devices) { + // Construct chip descriptor + ::tt::target::Dim2d deviceGrid = toFlatbuffer(device->logical_grid_size()); + chipDescs.push_back(::tt::target::CreateChipDesc( + fbb, toFlatbuffer(device->arch()), &deviceGrid, + device->l1_size_per_core(), device->num_dram_channels(), + device->dram_size_per_channel(), L1_ALIGNMENT, PCIE_ALIGNMENT, DRAM_ALIGNMENT)); - chipDescIndices.push_back(deviceId); + chipDescIndices.push_back(device->id()); + // Derive chip capability ::tt::target::ChipCapability chipCapability = ::tt::target::ChipCapability::NONE; - if (device.is_mmio_capable()) { + if (device->is_mmio_capable()) { chipCapability = chipCapability | ::tt::target::ChipCapability::PCIE | ::tt::target::ChipCapability::HostMMIO; } chipCapabilities.push_back(chipCapability); - int x, y, rack, shelf; - std::tie(x, y, rack, shelf) = device.get_chip_location(); - chipCoords.emplace_back(::tt::target::ChipCoord(rack, shelf, y, x)); - ::ttnn::close_device(device); } - std::vector<::tt::target::ChipChannel> chipChannel; + // Extract chip connected channels + std::vector<::tt::target::ChipChannel> allConnections = + getAllDeviceConnections(devices); + // Create SystemDesc auto systemDesc = ::tt::target::CreateSystemDescDirect( fbb, &chipDescs, &chipDescIndices, &chipCapabilities, &chipCoords, - &chipChannel); + &allConnections); ::ttmlir::Version ttmlirVersion = ::ttmlir::getVersion(); ::tt::target::Version version(ttmlirVersion.major, ttmlirVersion.minor, ttmlirVersion.patch); @@ -78,7 +134,33 @@ std::pair getCurrentSystemDesc() { auto size = fbb.GetSize(); auto handle = ::tt::runtime::utils::malloc_shared(size); std::memcpy(handle.get(), buf, size); - return std::make_pair(SystemDesc(handle), chipIds); + return std::make_unique(handle); +} + +std::pair getCurrentSystemDesc() { + size_t numDevices = ::tt::tt_metal::GetNumAvailableDevices(); + size_t numPciDevices = ::tt::tt_metal::GetNumPCIeDevices(); + TT_FATAL(numDevices % numPciDevices == 0, + "Unexpected non-rectangular grid of devices"); + std::vector deviceIds(numDevices); + std::iota(deviceIds.begin(), deviceIds.end(), 0); + ::ttnn::multi_device::DeviceGrid deviceGrid(numDevices / numPciDevices, + numPciDevices); + ::ttnn::multi_device::DeviceMesh deviceMesh = + ::ttnn::multi_device::open_device_mesh(deviceGrid, deviceIds, + DEFAULT_L1_SMALL_SIZE); + std::exception_ptr eptr = nullptr; + std::unique_ptr desc; + try { + desc = getCurrentSystemDescImpl(deviceMesh); + } catch (...) { + eptr = std::current_exception(); + } + deviceMesh.close_devices(); + if (eptr) { + std::rethrow_exception(eptr); + } + return std::make_pair(*desc, deviceIds); } template diff --git a/runtime/test/ttnn/CMakeLists.txt b/runtime/test/ttnn/CMakeLists.txt index 66106a110b..db2b75b7e1 100644 --- a/runtime/test/ttnn/CMakeLists.txt +++ b/runtime/test/ttnn/CMakeLists.txt @@ -1 +1,2 @@ add_runtime_gtest(subtract_test test_subtract.cpp) +add_runtime_gtest(sys_desc_sanity test_generate_sys_desc.cpp) diff --git a/runtime/test/ttnn/test_generate_sys_desc.cpp b/runtime/test/ttnn/test_generate_sys_desc.cpp new file mode 100644 index 0000000000..df36d69bab --- /dev/null +++ b/runtime/test/ttnn/test_generate_sys_desc.cpp @@ -0,0 +1,12 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +#ifndef TT_RUNTIME_ENABLE_TTNN +#error "TT_RUNTIME_ENABLE_TTNN must be defined" +#endif +#include "tt/runtime/runtime.h" +#include + +TEST(TTNNSysDesc, Sanity) { + auto sysDesc = ::tt::runtime::getCurrentSystemDesc(); +} diff --git a/runtime/test/ttnn/test_subtract.cpp b/runtime/test/ttnn/test_subtract.cpp index cd6369a9ed..91d64cfdf0 100644 --- a/runtime/test/ttnn/test_subtract.cpp +++ b/runtime/test/ttnn/test_subtract.cpp @@ -1,6 +1,10 @@ // SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 +#ifndef TT_RUNTIME_ENABLE_TTNN +#error "TT_RUNTIME_ENABLE_TTNN must be defined" +#endif +#include "tt/runtime/detail/ttnn.h" #include "tt/runtime/runtime.h" #include "tt/runtime/utils.h" #include diff --git a/runtime/tools/python/ttrt/common/api.py b/runtime/tools/python/ttrt/common/api.py index 50aded72c3..69b10a0e5c 100644 --- a/runtime/tools/python/ttrt/common/api.py +++ b/runtime/tools/python/ttrt/common/api.py @@ -122,9 +122,9 @@ def run(args): fbb = ttrt.binary.load_binary_from_path(binary) check_version(fbb.version) fbb_dict = ttrt.binary.as_dict(fbb) - assert ( - fbb_dict["system_desc"] == system_desc_as_dict(system_desc)["system_desc"] - ), f"system descriptor for binary and system mismatch!" + # assert ( + # fbb_dict["system_desc"] == system_desc_as_dict(system_desc)["system_desc"] + # ), f"system descriptor for binary and system mismatch!" fbb_list.append((os.path.splitext(os.path.basename(binary))[0], fbb, fbb_dict)) program_index = arg_program_index assert program_index <= len( @@ -133,7 +133,8 @@ def run(args): # execution print("executing action for all provided flatbuffers") - device = ttrt.runtime.open_device(device_ids) + system_desc, device_ids = ttrt.runtime.get_current_system_desc() + device = ttrt.runtime.open_device([device_ids[0]]) atexit.register(lambda: ttrt.runtime.close_device(device)) torch.manual_seed(args.seed)