Skip to content

Commit

Permalink
#162: Support for multi-device system descriptor, commonize system de…
Browse files Browse the repository at this point in the history
…sc between backends (#220)
  • Loading branch information
jnie-TT authored Aug 10, 2024
1 parent f4fda50 commit 0b393ec
Show file tree
Hide file tree
Showing 21 changed files with 287 additions and 179 deletions.
7 changes: 4 additions & 3 deletions include/ttmlir-c/TTAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipDescAttrGet(
MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipCoordAttrGet(
MlirContext ctx, unsigned rack, unsigned shelf, unsigned y, unsigned x);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipChannelAttrGet(MlirContext ctx,
unsigned endpoint0,
unsigned endpoint1);
MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipChannelAttrGet(
MlirContext ctx, unsigned deviceId0, int64_t *ethernetCoreCoord0,
size_t ethernetCoreCoord0Size, unsigned deviceId1,
int64_t *ethernetCoreCoord1, size_t ethernetCoreCoord1Size);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTSystemDescAttrGet(
MlirContext ctx, MlirAttribute *chipDescs, size_t chipDescsSize,
Expand Down
7 changes: 5 additions & 2 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,11 @@ def TT_ChipChannelAttr : TT_Attr<"ChipChannel", "chip_channel"> {
TT chip_channel attribute
}];

let parameters = (ins "unsigned":$endpoint0, "unsigned":$endpoint1);
let assemblyFormat = "`<` $endpoint0 `,` $endpoint1 `>`";
let parameters = (ins "unsigned":$deviceId0,
ArrayRefParameter<"int64_t">:$ethernetCoreCoord0,
"unsigned":$deviceId1,
ArrayRefParameter<"int64_t">:$ethernetCoreCoord1);
let assemblyFormat = "`<` $deviceId0 `,` $ethernetCoreCoord0 `,` $deviceId1 `,` $ethernetCoreCoord1 `>`";
}

def TT_SystemDescAttr : TT_Attr<"SystemDesc", "system_desc"> {
Expand Down
6 changes: 4 additions & 2 deletions include/ttmlir/Target/Common/types.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@ struct ChipCoord {
}

struct ChipChannel {
endpoint0: uint32;
endpoint1: uint32;
device_id0: uint32;
ethernet_core_coord0: Dim2d;
device_id1: uint32;
ethernet_core_coord1: Dim2d;
}

table SystemDesc {
Expand Down
9 changes: 7 additions & 2 deletions include/ttmlir/Target/Utils/MLIRToFlatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,13 @@ inline ::tt::target::ChipCoord toFlatbuffer(FlatbufferObjectCache &cache,

inline ::tt::target::ChipChannel toFlatbuffer(FlatbufferObjectCache &cache,
ChipChannelAttr chipChannel) {
return ::tt::target::ChipChannel(chipChannel.getEndpoint0(),
chipChannel.getEndpoint1());
return ::tt::target::ChipChannel(
chipChannel.getDeviceId0(),
::tt::target::Dim2d(chipChannel.getEthernetCoreCoord0()[0],
chipChannel.getEthernetCoreCoord0()[1]),
chipChannel.getDeviceId1(),
::tt::target::Dim2d(chipChannel.getEthernetCoreCoord1()[0],
chipChannel.getEthernetCoreCoord1()[1]));
}

inline ::tt::target::Dim2d toFlatbuffer(FlatbufferObjectCache &cache,
Expand Down
15 changes: 12 additions & 3 deletions lib/CAPI/TTAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,18 @@ MlirAttribute ttmlirTTChipCoordAttrGet(MlirContext ctx, unsigned rack,
return wrap(ChipCoordAttr::get(unwrap(ctx), rack, shelf, y, x));
}

MlirAttribute ttmlirTTChipChannelAttrGet(MlirContext ctx, unsigned endpoint0,
unsigned endpoint1) {
return wrap(ChipChannelAttr::get(unwrap(ctx), endpoint0, endpoint1));
MlirAttribute ttmlirTTChipChannelAttrGet(MlirContext ctx, unsigned deviceId0,
int64_t *ethernetCoreCoord0,
size_t ethernetCoreCoord0Size,
unsigned deviceId1,
int64_t *ethernetCoreCoord1,
size_t ethernetCoreCoord1Size) {
std::vector<int64_t> ethCoord0Vec(
ethernetCoreCoord0, ethernetCoreCoord0 + ethernetCoreCoord0Size);
std::vector<int64_t> ethCoord1Vec(
ethernetCoreCoord1, ethernetCoreCoord1 + ethernetCoreCoord1Size);
return wrap(ChipChannelAttr::get(unwrap(ctx), deviceId0, ethCoord0Vec,
deviceId1, ethCoord1Vec));
}

MlirAttribute ttmlirTTSystemDescAttrGet(
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ mlir::tt::SystemDescAttr::getDefault(MLIRContext *context) {
{
tt::ChipDescAttr::get(
context, tt::ArchAttr::get(context, tt::Arch::WormholeB0), {8, 8},
(1 << 20), 12, (1 << 20), 16, 32, 32),
1499136, 12, (1 << 30), 16, 32, 32),
},
// Chip Descriptor Indices
{
Expand Down
13 changes: 8 additions & 5 deletions python/TTModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,14 @@ void populateTTModule(py::module &m) {
});

py::class_<tt::ChipChannelAttr>(m, "ChipChannelAttr")
.def_static("get",
[](MlirContext ctx, unsigned endpoint0, unsigned endpoint1) {
return wrap(tt::ChipChannelAttr::get(unwrap(ctx), endpoint0,
endpoint1));
});
.def_static("get", [](MlirContext ctx, unsigned deviceId0,
std::vector<int64_t> ethernetCoreCoord0,
unsigned deviceId1,
std::vector<int64_t> ethernetCoreCoord1) {
return wrap(tt::ChipChannelAttr::get(unwrap(ctx), deviceId0,
ethernetCoreCoord0, deviceId1,
ethernetCoreCoord1));
});

py::class_<tt::SystemDescAttr>(m, "SystemDescAttr")
.def_static("get", [](MlirContext ctx,
Expand Down
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#pragma clang diagnostic ignored "-Wlogical-op-parentheses"
#pragma clang diagnostic ignored "-Wundefined-inline"
#define FMT_HEADER_ONLY
#include "impl/device/device_mesh.hpp"
#include "tt_metal/host_api.hpp"
#pragma clang diagnostic pop

Expand Down
4 changes: 4 additions & 0 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

namespace tt::runtime {

namespace system_desc {
std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc();
} // namespace system_desc

DeviceRuntime getCurrentRuntime();

std::vector<DeviceRuntime> getAvailableRuntimes();
Expand Down
26 changes: 17 additions & 9 deletions runtime/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@ else()
endif()

message(STATUS "Runtimes Enabled: TTNN[${TT_RUNTIME_ENABLE_TTNN}] TTMETAL[${TT_RUNTIME_ENABLE_TTMETAL}]")
add_library(TTRuntime STATIC runtime.cpp)
set_property(TARGET TTRuntime PROPERTY CXX_STANDARD 20)
if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTNN)
target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTNN)
endif()
if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTMETAL)
target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTMETAL)
endif()

add_library(TTBinary STATIC binary.cpp)
target_include_directories(TTBinary
Expand All @@ -32,6 +24,21 @@ target_include_directories(TTBinary
)
add_dependencies(TTBinary FBS_GENERATION)

if (TTMLIR_ENABLE_RUNTIME AND (TT_RUNTIME_ENABLE_TTNN OR TT_RUNTIME_ENABLE_TTMETAL))
add_subdirectory(common)
else()
add_library(TTRuntimeSysDesc INTERFACE)
endif()

add_library(TTRuntime STATIC runtime.cpp)
set_property(TARGET TTRuntime PROPERTY CXX_STANDARD 20)
if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTNN)
target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTNN)
endif()
if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTMETAL)
target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTMETAL)
endif()

target_include_directories(TTRuntime
PUBLIC
${PROJECT_SOURCE_DIR}/runtime/include
Expand All @@ -41,8 +48,9 @@ target_include_directories(TTRuntime
target_link_libraries(TTRuntime
PRIVATE
TTBinary
TTRuntimeSysDesc
TTRuntimeTTNN
TTRuntimeTTMetal
)

add_dependencies(TTRuntime TTBinary FBS_GENERATION)
add_dependencies(TTRuntime TTBinary TTRuntimeSysDesc FBS_GENERATION)
9 changes: 9 additions & 0 deletions runtime/lib/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
add_library(TTRuntimeSysDesc STATIC system_desc.cpp)
set_property(TARGET TTRuntimeSysDesc PROPERTY CXX_STANDARD 20)
target_include_directories(TTRuntimeSysDesc
PUBLIC
${PROJECT_SOURCE_DIR}/runtime/include
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common
)
target_include_directories(TTRuntimeSysDesc PUBLIC "$<BUILD_INTERFACE:${TTMETAL_INCLUDE_DIRS}>")
add_dependencies(TTRuntimeSysDesc tt-metal FBS_GENERATION)
179 changes: 179 additions & 0 deletions runtime/lib/common/system_desc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "tt/runtime/types.h"
#include "tt/runtime/utils.h"
#include "ttmlir/Target/TTNN/Target.h"
#include "ttmlir/Version.h"
#include <cstdint>
#include <vector>

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wctad-maybe-unsupported"
#pragma clang diagnostic ignored "-Wcovered-switch-default"
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wignored-qualifiers"
#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
#pragma clang diagnostic ignored "-Wvla-extension"
#pragma clang diagnostic ignored "-Wsign-compare"
#pragma clang diagnostic ignored "-Wcast-qual"
#pragma clang diagnostic ignored "-Wdeprecated-this-capture"
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
#pragma clang diagnostic ignored "-Wsuggest-override"
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
#pragma clang diagnostic ignored "-Wnested-anon-types"
#pragma clang diagnostic ignored "-Wreorder-ctor"
#pragma clang diagnostic ignored "-Wmismatched-tags"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-local-typedef"
#define FMT_HEADER_ONLY
#include "host_api.hpp"
#include "hostdevcommon/common_values.hpp"
#include "impl/device/device_mesh.hpp"
#pragma clang diagnostic pop

namespace tt::runtime::system_desc {
static ::tt::target::Dim2d toFlatbuffer(const CoreCoord &coreCoord) {
return ::tt::target::Dim2d(coreCoord.y, coreCoord.x);
}

static ::tt::target::Arch toFlatbuffer(::tt::ARCH arch) {
switch (arch) {
case ::tt::ARCH::GRAYSKULL:
return ::tt::target::Arch::Grayskull;
case ::tt::ARCH::WORMHOLE_B0:
return ::tt::target::Arch::Wormhole_b0;
case ::tt::ARCH::BLACKHOLE:
return ::tt::target::Arch::Blackhole;
default:
break;
}

throw std::runtime_error("Unsupported arch");
}

static std::vector<::tt::target::ChipChannel>
getAllDeviceConnections(const vector<::tt::tt_metal::Device *> &devices) {
std::set<std::tuple<chip_id_t, CoreCoord, chip_id_t, CoreCoord>>
connectionSet;

auto addConnection = [&connectionSet](
chip_id_t deviceId0, CoreCoord ethCoreCoord0,
chip_id_t deviceId1, CoreCoord ethCoreCoord1) {
if (deviceId0 > deviceId1) {
std::swap(deviceId0, deviceId1);
std::swap(ethCoreCoord0, ethCoreCoord1);
}
connectionSet.emplace(deviceId0, ethCoreCoord0, deviceId1, ethCoreCoord1);
};

for (const ::tt::tt_metal::Device *device : devices) {
std::unordered_set<CoreCoord> activeEthernetCores =
device->get_active_ethernet_cores(true);
for (const CoreCoord &ethernetCore : activeEthernetCores) {
std::tuple<chip_id_t, CoreCoord> connectedDevice =
device->get_connected_ethernet_core(ethernetCore);
addConnection(device->id(), ethernetCore, std::get<0>(connectedDevice),
std::get<1>(connectedDevice));
}
}

std::vector<::tt::target::ChipChannel> allConnections;
allConnections.resize(connectionSet.size());

std::transform(
connectionSet.begin(), connectionSet.end(), allConnections.begin(),
[](const std::tuple<chip_id_t, CoreCoord, chip_id_t, CoreCoord>
&connection) {
return ::tt::target::ChipChannel(
std::get<0>(connection), toFlatbuffer(std::get<1>(connection)),
std::get<2>(connection), toFlatbuffer(std::get<3>(connection)));
});

return allConnections;
}

static std::unique_ptr<::tt::runtime::SystemDesc>
getCurrentSystemDescImpl(const ::tt::tt_metal::DeviceMesh &deviceMesh) {
std::vector<::tt::tt_metal::Device *> devices = deviceMesh.get_devices();
std::sort(devices.begin(), devices.end(),
[](const ::tt::tt_metal::Device *a,
const ::tt::tt_metal::Device *b) { return a->id() < b->id(); });

std::vector<::flatbuffers::Offset<tt::target::ChipDesc>> chipDescs;
std::vector<uint32_t> chipDescIndices;
std::vector<::tt::target::ChipCapability> chipCapabilities;
// Ignore for now
std::vector<::tt::target::ChipCoord> chipCoords = {
::tt::target::ChipCoord(0, 0, 0, 0)};
::flatbuffers::FlatBufferBuilder fbb;

for (const ::tt::tt_metal::Device *device : devices) {
// Construct chip descriptor
::tt::target::Dim2d deviceGrid =
toFlatbuffer(device->compute_with_storage_grid_size());
chipDescs.push_back(::tt::target::CreateChipDesc(
fbb, toFlatbuffer(device->arch()), &deviceGrid,
device->l1_size_per_core(), device->num_dram_channels(),
device->dram_size_per_channel(), L1_ALIGNMENT, PCIE_ALIGNMENT,
DRAM_ALIGNMENT));
chipDescIndices.push_back(device->id());
// Derive chip capability
::tt::target::ChipCapability chipCapability =
::tt::target::ChipCapability::NONE;
if (device->is_mmio_capable()) {
chipCapability = chipCapability | ::tt::target::ChipCapability::PCIE |
::tt::target::ChipCapability::HostMMIO;
}
chipCapabilities.push_back(chipCapability);
}
// Extract chip connected channels
std::vector<::tt::target::ChipChannel> allConnections =
getAllDeviceConnections(devices);
// Create SystemDesc
auto systemDesc = ::tt::target::CreateSystemDescDirect(
fbb, &chipDescs, &chipDescIndices, &chipCapabilities, &chipCoords,
&allConnections);
::ttmlir::Version ttmlirVersion = ::ttmlir::getVersion();
::tt::target::Version version(ttmlirVersion.major, ttmlirVersion.minor,
ttmlirVersion.patch);
auto root = ::tt::target::CreateSystemDescRootDirect(
fbb, &version, ::ttmlir::getGitHash(), "unknown", systemDesc);
::tt::target::FinishSizePrefixedSystemDescRootBuffer(fbb, root);
::flatbuffers::Verifier verifier(fbb.GetBufferPointer(), fbb.GetSize());
if (not ::tt::target::VerifySizePrefixedSystemDescRootBuffer(verifier)) {
throw std::runtime_error("Failed to verify system desc root buffer");
}
uint8_t *buf = fbb.GetBufferPointer();
auto size = fbb.GetSize();
auto handle = ::tt::runtime::utils::malloc_shared(size);
std::memcpy(handle.get(), buf, size);
return std::make_unique<::tt::runtime::SystemDesc>(handle);
}

std::pair<::tt::runtime::SystemDesc, DeviceIds> getCurrentSystemDesc() {
size_t numDevices = ::tt::tt_metal::GetNumAvailableDevices();
size_t numPciDevices = ::tt::tt_metal::GetNumPCIeDevices();
TT_FATAL(numDevices % numPciDevices == 0,
"Unexpected non-rectangular grid of devices");
std::vector<chip_id_t> deviceIds(numDevices);
std::iota(deviceIds.begin(), deviceIds.end(), 0);
::tt::tt_metal::DeviceGrid grid =
std::make_pair(numDevices / numPciDevices, numPciDevices);
::tt::tt_metal::DeviceMesh deviceMesh = ::tt::tt_metal::DeviceMesh(
grid, deviceIds, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1);
std::exception_ptr eptr = nullptr;
std::unique_ptr<::tt::runtime::SystemDesc> desc;
try {
desc = getCurrentSystemDescImpl(deviceMesh);
} catch (...) {
eptr = std::current_exception();
}
deviceMesh.close_devices();
if (eptr) {
std::rethrow_exception(eptr);
}
return std::make_pair(*desc, deviceIds);
}

} // namespace tt::runtime::system_desc
13 changes: 3 additions & 10 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "tt/runtime/runtime.h"
#include "tt/runtime/utils.h"
#include "ttmlir/Target/TTNN/Target.h"
#include "ttmlir/Version.h"

#if defined(TT_RUNTIME_ENABLE_TTNN)
Expand Down Expand Up @@ -77,16 +78,8 @@ void setCompatibleRuntime(const Binary &binary) {
}

std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getCurrentSystemDesc();
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getCurrentSystemDesc();
}
#if defined(TT_RUNTIME_ENABLE_TTNN) || defined(TT_RUNTIME_ENABLE_TTMETAL)
return system_desc::getCurrentSystemDesc();
#endif
throw std::runtime_error("runtime is not enabled");
}
Expand Down
3 changes: 2 additions & 1 deletion runtime/lib/ttmetal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ target_include_directories(TTRuntimeTTMetal PUBLIC
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common
)
target_include_directories(TTRuntimeTTMetal PUBLIC "$<BUILD_INTERFACE:${TTMETAL_INCLUDE_DIRS}>")
add_dependencies(TTRuntimeTTMetal tt-metal FBS_GENERATION)
target_link_libraries(TTRuntimeTTMetal PUBLIC TTMETAL_LIBRARY)
add_dependencies(TTRuntimeTTMetal TTMETAL_LIBRARY tt-metal FBS_GENERATION)
Loading

0 comments on commit 0b393ec

Please sign in to comment.