Skip to content

Commit

Permalink
Commonize system descriptor API, update metal APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT committed Aug 10, 2024
1 parent a38f35e commit 259278e
Show file tree
Hide file tree
Showing 14 changed files with 233 additions and 242 deletions.
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#pragma clang diagnostic ignored "-Wlogical-op-parentheses"
#pragma clang diagnostic ignored "-Wundefined-inline"
#define FMT_HEADER_ONLY
#include "impl/device/device_mesh.hpp"
#include "tt_metal/host_api.hpp"
#pragma clang diagnostic pop

Expand Down
4 changes: 4 additions & 0 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

namespace tt::runtime {

namespace system_desc {
std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc();
} // namespace system_desc

DeviceRuntime getCurrentRuntime();

std::vector<DeviceRuntime> getAvailableRuntimes();
Expand Down
26 changes: 17 additions & 9 deletions runtime/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@ else()
endif()

message(STATUS "Runtimes Enabled: TTNN[${TT_RUNTIME_ENABLE_TTNN}] TTMETAL[${TT_RUNTIME_ENABLE_TTMETAL}]")
add_library(TTRuntime STATIC runtime.cpp)
set_property(TARGET TTRuntime PROPERTY CXX_STANDARD 20)
if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTNN)
target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTNN)
endif()
if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTMETAL)
target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTMETAL)
endif()

add_library(TTBinary STATIC binary.cpp)
target_include_directories(TTBinary
Expand All @@ -32,6 +24,21 @@ target_include_directories(TTBinary
)
add_dependencies(TTBinary FBS_GENERATION)

if (TTMLIR_ENABLE_RUNTIME AND (TT_RUNTIME_ENABLE_TTNN OR TT_RUNTIME_ENABLE_TTMETAL))
add_subdirectory(common)
else()
add_library(TTRuntimeSysDesc INTERFACE)
endif()

add_library(TTRuntime STATIC runtime.cpp)
set_property(TARGET TTRuntime PROPERTY CXX_STANDARD 20)
if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTNN)
target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTNN)
endif()
if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTMETAL)
target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTMETAL)
endif()

target_include_directories(TTRuntime
PUBLIC
${PROJECT_SOURCE_DIR}/runtime/include
Expand All @@ -41,8 +48,9 @@ target_include_directories(TTRuntime
target_link_libraries(TTRuntime
PRIVATE
TTBinary
TTRuntimeSysDesc
TTRuntimeTTNN
TTRuntimeTTMetal
)

add_dependencies(TTRuntime TTBinary FBS_GENERATION)
add_dependencies(TTRuntime TTBinary TTRuntimeSysDesc FBS_GENERATION)
9 changes: 9 additions & 0 deletions runtime/lib/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
add_library(TTRuntimeSysDesc STATIC system_desc.cpp)
set_property(TARGET TTRuntimeSysDesc PROPERTY CXX_STANDARD 20)
target_include_directories(TTRuntimeSysDesc
PUBLIC
${PROJECT_SOURCE_DIR}/runtime/include
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common
)
target_include_directories(TTRuntimeSysDesc PUBLIC "$<BUILD_INTERFACE:${TTMETAL_INCLUDE_DIRS}>")
add_dependencies(TTRuntimeSysDesc FBS_GENERATION)
179 changes: 179 additions & 0 deletions runtime/lib/common/system_desc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "tt/runtime/types.h"
#include "tt/runtime/utils.h"
#include "ttmlir/Target/TTNN/Target.h"
#include "ttmlir/Version.h"
#include <cstdint>
#include <vector>

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wctad-maybe-unsupported"
#pragma clang diagnostic ignored "-Wcovered-switch-default"
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wignored-qualifiers"
#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
#pragma clang diagnostic ignored "-Wvla-extension"
#pragma clang diagnostic ignored "-Wsign-compare"
#pragma clang diagnostic ignored "-Wcast-qual"
#pragma clang diagnostic ignored "-Wdeprecated-this-capture"
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
#pragma clang diagnostic ignored "-Wsuggest-override"
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
#pragma clang diagnostic ignored "-Wnested-anon-types"
#pragma clang diagnostic ignored "-Wreorder-ctor"
#pragma clang diagnostic ignored "-Wmismatched-tags"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-local-typedef"
#define FMT_HEADER_ONLY
#include "host_api.hpp"
#include "hostdevcommon/common_values.hpp"
#include "impl/device/device_mesh.hpp"
#pragma clang diagnostic pop

namespace tt::runtime::system_desc {
static ::tt::target::Dim2d toFlatbuffer(const CoreCoord &coreCoord) {
return ::tt::target::Dim2d(coreCoord.y, coreCoord.x);
}

static ::tt::target::Arch toFlatbuffer(::tt::ARCH arch) {
switch (arch) {
case ::tt::ARCH::GRAYSKULL:
return ::tt::target::Arch::Grayskull;
case ::tt::ARCH::WORMHOLE_B0:
return ::tt::target::Arch::Wormhole_b0;
case ::tt::ARCH::BLACKHOLE:
return ::tt::target::Arch::Blackhole;
default:
break;
}

throw std::runtime_error("Unsupported arch");
}

static std::vector<::tt::target::ChipChannel>
getAllDeviceConnections(const vector<::tt::tt_metal::Device *> &devices) {
std::set<std::tuple<chip_id_t, CoreCoord, chip_id_t, CoreCoord>>
connectionSet;

auto addConnection = [&connectionSet](
chip_id_t deviceId0, CoreCoord ethCoreCoord0,
chip_id_t deviceId1, CoreCoord ethCoreCoord1) {
if (deviceId0 > deviceId1) {
std::swap(deviceId0, deviceId1);
std::swap(ethCoreCoord0, ethCoreCoord1);
}
connectionSet.emplace(deviceId0, ethCoreCoord0, deviceId1, ethCoreCoord1);
};

for (const ::tt::tt_metal::Device *device : devices) {
std::unordered_set<CoreCoord> activeEthernetCores =
device->get_active_ethernet_cores(true);
for (const CoreCoord &ethernetCore : activeEthernetCores) {
std::tuple<chip_id_t, CoreCoord> connectedDevice =
device->get_connected_ethernet_core(ethernetCore);
addConnection(device->id(), ethernetCore, std::get<0>(connectedDevice),
std::get<1>(connectedDevice));
}
}

std::vector<::tt::target::ChipChannel> allConnections;
allConnections.resize(connectionSet.size());

std::transform(
connectionSet.begin(), connectionSet.end(), allConnections.begin(),
[](const std::tuple<chip_id_t, CoreCoord, chip_id_t, CoreCoord>
&connection) {
return ::tt::target::ChipChannel(
std::get<0>(connection), toFlatbuffer(std::get<1>(connection)),
std::get<2>(connection), toFlatbuffer(std::get<3>(connection)));
});

return allConnections;
}

static std::unique_ptr<::tt::runtime::SystemDesc>
getCurrentSystemDescImpl(const ::tt::tt_metal::DeviceMesh &deviceMesh) {
std::vector<::tt::tt_metal::Device *> devices = deviceMesh.get_devices();
std::sort(devices.begin(), devices.end(),
[](const ::tt::tt_metal::Device *a,
const ::tt::tt_metal::Device *b) { return a->id() < b->id(); });

std::vector<::flatbuffers::Offset<tt::target::ChipDesc>> chipDescs;
std::vector<uint32_t> chipDescIndices;
std::vector<::tt::target::ChipCapability> chipCapabilities;
// Ignore for now
std::vector<::tt::target::ChipCoord> chipCoords = {
::tt::target::ChipCoord(0, 0, 0, 0)};
::flatbuffers::FlatBufferBuilder fbb;

for (const ::tt::tt_metal::Device *device : devices) {
// Construct chip descriptor
::tt::target::Dim2d deviceGrid =
toFlatbuffer(device->compute_with_storage_grid_size());
chipDescs.push_back(::tt::target::CreateChipDesc(
fbb, toFlatbuffer(device->arch()), &deviceGrid,
device->l1_size_per_core(), device->num_dram_channels(),
device->dram_size_per_channel(), L1_ALIGNMENT, PCIE_ALIGNMENT,
DRAM_ALIGNMENT));
chipDescIndices.push_back(device->id());
// Derive chip capability
::tt::target::ChipCapability chipCapability =
::tt::target::ChipCapability::NONE;
if (device->is_mmio_capable()) {
chipCapability = chipCapability | ::tt::target::ChipCapability::PCIE |
::tt::target::ChipCapability::HostMMIO;
}
chipCapabilities.push_back(chipCapability);
}
// Extract chip connected channels
std::vector<::tt::target::ChipChannel> allConnections =
getAllDeviceConnections(devices);
// Create SystemDesc
auto systemDesc = ::tt::target::CreateSystemDescDirect(
fbb, &chipDescs, &chipDescIndices, &chipCapabilities, &chipCoords,
&allConnections);
::ttmlir::Version ttmlirVersion = ::ttmlir::getVersion();
::tt::target::Version version(ttmlirVersion.major, ttmlirVersion.minor,
ttmlirVersion.patch);
auto root = ::tt::target::CreateSystemDescRootDirect(
fbb, &version, ::ttmlir::getGitHash(), "unknown", systemDesc);
::tt::target::FinishSizePrefixedSystemDescRootBuffer(fbb, root);
::flatbuffers::Verifier verifier(fbb.GetBufferPointer(), fbb.GetSize());
if (not ::tt::target::VerifySizePrefixedSystemDescRootBuffer(verifier)) {
throw std::runtime_error("Failed to verify system desc root buffer");
}
uint8_t *buf = fbb.GetBufferPointer();
auto size = fbb.GetSize();
auto handle = ::tt::runtime::utils::malloc_shared(size);
std::memcpy(handle.get(), buf, size);
return std::make_unique<::tt::runtime::SystemDesc>(handle);
}

std::pair<::tt::runtime::SystemDesc, DeviceIds> getCurrentSystemDesc() {
size_t numDevices = ::tt::tt_metal::GetNumAvailableDevices();
size_t numPciDevices = ::tt::tt_metal::GetNumPCIeDevices();
TT_FATAL(numDevices % numPciDevices == 0,
"Unexpected non-rectangular grid of devices");
std::vector<chip_id_t> deviceIds(numDevices);
std::iota(deviceIds.begin(), deviceIds.end(), 0);
::tt::tt_metal::DeviceGrid grid =
std::make_pair(numDevices / numPciDevices, numPciDevices);
::tt::tt_metal::DeviceMesh deviceMesh = ::tt::tt_metal::DeviceMesh(
grid, deviceIds, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1);
std::exception_ptr eptr = nullptr;
std::unique_ptr<::tt::runtime::SystemDesc> desc;
try {
desc = getCurrentSystemDescImpl(deviceMesh);
} catch (...) {
eptr = std::current_exception();
}
deviceMesh.close_devices();
if (eptr) {
std::rethrow_exception(eptr);
}
return std::make_pair(*desc, deviceIds);
}

} // namespace tt::runtime::system_desc
13 changes: 3 additions & 10 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "tt/runtime/runtime.h"
#include "tt/runtime/utils.h"
#include "ttmlir/Target/TTNN/Target.h"
#include "ttmlir/Version.h"

#if defined(TT_RUNTIME_ENABLE_TTNN)
Expand Down Expand Up @@ -77,16 +78,8 @@ void setCompatibleRuntime(const Binary &binary) {
}

std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getCurrentSystemDesc();
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getCurrentSystemDesc();
}
#if defined(TT_RUNTIME_ENABLE_TTNN) || defined(TT_RUNTIME_ENABLE_TTMETAL)
return system_desc::getCurrentSystemDesc();
#endif
throw std::runtime_error("runtime is not enabled");
}
Expand Down
1 change: 1 addition & 0 deletions runtime/lib/ttmetal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ target_include_directories(TTRuntimeTTMetal PUBLIC
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common
)
target_include_directories(TTRuntimeTTMetal PUBLIC "$<BUILD_INTERFACE:${TTMETAL_INCLUDE_DIRS}>")
target_link_libraries(TTRuntimeTTMetal PUBLIC TTMETAL_LIBRARY)
add_dependencies(TTRuntimeTTMetal tt-metal FBS_GENERATION)
68 changes: 0 additions & 68 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,6 @@ using DeviceMesh = std::vector<::tt::tt_metal::Device *>;
using MetalTensor =
std::variant<TensorDesc, std::shared_ptr<::tt::tt_metal::Buffer>>;

static ::tt::target::Arch toFlatbuffer(::tt::ARCH arch) {
switch (arch) {
case ::tt::ARCH::GRAYSKULL:
return ::tt::target::Arch::Grayskull;
case ::tt::ARCH::WORMHOLE_B0:
return ::tt::target::Arch::Wormhole_b0;
case ::tt::ARCH::BLACKHOLE:
return ::tt::target::Arch::Blackhole;
default:
break;
}

throw std::runtime_error("Unsupported arch");
}

static ::tt::target::Dim2d toFlatbuffer(CoreCoord coreCoord) {
return ::tt::target::Dim2d(coreCoord.y, coreCoord.x);
}

static ::tt::target::metal::TTMetalBinary const *getBinary(Flatbuffer binary) {
bool isTTMetal =
::tt::target::metal::SizePrefixedTTMetalBinaryBufferHasIdentifier(
Expand All @@ -48,55 +29,6 @@ static ::tt::target::metal::TTMetalBinary const *getBinary(Flatbuffer binary) {
return ::tt::target::metal::GetSizePrefixedTTMetalBinary(binary.handle.get());
}

std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc() {
::tt::tt_metal::Device *device = ::tt::tt_metal::CreateDevice(0);
std::vector<int> chipIds = {
device->id(),
};
::flatbuffers::FlatBufferBuilder fbb;
::ttmlir::Version ttmlirVersion = ::ttmlir::getVersion();
::tt::target::Version version(ttmlirVersion.major, ttmlirVersion.minor,
ttmlirVersion.patch);
::tt::target::Dim2d deviceGrid =
toFlatbuffer(device->compute_with_storage_grid_size());
std::vector<::flatbuffers::Offset<tt::target::ChipDesc>> chipDescs = {
::tt::target::CreateChipDesc(
fbb, toFlatbuffer(device->arch()), &deviceGrid, (1 << 20), 12,
(1 << 20), L1_ALIGNMENT, PCIE_ALIGNMENT, DRAM_ALIGNMENT),
};
std::vector<uint32_t> chipDescIndices = {
0,
};
::tt::target::ChipCapability chipCapability =
::tt::target::ChipCapability::PCIE;
if (device->is_mmio_capable()) {
chipCapability = chipCapability | ::tt::target::ChipCapability::HostMMIO;
}
std::vector<::tt::target::ChipCapability> chipCapabilities = {
chipCapability,
};
std::vector<::tt::target::ChipCoord> chipCoord = {
::tt::target::ChipCoord(0, 0, 0, 0),
};
std::vector<::tt::target::ChipChannel> chipChannel;
auto systemDesc = ::tt::target::CreateSystemDescDirect(
fbb, &chipDescs, &chipDescIndices, &chipCapabilities, &chipCoord,
&chipChannel);
auto root = ::tt::target::CreateSystemDescRootDirect(
fbb, &version, ::ttmlir::getGitHash(), "unknown", systemDesc);
::tt::target::FinishSizePrefixedSystemDescRootBuffer(fbb, root);
::flatbuffers::Verifier verifier(fbb.GetBufferPointer(), fbb.GetSize());
if (not ::tt::target::VerifySizePrefixedSystemDescRootBuffer(verifier)) {
throw std::runtime_error("Failed to verify system desc root buffer");
}
uint8_t *buf = fbb.GetBufferPointer();
auto size = fbb.GetSize();
auto handle = utils::malloc_shared(size);
std::memcpy(handle.get(), buf, size);
::tt::tt_metal::CloseDevice(device);
return std::make_pair(SystemDesc(handle), chipIds);
}

Tensor createTensor(std::shared_ptr<void> data,
std::vector<std::uint32_t> const &shape,
std::vector<std::uint32_t> const &stride,
Expand Down
Loading

0 comments on commit 259278e

Please sign in to comment.