Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for multi-device system descriptor #220

Merged
merged 2 commits into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions include/ttmlir-c/TTAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipDescAttrGet(
MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipCoordAttrGet(
MlirContext ctx, unsigned rack, unsigned shelf, unsigned y, unsigned x);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipChannelAttrGet(MlirContext ctx,
unsigned endpoint0,
unsigned endpoint1);
MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipChannelAttrGet(
MlirContext ctx, unsigned deviceId0, int64_t *ethernetCoreCoord0,
size_t ethernetCoreCoord0Size, unsigned deviceId1,
int64_t *ethernetCoreCoord1, size_t ethernetCoreCoord1Size);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTSystemDescAttrGet(
MlirContext ctx, MlirAttribute *chipDescs, size_t chipDescsSize,
Expand Down
7 changes: 5 additions & 2 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,11 @@ def TT_ChipChannelAttr : TT_Attr<"ChipChannel", "chip_channel"> {
TT chip_channel attribute
}];

let parameters = (ins "unsigned":$endpoint0, "unsigned":$endpoint1);
let assemblyFormat = "`<` $endpoint0 `,` $endpoint1 `>`";
let parameters = (ins "unsigned":$deviceId0,
ArrayRefParameter<"int64_t">:$ethernetCoreCoord0,
"unsigned":$deviceId1,
ArrayRefParameter<"int64_t">:$ethernetCoreCoord1);
let assemblyFormat = "`<` $deviceId0 `,` $ethernetCoreCoord0 `,` $deviceId1 `,` $ethernetCoreCoord1 `>`";
}

def TT_SystemDescAttr : TT_Attr<"SystemDesc", "system_desc"> {
Expand Down
6 changes: 4 additions & 2 deletions include/ttmlir/Target/Common/types.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@ struct ChipCoord {
}

struct ChipChannel {
endpoint0: uint32;
endpoint1: uint32;
device_id0: uint32;
ethernet_core_coord0: Dim2d;
device_id1: uint32;
ethernet_core_coord1: Dim2d;
}

table SystemDesc {
Expand Down
9 changes: 7 additions & 2 deletions include/ttmlir/Target/Utils/MLIRToFlatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,13 @@ inline ::tt::target::ChipCoord toFlatbuffer(FlatbufferObjectCache &cache,

inline ::tt::target::ChipChannel toFlatbuffer(FlatbufferObjectCache &cache,
ChipChannelAttr chipChannel) {
return ::tt::target::ChipChannel(chipChannel.getEndpoint0(),
chipChannel.getEndpoint1());
return ::tt::target::ChipChannel(
chipChannel.getDeviceId0(),
::tt::target::Dim2d(chipChannel.getEthernetCoreCoord0()[0],
chipChannel.getEthernetCoreCoord0()[1]),
chipChannel.getDeviceId1(),
::tt::target::Dim2d(chipChannel.getEthernetCoreCoord1()[0],
chipChannel.getEthernetCoreCoord1()[1]));
}

inline ::tt::target::Dim2d toFlatbuffer(FlatbufferObjectCache &cache,
Expand Down
15 changes: 12 additions & 3 deletions lib/CAPI/TTAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,18 @@ MlirAttribute ttmlirTTChipCoordAttrGet(MlirContext ctx, unsigned rack,
return wrap(ChipCoordAttr::get(unwrap(ctx), rack, shelf, y, x));
}

MlirAttribute ttmlirTTChipChannelAttrGet(MlirContext ctx, unsigned endpoint0,
unsigned endpoint1) {
return wrap(ChipChannelAttr::get(unwrap(ctx), endpoint0, endpoint1));
MlirAttribute ttmlirTTChipChannelAttrGet(MlirContext ctx, unsigned deviceId0,
int64_t *ethernetCoreCoord0,
size_t ethernetCoreCoord0Size,
unsigned deviceId1,
int64_t *ethernetCoreCoord1,
size_t ethernetCoreCoord1Size) {
std::vector<int64_t> ethCoord0Vec(
ethernetCoreCoord0, ethernetCoreCoord0 + ethernetCoreCoord0Size);
std::vector<int64_t> ethCoord1Vec(
ethernetCoreCoord1, ethernetCoreCoord1 + ethernetCoreCoord1Size);
return wrap(ChipChannelAttr::get(unwrap(ctx), deviceId0, ethCoord0Vec,
deviceId1, ethCoord1Vec));
}

MlirAttribute ttmlirTTSystemDescAttrGet(
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ mlir::tt::SystemDescAttr::getDefault(MLIRContext *context) {
{
tt::ChipDescAttr::get(
context, tt::ArchAttr::get(context, tt::Arch::WormholeB0), {8, 8},
(1 << 20), 12, (1 << 20), 16, 32, 32),
1499136, 12, (1 << 30), 16, 32, 32),
},
// Chip Descriptor Indices
{
Expand Down
13 changes: 8 additions & 5 deletions python/TTModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,14 @@ void populateTTModule(py::module &m) {
});

py::class_<tt::ChipChannelAttr>(m, "ChipChannelAttr")
.def_static("get",
[](MlirContext ctx, unsigned endpoint0, unsigned endpoint1) {
return wrap(tt::ChipChannelAttr::get(unwrap(ctx), endpoint0,
endpoint1));
});
.def_static("get", [](MlirContext ctx, unsigned deviceId0,
std::vector<int64_t> ethernetCoreCoord0,
unsigned deviceId1,
std::vector<int64_t> ethernetCoreCoord1) {
return wrap(tt::ChipChannelAttr::get(unwrap(ctx), deviceId0,
ethernetCoreCoord0, deviceId1,
ethernetCoreCoord1));
});

py::class_<tt::SystemDescAttr>(m, "SystemDescAttr")
.def_static("get", [](MlirContext ctx,
Expand Down
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#pragma clang diagnostic ignored "-Wlogical-op-parentheses"
#pragma clang diagnostic ignored "-Wundefined-inline"
#define FMT_HEADER_ONLY
#include "impl/device/device_mesh.hpp"
#include "tt_metal/host_api.hpp"
#pragma clang diagnostic pop

Expand Down
4 changes: 4 additions & 0 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

namespace tt::runtime {

namespace system_desc {
std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc();
} // namespace system_desc

DeviceRuntime getCurrentRuntime();

std::vector<DeviceRuntime> getAvailableRuntimes();
Expand Down
26 changes: 17 additions & 9 deletions runtime/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@ else()
endif()

message(STATUS "Runtimes Enabled: TTNN[${TT_RUNTIME_ENABLE_TTNN}] TTMETAL[${TT_RUNTIME_ENABLE_TTMETAL}]")
add_library(TTRuntime STATIC runtime.cpp)
set_property(TARGET TTRuntime PROPERTY CXX_STANDARD 20)
if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTNN)
target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTNN)
endif()
if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTMETAL)
target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTMETAL)
endif()

add_library(TTBinary STATIC binary.cpp)
target_include_directories(TTBinary
Expand All @@ -32,6 +24,21 @@ target_include_directories(TTBinary
)
add_dependencies(TTBinary FBS_GENERATION)

if (TTMLIR_ENABLE_RUNTIME AND (TT_RUNTIME_ENABLE_TTNN OR TT_RUNTIME_ENABLE_TTMETAL))
add_subdirectory(common)
else()
add_library(TTRuntimeSysDesc INTERFACE)
endif()

add_library(TTRuntime STATIC runtime.cpp)
set_property(TARGET TTRuntime PROPERTY CXX_STANDARD 20)
if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTNN)
target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTNN)
endif()
if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTMETAL)
target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTMETAL)
endif()

target_include_directories(TTRuntime
PUBLIC
${PROJECT_SOURCE_DIR}/runtime/include
Expand All @@ -41,8 +48,9 @@ target_include_directories(TTRuntime
target_link_libraries(TTRuntime
PRIVATE
TTBinary
TTRuntimeSysDesc
TTRuntimeTTNN
TTRuntimeTTMetal
)

add_dependencies(TTRuntime TTBinary FBS_GENERATION)
add_dependencies(TTRuntime TTBinary TTRuntimeSysDesc FBS_GENERATION)
9 changes: 9 additions & 0 deletions runtime/lib/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
add_library(TTRuntimeSysDesc STATIC system_desc.cpp)
set_property(TARGET TTRuntimeSysDesc PROPERTY CXX_STANDARD 20)
target_include_directories(TTRuntimeSysDesc
PUBLIC
${PROJECT_SOURCE_DIR}/runtime/include
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common
)
target_include_directories(TTRuntimeSysDesc PUBLIC "$<BUILD_INTERFACE:${TTMETAL_INCLUDE_DIRS}>")
add_dependencies(TTRuntimeSysDesc tt-metal FBS_GENERATION)
179 changes: 179 additions & 0 deletions runtime/lib/common/system_desc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "tt/runtime/types.h"
#include "tt/runtime/utils.h"
#include "ttmlir/Target/TTNN/Target.h"
#include "ttmlir/Version.h"
#include <cstdint>
#include <vector>

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wctad-maybe-unsupported"
#pragma clang diagnostic ignored "-Wcovered-switch-default"
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wignored-qualifiers"
#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
#pragma clang diagnostic ignored "-Wvla-extension"
#pragma clang diagnostic ignored "-Wsign-compare"
#pragma clang diagnostic ignored "-Wcast-qual"
#pragma clang diagnostic ignored "-Wdeprecated-this-capture"
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
#pragma clang diagnostic ignored "-Wsuggest-override"
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
#pragma clang diagnostic ignored "-Wnested-anon-types"
#pragma clang diagnostic ignored "-Wreorder-ctor"
#pragma clang diagnostic ignored "-Wmismatched-tags"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-local-typedef"
#define FMT_HEADER_ONLY
#include "host_api.hpp"
#include "hostdevcommon/common_values.hpp"
#include "impl/device/device_mesh.hpp"
#pragma clang diagnostic pop

namespace tt::runtime::system_desc {
static ::tt::target::Dim2d toFlatbuffer(const CoreCoord &coreCoord) {
return ::tt::target::Dim2d(coreCoord.y, coreCoord.x);
}

static ::tt::target::Arch toFlatbuffer(::tt::ARCH arch) {
switch (arch) {
case ::tt::ARCH::GRAYSKULL:
return ::tt::target::Arch::Grayskull;
case ::tt::ARCH::WORMHOLE_B0:
return ::tt::target::Arch::Wormhole_b0;
case ::tt::ARCH::BLACKHOLE:
return ::tt::target::Arch::Blackhole;
default:
break;
}

throw std::runtime_error("Unsupported arch");
}

static std::vector<::tt::target::ChipChannel>
getAllDeviceConnections(const vector<::tt::tt_metal::Device *> &devices) {
std::set<std::tuple<chip_id_t, CoreCoord, chip_id_t, CoreCoord>>
connectionSet;

auto addConnection = [&connectionSet](
chip_id_t deviceId0, CoreCoord ethCoreCoord0,
chip_id_t deviceId1, CoreCoord ethCoreCoord1) {
if (deviceId0 > deviceId1) {
std::swap(deviceId0, deviceId1);
std::swap(ethCoreCoord0, ethCoreCoord1);
}
connectionSet.emplace(deviceId0, ethCoreCoord0, deviceId1, ethCoreCoord1);
};

for (const ::tt::tt_metal::Device *device : devices) {
std::unordered_set<CoreCoord> activeEthernetCores =
device->get_active_ethernet_cores(true);
for (const CoreCoord &ethernetCore : activeEthernetCores) {
std::tuple<chip_id_t, CoreCoord> connectedDevice =
device->get_connected_ethernet_core(ethernetCore);
addConnection(device->id(), ethernetCore, std::get<0>(connectedDevice),
std::get<1>(connectedDevice));
}
}

std::vector<::tt::target::ChipChannel> allConnections;
allConnections.resize(connectionSet.size());

std::transform(
connectionSet.begin(), connectionSet.end(), allConnections.begin(),
[](const std::tuple<chip_id_t, CoreCoord, chip_id_t, CoreCoord>
&connection) {
return ::tt::target::ChipChannel(
std::get<0>(connection), toFlatbuffer(std::get<1>(connection)),
std::get<2>(connection), toFlatbuffer(std::get<3>(connection)));
});

return allConnections;
}

static std::unique_ptr<::tt::runtime::SystemDesc>
getCurrentSystemDescImpl(const ::tt::tt_metal::DeviceMesh &deviceMesh) {
std::vector<::tt::tt_metal::Device *> devices = deviceMesh.get_devices();
std::sort(devices.begin(), devices.end(),
[](const ::tt::tt_metal::Device *a,
const ::tt::tt_metal::Device *b) { return a->id() < b->id(); });

std::vector<::flatbuffers::Offset<tt::target::ChipDesc>> chipDescs;
std::vector<uint32_t> chipDescIndices;
std::vector<::tt::target::ChipCapability> chipCapabilities;
// Ignore for now
std::vector<::tt::target::ChipCoord> chipCoords = {
::tt::target::ChipCoord(0, 0, 0, 0)};
::flatbuffers::FlatBufferBuilder fbb;

for (const ::tt::tt_metal::Device *device : devices) {
// Construct chip descriptor
::tt::target::Dim2d deviceGrid =
toFlatbuffer(device->compute_with_storage_grid_size());
chipDescs.push_back(::tt::target::CreateChipDesc(
fbb, toFlatbuffer(device->arch()), &deviceGrid,
device->l1_size_per_core(), device->num_dram_channels(),
device->dram_size_per_channel(), L1_ALIGNMENT, PCIE_ALIGNMENT,
DRAM_ALIGNMENT));
chipDescIndices.push_back(device->id());
// Derive chip capability
::tt::target::ChipCapability chipCapability =
::tt::target::ChipCapability::NONE;
if (device->is_mmio_capable()) {
chipCapability = chipCapability | ::tt::target::ChipCapability::PCIE |
::tt::target::ChipCapability::HostMMIO;
}
chipCapabilities.push_back(chipCapability);
}
// Extract chip connected channels
std::vector<::tt::target::ChipChannel> allConnections =
getAllDeviceConnections(devices);
// Create SystemDesc
auto systemDesc = ::tt::target::CreateSystemDescDirect(
fbb, &chipDescs, &chipDescIndices, &chipCapabilities, &chipCoords,
&allConnections);
::ttmlir::Version ttmlirVersion = ::ttmlir::getVersion();
::tt::target::Version version(ttmlirVersion.major, ttmlirVersion.minor,
ttmlirVersion.patch);
auto root = ::tt::target::CreateSystemDescRootDirect(
fbb, &version, ::ttmlir::getGitHash(), "unknown", systemDesc);
::tt::target::FinishSizePrefixedSystemDescRootBuffer(fbb, root);
::flatbuffers::Verifier verifier(fbb.GetBufferPointer(), fbb.GetSize());
if (not ::tt::target::VerifySizePrefixedSystemDescRootBuffer(verifier)) {
throw std::runtime_error("Failed to verify system desc root buffer");
}
uint8_t *buf = fbb.GetBufferPointer();
auto size = fbb.GetSize();
auto handle = ::tt::runtime::utils::malloc_shared(size);
std::memcpy(handle.get(), buf, size);
return std::make_unique<::tt::runtime::SystemDesc>(handle);
}

std::pair<::tt::runtime::SystemDesc, DeviceIds> getCurrentSystemDesc() {
size_t numDevices = ::tt::tt_metal::GetNumAvailableDevices();
size_t numPciDevices = ::tt::tt_metal::GetNumPCIeDevices();
TT_FATAL(numDevices % numPciDevices == 0,
"Unexpected non-rectangular grid of devices");
std::vector<chip_id_t> deviceIds(numDevices);
std::iota(deviceIds.begin(), deviceIds.end(), 0);
::tt::tt_metal::DeviceGrid grid =
std::make_pair(numDevices / numPciDevices, numPciDevices);
::tt::tt_metal::DeviceMesh deviceMesh = ::tt::tt_metal::DeviceMesh(
grid, deviceIds, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1);
std::exception_ptr eptr = nullptr;
std::unique_ptr<::tt::runtime::SystemDesc> desc;
try {
desc = getCurrentSystemDescImpl(deviceMesh);
} catch (...) {
eptr = std::current_exception();
}
deviceMesh.close_devices();
if (eptr) {
std::rethrow_exception(eptr);
}
return std::make_pair(*desc, deviceIds);
}

} // namespace tt::runtime::system_desc
13 changes: 3 additions & 10 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "tt/runtime/runtime.h"
#include "tt/runtime/utils.h"
#include "ttmlir/Target/TTNN/Target.h"
#include "ttmlir/Version.h"

#if defined(TT_RUNTIME_ENABLE_TTNN)
Expand Down Expand Up @@ -77,16 +78,8 @@ void setCompatibleRuntime(const Binary &binary) {
}

std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getCurrentSystemDesc();
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getCurrentSystemDesc();
}
#if defined(TT_RUNTIME_ENABLE_TTNN) || defined(TT_RUNTIME_ENABLE_TTMETAL)
return system_desc::getCurrentSystemDesc();
#endif
throw std::runtime_error("runtime is not enabled");
}
Expand Down
3 changes: 2 additions & 1 deletion runtime/lib/ttmetal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ target_include_directories(TTRuntimeTTMetal PUBLIC
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common
)
target_include_directories(TTRuntimeTTMetal PUBLIC "$<BUILD_INTERFACE:${TTMETAL_INCLUDE_DIRS}>")
add_dependencies(TTRuntimeTTMetal tt-metal FBS_GENERATION)
target_link_libraries(TTRuntimeTTMetal PUBLIC TTMETAL_LIBRARY)
add_dependencies(TTRuntimeTTMetal TTMETAL_LIBRARY tt-metal FBS_GENERATION)
Loading
Loading