diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index b4516ae1b..d6188a62d 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -39,6 +39,7 @@ #pragma clang diagnostic ignored "-Wlogical-op-parentheses" #pragma clang diagnostic ignored "-Wundefined-inline" #define FMT_HEADER_ONLY +#include "impl/device/device_mesh.hpp" #include "tt_metal/host_api.hpp" #pragma clang diagnostic pop diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index 705a0bd46..dfe98a7a9 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -13,6 +13,10 @@ namespace tt::runtime { +namespace system_desc { +std::pair getCurrentSystemDesc(); +} // namespace system_desc + DeviceRuntime getCurrentRuntime(); std::vector getAvailableRuntimes(); diff --git a/runtime/lib/CMakeLists.txt b/runtime/lib/CMakeLists.txt index 3ee448560..1792f24bf 100644 --- a/runtime/lib/CMakeLists.txt +++ b/runtime/lib/CMakeLists.txt @@ -15,14 +15,6 @@ else() endif() message(STATUS "Runtimes Enabled: TTNN[${TT_RUNTIME_ENABLE_TTNN}] TTMETAL[${TT_RUNTIME_ENABLE_TTMETAL}]") -add_library(TTRuntime STATIC runtime.cpp) -set_property(TARGET TTRuntime PROPERTY CXX_STANDARD 20) -if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTNN) - target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTNN) -endif() -if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTMETAL) - target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTMETAL) -endif() add_library(TTBinary STATIC binary.cpp) target_include_directories(TTBinary @@ -32,6 +24,21 @@ target_include_directories(TTBinary ) add_dependencies(TTBinary FBS_GENERATION) +if (TTMLIR_ENABLE_RUNTIME AND (TT_RUNTIME_ENABLE_TTNN OR TT_RUNTIME_ENABLE_TTMETAL)) + add_subdirectory(common) +else() + add_library(TTRuntimeSysDesc INTERFACE) +endif() + +add_library(TTRuntime STATIC runtime.cpp) +set_property(TARGET TTRuntime PROPERTY CXX_STANDARD 20) +if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTNN) + target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTNN) +endif() +if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTMETAL) + target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTMETAL) +endif() + target_include_directories(TTRuntime PUBLIC ${PROJECT_SOURCE_DIR}/runtime/include @@ -41,8 +48,9 @@ target_include_directories(TTRuntime target_link_libraries(TTRuntime PRIVATE TTBinary + TTRuntimeSysDesc TTRuntimeTTNN TTRuntimeTTMetal ) -add_dependencies(TTRuntime TTBinary FBS_GENERATION) +add_dependencies(TTRuntime TTBinary TTRuntimeSysDesc FBS_GENERATION) diff --git a/runtime/lib/common/CMakeLists.txt b/runtime/lib/common/CMakeLists.txt new file mode 100644 index 000000000..05c386393 --- /dev/null +++ b/runtime/lib/common/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library(TTRuntimeSysDesc STATIC system_desc.cpp) +set_property(TARGET TTRuntimeSysDesc PROPERTY CXX_STANDARD 20) +target_include_directories(TTRuntimeSysDesc + PUBLIC + ${PROJECT_SOURCE_DIR}/runtime/include + ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common +) +target_include_directories(TTRuntimeSysDesc PUBLIC "$") +add_dependencies(TTRuntimeSysDesc tt-metal FBS_GENERATION) diff --git a/runtime/lib/common/system_desc.cpp b/runtime/lib/common/system_desc.cpp new file mode 100644 index 000000000..e2ffc1c79 --- /dev/null +++ b/runtime/lib/common/system_desc.cpp @@ -0,0 +1,179 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +#include "tt/runtime/types.h" +#include "tt/runtime/utils.h" +#include "ttmlir/Target/TTNN/Target.h" +#include "ttmlir/Version.h" +#include +#include + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wctad-maybe-unsupported" +#pragma clang diagnostic ignored "-Wcovered-switch-default" +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wignored-qualifiers" +#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#pragma clang diagnostic ignored "-Wvla-extension" +#pragma clang diagnostic ignored "-Wsign-compare" +#pragma clang diagnostic ignored "-Wcast-qual" +#pragma clang diagnostic ignored "-Wdeprecated-this-capture" +#pragma clang diagnostic ignored "-Wnon-virtual-dtor" +#pragma clang diagnostic ignored "-Wsuggest-override" +#pragma clang diagnostic ignored "-Wgnu-anonymous-struct" +#pragma clang diagnostic ignored "-Wnested-anon-types" +#pragma clang diagnostic ignored "-Wreorder-ctor" +#pragma clang diagnostic ignored "-Wmismatched-tags" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-local-typedef" +#define FMT_HEADER_ONLY +#include "host_api.hpp" +#include "hostdevcommon/common_values.hpp" +#include "impl/device/device_mesh.hpp" +#pragma clang diagnostic pop + +namespace tt::runtime::system_desc { +static ::tt::target::Dim2d toFlatbuffer(const CoreCoord &coreCoord) { + return ::tt::target::Dim2d(coreCoord.y, coreCoord.x); +} + +static ::tt::target::Arch toFlatbuffer(::tt::ARCH arch) { + switch (arch) { + case ::tt::ARCH::GRAYSKULL: + return ::tt::target::Arch::Grayskull; + case ::tt::ARCH::WORMHOLE_B0: + return ::tt::target::Arch::Wormhole_b0; + case ::tt::ARCH::BLACKHOLE: + return ::tt::target::Arch::Blackhole; + default: + break; + } + + throw std::runtime_error("Unsupported arch"); +} + +static std::vector<::tt::target::ChipChannel> +getAllDeviceConnections(const vector<::tt::tt_metal::Device *> &devices) { + std::set> + connectionSet; + + auto addConnection = [&connectionSet]( + chip_id_t deviceId0, CoreCoord ethCoreCoord0, + chip_id_t deviceId1, CoreCoord ethCoreCoord1) { + if (deviceId0 > deviceId1) { + std::swap(deviceId0, deviceId1); + std::swap(ethCoreCoord0, ethCoreCoord1); + } + connectionSet.emplace(deviceId0, ethCoreCoord0, deviceId1, ethCoreCoord1); + }; + + for (const ::tt::tt_metal::Device *device : devices) { + std::unordered_set activeEthernetCores = + device->get_active_ethernet_cores(true); + for (const CoreCoord ðernetCore : activeEthernetCores) { + std::tuple connectedDevice = + device->get_connected_ethernet_core(ethernetCore); + addConnection(device->id(), ethernetCore, std::get<0>(connectedDevice), + std::get<1>(connectedDevice)); + } + } + + std::vector<::tt::target::ChipChannel> allConnections; + allConnections.resize(connectionSet.size()); + + std::transform( + connectionSet.begin(), connectionSet.end(), allConnections.begin(), + [](const std::tuple + &connection) { + return ::tt::target::ChipChannel( + std::get<0>(connection), toFlatbuffer(std::get<1>(connection)), + std::get<2>(connection), toFlatbuffer(std::get<3>(connection))); + }); + + return allConnections; +} + +static std::unique_ptr<::tt::runtime::SystemDesc> +getCurrentSystemDescImpl(const ::tt::tt_metal::DeviceMesh &deviceMesh) { + std::vector<::tt::tt_metal::Device *> devices = deviceMesh.get_devices(); + std::sort(devices.begin(), devices.end(), + [](const ::tt::tt_metal::Device *a, + const ::tt::tt_metal::Device *b) { return a->id() < b->id(); }); + + std::vector<::flatbuffers::Offset> chipDescs; + std::vector chipDescIndices; + std::vector<::tt::target::ChipCapability> chipCapabilities; + // Ignore for now + std::vector<::tt::target::ChipCoord> chipCoords = { + ::tt::target::ChipCoord(0, 0, 0, 0)}; + ::flatbuffers::FlatBufferBuilder fbb; + + for (const ::tt::tt_metal::Device *device : devices) { + // Construct chip descriptor + ::tt::target::Dim2d deviceGrid = + toFlatbuffer(device->compute_with_storage_grid_size()); + chipDescs.push_back(::tt::target::CreateChipDesc( + fbb, toFlatbuffer(device->arch()), &deviceGrid, + device->l1_size_per_core(), device->num_dram_channels(), + device->dram_size_per_channel(), L1_ALIGNMENT, PCIE_ALIGNMENT, + DRAM_ALIGNMENT)); + chipDescIndices.push_back(device->id()); + // Derive chip capability + ::tt::target::ChipCapability chipCapability = + ::tt::target::ChipCapability::NONE; + if (device->is_mmio_capable()) { + chipCapability = chipCapability | ::tt::target::ChipCapability::PCIE | + ::tt::target::ChipCapability::HostMMIO; + } + chipCapabilities.push_back(chipCapability); + } + // Extract chip connected channels + std::vector<::tt::target::ChipChannel> allConnections = + getAllDeviceConnections(devices); + // Create SystemDesc + auto systemDesc = ::tt::target::CreateSystemDescDirect( + fbb, &chipDescs, &chipDescIndices, &chipCapabilities, &chipCoords, + &allConnections); + ::ttmlir::Version ttmlirVersion = ::ttmlir::getVersion(); + ::tt::target::Version version(ttmlirVersion.major, ttmlirVersion.minor, + ttmlirVersion.patch); + auto root = ::tt::target::CreateSystemDescRootDirect( + fbb, &version, ::ttmlir::getGitHash(), "unknown", systemDesc); + ::tt::target::FinishSizePrefixedSystemDescRootBuffer(fbb, root); + ::flatbuffers::Verifier verifier(fbb.GetBufferPointer(), fbb.GetSize()); + if (not ::tt::target::VerifySizePrefixedSystemDescRootBuffer(verifier)) { + throw std::runtime_error("Failed to verify system desc root buffer"); + } + uint8_t *buf = fbb.GetBufferPointer(); + auto size = fbb.GetSize(); + auto handle = ::tt::runtime::utils::malloc_shared(size); + std::memcpy(handle.get(), buf, size); + return std::make_unique<::tt::runtime::SystemDesc>(handle); +} + +std::pair<::tt::runtime::SystemDesc, DeviceIds> getCurrentSystemDesc() { + size_t numDevices = ::tt::tt_metal::GetNumAvailableDevices(); + size_t numPciDevices = ::tt::tt_metal::GetNumPCIeDevices(); + TT_FATAL(numDevices % numPciDevices == 0, + "Unexpected non-rectangular grid of devices"); + std::vector deviceIds(numDevices); + std::iota(deviceIds.begin(), deviceIds.end(), 0); + ::tt::tt_metal::DeviceGrid grid = + std::make_pair(numDevices / numPciDevices, numPciDevices); + ::tt::tt_metal::DeviceMesh deviceMesh = ::tt::tt_metal::DeviceMesh( + grid, deviceIds, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1); + std::exception_ptr eptr = nullptr; + std::unique_ptr<::tt::runtime::SystemDesc> desc; + try { + desc = getCurrentSystemDescImpl(deviceMesh); + } catch (...) { + eptr = std::current_exception(); + } + deviceMesh.close_devices(); + if (eptr) { + std::rethrow_exception(eptr); + } + return std::make_pair(*desc, deviceIds); +} + +} // namespace tt::runtime::system_desc diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 84b2523fa..3e605434e 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -4,6 +4,7 @@ #include "tt/runtime/runtime.h" #include "tt/runtime/utils.h" +#include "ttmlir/Target/TTNN/Target.h" #include "ttmlir/Version.h" #if defined(TT_RUNTIME_ENABLE_TTNN) @@ -77,16 +78,8 @@ void setCompatibleRuntime(const Binary &binary) { } std::pair getCurrentSystemDesc() { -#if defined(TT_RUNTIME_ENABLE_TTNN) - if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getCurrentSystemDesc(); - } -#endif - -#if defined(TT_RUNTIME_ENABLE_TTMETAL) - if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getCurrentSystemDesc(); - } +#if defined(TT_RUNTIME_ENABLE_TTNN) || defined(TT_RUNTIME_ENABLE_TTMETAL) + return system_desc::getCurrentSystemDesc(); #endif throw std::runtime_error("runtime is not enabled"); } diff --git a/runtime/lib/ttmetal/CMakeLists.txt b/runtime/lib/ttmetal/CMakeLists.txt index 3c2488049..f36f25d2f 100644 --- a/runtime/lib/ttmetal/CMakeLists.txt +++ b/runtime/lib/ttmetal/CMakeLists.txt @@ -10,4 +10,5 @@ target_include_directories(TTRuntimeTTMetal PUBLIC ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common ) target_include_directories(TTRuntimeTTMetal PUBLIC "$") -add_dependencies(TTRuntimeTTMetal tt-metal FBS_GENERATION) +target_link_libraries(TTRuntimeTTMetal PUBLIC TTMETAL_LIBRARY) +add_dependencies(TTRuntimeTTMetal TTMETAL_LIBRARY tt-metal FBS_GENERATION) diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 76513521e..39fce01de 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -19,25 +19,6 @@ using DeviceMesh = std::vector<::tt::tt_metal::Device *>; using MetalTensor = std::variant>; -static ::tt::target::Arch toFlatbuffer(::tt::ARCH arch) { - switch (arch) { - case ::tt::ARCH::GRAYSKULL: - return ::tt::target::Arch::Grayskull; - case ::tt::ARCH::WORMHOLE_B0: - return ::tt::target::Arch::Wormhole_b0; - case ::tt::ARCH::BLACKHOLE: - return ::tt::target::Arch::Blackhole; - default: - break; - } - - throw std::runtime_error("Unsupported arch"); -} - -static ::tt::target::Dim2d toFlatbuffer(CoreCoord coreCoord) { - return ::tt::target::Dim2d(coreCoord.y, coreCoord.x); -} - static ::tt::target::metal::TTMetalBinary const *getBinary(Flatbuffer binary) { bool isTTMetal = ::tt::target::metal::SizePrefixedTTMetalBinaryBufferHasIdentifier( @@ -48,55 +29,6 @@ static ::tt::target::metal::TTMetalBinary const *getBinary(Flatbuffer binary) { return ::tt::target::metal::GetSizePrefixedTTMetalBinary(binary.handle.get()); } -std::pair getCurrentSystemDesc() { - ::tt::tt_metal::Device *device = ::tt::tt_metal::CreateDevice(0); - std::vector chipIds = { - device->id(), - }; - ::flatbuffers::FlatBufferBuilder fbb; - ::ttmlir::Version ttmlirVersion = ::ttmlir::getVersion(); - ::tt::target::Version version(ttmlirVersion.major, ttmlirVersion.minor, - ttmlirVersion.patch); - ::tt::target::Dim2d deviceGrid = - toFlatbuffer(device->compute_with_storage_grid_size()); - std::vector<::flatbuffers::Offset> chipDescs = { - ::tt::target::CreateChipDesc( - fbb, toFlatbuffer(device->arch()), &deviceGrid, (1 << 20), 12, - (1 << 20), L1_ALIGNMENT, PCIE_ALIGNMENT, DRAM_ALIGNMENT), - }; - std::vector chipDescIndices = { - 0, - }; - ::tt::target::ChipCapability chipCapability = - ::tt::target::ChipCapability::PCIE; - if (device->is_mmio_capable()) { - chipCapability = chipCapability | ::tt::target::ChipCapability::HostMMIO; - } - std::vector<::tt::target::ChipCapability> chipCapabilities = { - chipCapability, - }; - std::vector<::tt::target::ChipCoord> chipCoord = { - ::tt::target::ChipCoord(0, 0, 0, 0), - }; - std::vector<::tt::target::ChipChannel> chipChannel; - auto systemDesc = ::tt::target::CreateSystemDescDirect( - fbb, &chipDescs, &chipDescIndices, &chipCapabilities, &chipCoord, - &chipChannel); - auto root = ::tt::target::CreateSystemDescRootDirect( - fbb, &version, ::ttmlir::getGitHash(), "unknown", systemDesc); - ::tt::target::FinishSizePrefixedSystemDescRootBuffer(fbb, root); - ::flatbuffers::Verifier verifier(fbb.GetBufferPointer(), fbb.GetSize()); - if (not ::tt::target::VerifySizePrefixedSystemDescRootBuffer(verifier)) { - throw std::runtime_error("Failed to verify system desc root buffer"); - } - uint8_t *buf = fbb.GetBufferPointer(); - auto size = fbb.GetSize(); - auto handle = utils::malloc_shared(size); - std::memcpy(handle.get(), buf, size); - ::tt::tt_metal::CloseDevice(device); - return std::make_pair(SystemDesc(handle), chipIds); -} - Tensor createTensor(std::shared_ptr data, std::vector const &shape, std::vector const &stride, diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index d7deba86b..9f6f66202 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -2,160 +2,14 @@ // // SPDX-License-Identifier: Apache-2.0 #include "tt/runtime/runtime.h" -#include "hostdevcommon/common_values.hpp" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/utils.h" -#include "utils.h" -#include "impl/device/device_mesh.hpp" -#include "hostdevcommon/common_values.hpp" - #include "ttmlir/Target/TTNN/Target.h" #include "ttmlir/Version.h" +#include "utils.h" namespace tt::runtime::ttnn { -static ::tt::target::Arch toFlatbuffer(::tt::ARCH arch) { - switch (arch) { - case ::tt::ARCH::GRAYSKULL: - return ::tt::target::Arch::Grayskull; - case ::tt::ARCH::WORMHOLE_B0: - return ::tt::target::Arch::Wormhole_b0; - case ::tt::ARCH::BLACKHOLE: - return ::tt::target::Arch::Blackhole; - default: - break; - } - - throw std::runtime_error("Unsupported arch"); -} - -static ::tt::target::Dim2d toFlatbuffer(const CoreCoord &coreCoord) { - return ::tt::target::Dim2d(coreCoord.y, coreCoord.x); -} - -static std::vector<::tt::target::ChipChannel> -getAllDeviceConnections(const vector<::ttnn::Device *> &devices) { - std::set> - connectionSet; - - auto addConnection = [&connectionSet]( - chip_id_t deviceId0, CoreCoord ethCoreCoord0, - chip_id_t deviceId1, CoreCoord ethCoreCoord1) { - if (deviceId0 > deviceId1) { - std::swap(deviceId0, deviceId1); - std::swap(ethCoreCoord0, ethCoreCoord1); - } - connectionSet.emplace(deviceId0, ethCoreCoord0, deviceId1, ethCoreCoord1); - }; - - for (const ::ttnn::Device *device : devices) { - std::unordered_set activeEthernetCores = - device->get_active_ethernet_cores(true); - for (const CoreCoord ðernetCore : activeEthernetCores) { - std::tuple connectedDevice = - device->get_connected_ethernet_core(ethernetCore); - addConnection(device->id(), ethernetCore, std::get<0>(connectedDevice), - std::get<1>(connectedDevice)); - } - } - - std::vector<::tt::target::ChipChannel> allConnections; - allConnections.resize(connectionSet.size()); - - std::transform( - connectionSet.begin(), connectionSet.end(), allConnections.begin(), - [](const std::tuple - &connection) { - return ::tt::target::ChipChannel( - std::get<0>(connection), toFlatbuffer(std::get<1>(connection)), - std::get<2>(connection), toFlatbuffer(std::get<3>(connection))); - }); - - return allConnections; -} - -static std::unique_ptr -getCurrentSystemDescImpl(const ::tt::tt_metal::DeviceMesh &deviceMesh) { - std::vector<::ttnn::Device *> devices = deviceMesh.get_devices(); - std::sort(devices.begin(), devices.end(), - [](const ::ttnn::Device *a, const ::ttnn::Device *b) { - return a->id() < b->id(); - }); - - std::vector<::flatbuffers::Offset> chipDescs; - std::vector chipDescIndices; - std::vector<::tt::target::ChipCapability> chipCapabilities; - // Ignore for now - std::vector<::tt::target::ChipCoord> chipCoords = { - ::tt::target::ChipCoord(0, 0, 0, 0)}; - ::flatbuffers::FlatBufferBuilder fbb; - - for (const ::ttnn::Device *device : devices) { - // Construct chip descriptor - ::tt::target::Dim2d deviceGrid = - toFlatbuffer(device->compute_with_storage_grid_size()); - chipDescs.push_back(::tt::target::CreateChipDesc( - fbb, toFlatbuffer(device->arch()), &deviceGrid, - device->l1_size_per_core(), device->num_dram_channels(), - device->dram_size_per_channel(), L1_ALIGNMENT, PCIE_ALIGNMENT, - DRAM_ALIGNMENT)); - chipDescIndices.push_back(device->id()); - // Derive chip capability - ::tt::target::ChipCapability chipCapability = - ::tt::target::ChipCapability::NONE; - if (device->is_mmio_capable()) { - chipCapability = chipCapability | ::tt::target::ChipCapability::PCIE | - ::tt::target::ChipCapability::HostMMIO; - } - chipCapabilities.push_back(chipCapability); - } - // Extract chip connected channels - std::vector<::tt::target::ChipChannel> allConnections = - getAllDeviceConnections(devices); - // Create SystemDesc - auto systemDesc = ::tt::target::CreateSystemDescDirect( - fbb, &chipDescs, &chipDescIndices, &chipCapabilities, &chipCoords, - &allConnections); - ::ttmlir::Version ttmlirVersion = ::ttmlir::getVersion(); - ::tt::target::Version version(ttmlirVersion.major, ttmlirVersion.minor, - ttmlirVersion.patch); - auto root = ::tt::target::CreateSystemDescRootDirect( - fbb, &version, ::ttmlir::getGitHash(), "unknown", systemDesc); - ::tt::target::FinishSizePrefixedSystemDescRootBuffer(fbb, root); - ::flatbuffers::Verifier verifier(fbb.GetBufferPointer(), fbb.GetSize()); - if (not ::tt::target::VerifySizePrefixedSystemDescRootBuffer(verifier)) { - throw std::runtime_error("Failed to verify system desc root buffer"); - } - uint8_t *buf = fbb.GetBufferPointer(); - auto size = fbb.GetSize(); - auto handle = ::tt::runtime::utils::malloc_shared(size); - std::memcpy(handle.get(), buf, size); - return std::make_unique(handle); -} - -std::pair getCurrentSystemDesc() { - size_t numDevices = ::tt::tt_metal::GetNumAvailableDevices(); - size_t numPciDevices = ::tt::tt_metal::GetNumPCIeDevices(); - TT_FATAL(numDevices % numPciDevices == 0, - "Unexpected non-rectangular grid of devices"); - std::vector deviceIds(numDevices); - std::iota(deviceIds.begin(), deviceIds.end(), 0); - ::tt::tt_metal::DeviceGrid deviceGrid = std::make_pair(numDevices / numPciDevices, numPciDevices); - ::tt::tt_metal::DeviceMesh deviceMesh(deviceGrid, deviceIds, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1); - std::exception_ptr eptr = nullptr; - std::unique_ptr desc; - try { - desc = getCurrentSystemDescImpl(deviceMesh); - } catch (...) { - eptr = std::current_exception(); - } - deviceMesh.close_devices(); - if (eptr) { - std::rethrow_exception(eptr); - } - return std::make_pair(*desc, deviceIds); -} - template static BorrowedStorage createStorage(void *ptr, std::uint32_t numElements) { return BorrowedStorage( diff --git a/runtime/test/CMakeLists.txt b/runtime/test/CMakeLists.txt index e3a48c925..20fbdf79b 100644 --- a/runtime/test/CMakeLists.txt +++ b/runtime/test/CMakeLists.txt @@ -1,3 +1,7 @@ +if (NOT TTMLIR_ENABLE_RUNTIME OR (NOT TT_RUNTIME_ENABLE_TTNN AND NOT TT_RUNTIME_ENABLE_TTMETAL)) + message(FATAL_ERROR "Runtime tests require -DTTMLIR_ENABLE_RUNTIME=ON and at least one backend runtime to be enabled") +endif() + enable_testing() include(FetchContent) FetchContent_Declare( @@ -44,10 +48,18 @@ target_link_libraries(TTRuntimeTEST INTERFACE function(add_runtime_gtest test_name) add_executable(${test_name} ${ARGN}) + set_property(TARGET ${test_name} PROPERTY CXX_STANDARD 20) add_dependencies(${test_name} TTRuntimeTEST) target_link_libraries(${test_name} PRIVATE TTRuntimeTEST) gtest_discover_tests(${test_name}) endfunction() -add_subdirectory(ttnn) -add_subdirectory(ttmetal) +add_subdirectory(common) + +if (TT_RUNTIME_ENABLE_TTNN) + add_subdirectory(ttnn) +endif() + +if (TT_RUNTIME_ENABLE_TTMETAL) + add_subdirectory(ttmetal) +endif() diff --git a/runtime/test/common/CMakeLists.txt b/runtime/test/common/CMakeLists.txt new file mode 100644 index 000000000..2443a19d4 --- /dev/null +++ b/runtime/test/common/CMakeLists.txt @@ -0,0 +1 @@ +add_runtime_gtest(sys_desc_sanity test_generate_sys_desc.cpp) diff --git a/runtime/test/ttnn/test_generate_sys_desc.cpp b/runtime/test/common/test_generate_sys_desc.cpp similarity index 65% rename from runtime/test/ttnn/test_generate_sys_desc.cpp rename to runtime/test/common/test_generate_sys_desc.cpp index df36d69ba..0d2e19f01 100644 --- a/runtime/test/ttnn/test_generate_sys_desc.cpp +++ b/runtime/test/common/test_generate_sys_desc.cpp @@ -1,12 +1,9 @@ // SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 -#ifndef TT_RUNTIME_ENABLE_TTNN -#error "TT_RUNTIME_ENABLE_TTNN must be defined" -#endif #include "tt/runtime/runtime.h" #include -TEST(TTNNSysDesc, Sanity) { +TEST(GenerateSysDesc, Sanity) { auto sysDesc = ::tt::runtime::getCurrentSystemDesc(); } diff --git a/runtime/test/ttnn/CMakeLists.txt b/runtime/test/ttnn/CMakeLists.txt index db2b75b7e..66106a110 100644 --- a/runtime/test/ttnn/CMakeLists.txt +++ b/runtime/test/ttnn/CMakeLists.txt @@ -1,2 +1 @@ add_runtime_gtest(subtract_test test_subtract.cpp) -add_runtime_gtest(sys_desc_sanity test_generate_sys_desc.cpp) diff --git a/runtime/tools/python/setup.py b/runtime/tools/python/setup.py index ce201d6e6..d9b09bf09 100644 --- a/runtime/tools/python/setup.py +++ b/runtime/tools/python/setup.py @@ -45,7 +45,7 @@ ] dylibs = [] -linklibs = ["TTBinary"] +linklibs = ["TTBinary", "TTRuntimeSysDesc"] if enable_ttnn: dylibs += ["_ttnn.so"] linklibs += ["TTRuntimeTTNN", ":_ttnn.so"] @@ -75,6 +75,7 @@ libraries=["TTRuntime"] + linklibs + ["flatbuffers"], library_dirs=[ f"{src_dir}/build/runtime/lib", + f"{src_dir}/build/runtime/lib/common", f"{src_dir}/build/runtime/lib/ttnn", f"{src_dir}/build/runtime/lib/ttmetal", f"{toolchain}/lib",