Skip to content

Commit

Permalink
Set dispatch core to Ethernet on N300, ensuring 8x8 grid across devic…
Browse files Browse the repository at this point in the history
…es (#1897)

Current behaviour always sets dispatch core type to worker. This leads
to a different grid size on N300 `7x8` and N150 `8x8`, since N300 is
dual row harvested. This change moves dispatch to Ethernet cores on
N300, freeing up another row and giving the same grid size.

While I was adding, I gave the option of overriding from FEs when
opening the device.
  • Loading branch information
AleksKnezevic authored Jan 23, 2025
1 parent 78601d6 commit 7214507
Show file tree
Hide file tree
Showing 14 changed files with 144 additions and 52 deletions.
12 changes: 11 additions & 1 deletion lib/OpModel/TTNN/SingletonDeviceContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,17 @@
namespace mlir::tt::op_model::ttnn {

SingletonDeviceContext::SingletonDeviceContext() {
m_device = ::tt::tt_metal::CreateDevice(0);

// todo: this replicates logic in runtime/include/tt/runtime/detail/common.h,
// move to shared location
size_t numDevices = ::tt::tt_metal::GetNumAvailableDevices();
size_t numPCIeDevices = ::tt::tt_metal::GetNumPCIeDevices();
::tt::tt_metal::DispatchCoreType dispatchCoreType =
numDevices == numPCIeDevices ? ::tt::tt_metal::DispatchCoreType::WORKER
: ::tt::tt_metal::DispatchCoreType::ETH;
m_device = ::tt::tt_metal::CreateDevice(
0, /* num_hw_cqs = */ 1, /* l1_small_size = */ DEFAULT_L1_SMALL_SIZE,
/* trace_region_size = */ DEFAULT_TRACE_REGION_SIZE, dispatchCoreType);
}

SingletonDeviceContext::~SingletonDeviceContext() {
Expand Down
41 changes: 41 additions & 0 deletions runtime/include/tt/runtime/detail/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TT_RUNTIME_DETAIL_COMMON_H
#define TT_RUNTIME_DETAIL_COMMON_H

#include <optional>

#define FMT_HEADER_ONLY
#include "tt-metalium/host_api.hpp"

#include "tt/runtime/detail/logger.h"
#include "tt/runtime/types.h"

namespace tt::runtime::common {

inline ::tt::tt_metal::DispatchCoreType
getDispatchCoreType(std::optional<DispatchCoreType> dispatchCoreType) {

::tt::tt_metal::DispatchCoreType type;
if (dispatchCoreType.has_value()) {
if (dispatchCoreType == DispatchCoreType::ETH) {
type = ::tt::tt_metal::DispatchCoreType::ETH;
} else if (dispatchCoreType == DispatchCoreType::WORKER) {
type = ::tt::tt_metal::DispatchCoreType::WORKER;
} else {
LOG_FATAL("Unsupported dispatch core type");
}
} else {
size_t numDevices = ::tt::tt_metal::GetNumAvailableDevices();
size_t numPCIeDevices = ::tt::tt_metal::GetNumPCIeDevices();
type = numDevices == numPCIeDevices
? ::tt::tt_metal::DispatchCoreType::WORKER
: ::tt::tt_metal::DispatchCoreType::ETH;
}
return type;
}

} // namespace tt::runtime::common
#endif // TT_RUNTIME_DETAIL_COMMON_H
6 changes: 4 additions & 2 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ tt::target::DataType getTensorDataType(Tensor tensor);

size_t getNumAvailableDevices();

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1,
std::optional<size_t> l1SmallSize = std::nullopt);
Device
openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1,
std::optional<size_t> l1SmallSize = std::nullopt,
std::optional<DispatchCoreType> dispatchCoreType = std::nullopt);

void closeDevice(Device device);

Expand Down
6 changes: 4 additions & 2 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ tt::target::DataType getTensorDataType(Tensor tensor);

size_t getNumAvailableDevices();

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1,
std::optional<size_t> l1SmallSize = std::nullopt);
Device
openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1,
std::optional<size_t> l1SmallSize = std::nullopt,
std::optional<DispatchCoreType> dispatchCoreType = std::nullopt);

void closeDevice(Device device);

Expand Down
6 changes: 4 additions & 2 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ tt::target::DataType getTensorDataType(Tensor tensor);

size_t getNumAvailableDevices();

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1,
std::optional<size_t> l1SmallSize = std::nullopt);
Device
openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1,
std::optional<size_t> l1SmallSize = std::nullopt,
std::optional<DispatchCoreType> dispatchCoreType = std::nullopt);

void closeDevice(Device device);

Expand Down
5 changes: 5 additions & 0 deletions runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ enum class DeviceRuntime {
TTMetal,
};

enum class DispatchCoreType {
WORKER,
ETH,
};

namespace detail {
struct ObjectImpl {

Expand Down
8 changes: 7 additions & 1 deletion runtime/lib/common/system_desc.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "tt/runtime/detail/common.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/types.h"
#include "tt/runtime/utils.h"
Expand Down Expand Up @@ -257,6 +258,8 @@ static std::unique_ptr<::tt::runtime::SystemDesc> getCurrentSystemDescImpl(

std::pair<::tt::runtime::SystemDesc, DeviceIds> getCurrentSystemDesc() {
size_t numDevices = ::tt::tt_metal::GetNumAvailableDevices();
::tt::tt_metal::DispatchCoreType dispatchCoreType =
tt::runtime::common::getDispatchCoreType(std::nullopt);
std::vector<chip_id_t> deviceIds(numDevices);
std::iota(deviceIds.begin(), deviceIds.end(), 0);
::tt::tt_metal::distributed::MeshShape meshShape = {1, numDevices};
Expand All @@ -265,7 +268,10 @@ std::pair<::tt::runtime::SystemDesc, DeviceIds> getCurrentSystemDesc() {
::tt::tt_metal::distributed::MeshDeviceConfig{.mesh_shape =
meshShape},
DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1,
::tt::tt_metal::DispatchCoreType::WORKER);
dispatchCoreType);
CoreCoord logical_grid_size = meshDevice->compute_with_storage_grid_size();
LOG_INFO("Grid size = { ", logical_grid_size.x, ", ", logical_grid_size.y,
"}");
std::exception_ptr eptr = nullptr;
std::unique_ptr<::tt::runtime::SystemDesc> desc;
try {
Expand Down
9 changes: 6 additions & 3 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,16 +234,19 @@ size_t getNumAvailableDevices() {
}

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs,
std::optional<size_t> l1SmallSize) {
std::optional<size_t> l1SmallSize,
std::optional<DispatchCoreType> dispatchCoreType) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::openDevice(deviceIds, numHWCQs, l1SmallSize);
return ::tt::runtime::ttnn::openDevice(deviceIds, numHWCQs, l1SmallSize,
dispatchCoreType);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::openDevice(deviceIds, numHWCQs, l1SmallSize);
return ::tt::runtime::ttmetal::openDevice(deviceIds, numHWCQs, l1SmallSize,
dispatchCoreType);
}
#endif
LOG_FATAL("runtime is not enabled");
Expand Down
14 changes: 11 additions & 3 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <variant>

#include "tt/runtime/detail/common.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttmetal.h"
#include "tt/runtime/runtime.h"
Expand Down Expand Up @@ -80,16 +81,23 @@ size_t getNumAvailableDevices() {
}

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs,
std::optional<size_t> l1SmallSize) {
std::optional<size_t> l1SmallSize,
std::optional<DispatchCoreType> dispatchCoreType) {
LOG_ASSERT(deviceIds.size(), "No devices specified");

::tt::tt_metal::DispatchCoreType type =
tt::runtime::common::getDispatchCoreType(dispatchCoreType);

::tt::tt_metal::distributed::MeshShape grid = {1, deviceIds.size()};
size_t l1SmallSizeValue = l1SmallSize.value_or(DEFAULT_L1_SMALL_SIZE);
std::shared_ptr<::tt::tt_metal::distributed::MeshDevice> meshDevice =
::tt::tt_metal::distributed::MeshDevice::create(
::tt::tt_metal::distributed::MeshDeviceConfig{.mesh_shape = grid},
l1SmallSizeValue, DEFAULT_TRACE_REGION_SIZE, numHWCQs,
::tt::tt_metal::DispatchCoreType::WORKER);
l1SmallSizeValue, DEFAULT_TRACE_REGION_SIZE, numHWCQs, type);

CoreCoord logical_grid_size = meshDevice->compute_with_storage_grid_size();
LOG_INFO("Grid size = { ", logical_grid_size.x, ", ", logical_grid_size.y,
"}");

return Device(std::static_pointer_cast<void>(meshDevice),
DeviceRuntime::TTMetal);
Expand Down
15 changes: 12 additions & 3 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "tt/runtime/detail/common.h"
#include "tt/runtime/detail/debug.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
Expand Down Expand Up @@ -193,14 +194,22 @@ size_t getNumAvailableDevices() {
}

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs,
std::optional<size_t> l1SmallSize) {
std::optional<size_t> l1SmallSize,
std::optional<DispatchCoreType> dispatchCoreType) {

::tt::tt_metal::DispatchCoreType type =
tt::runtime::common::getDispatchCoreType(dispatchCoreType);

LOG_ASSERT(deviceIds.size(), "No devices specified");
::tt::tt_metal::distributed::MeshShape grid = {1, deviceIds.size()};
size_t l1SmallSizeValue = l1SmallSize.value_or(kL1SmallSize);
std::shared_ptr<::ttnn::MeshDevice> meshDevice = ::ttnn::MeshDevice::create(
::tt::tt_metal::distributed::MeshDeviceConfig{.mesh_shape = grid},
l1SmallSizeValue, DEFAULT_TRACE_REGION_SIZE, numHWCQs,
::tt::tt_metal::DispatchCoreType::WORKER);
l1SmallSizeValue, DEFAULT_TRACE_REGION_SIZE, numHWCQs, type);

CoreCoord logical_grid_size = meshDevice->compute_with_storage_grid_size();
LOG_INFO("Grid size = { ", logical_grid_size.x, ", ", logical_grid_size.y,
"}");

bool enableAsync = debug::Env::get().enableAsyncTTNN;
for (::ttnn::IDevice *device : meshDevice->get_devices()) {
Expand Down
4 changes: 4 additions & 0 deletions runtime/tools/python/ttrt/runtime/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ PYBIND11_MODULE(_C, m) {
.value("Disabled", ::tt::runtime::DeviceRuntime::Disabled)
.value("TTNN", ::tt::runtime::DeviceRuntime::TTNN)
.value("TTMetal", ::tt::runtime::DeviceRuntime::TTMetal);
py::enum_<::tt::runtime::DispatchCoreType>(m, "DispatchCoreType")
.value("WORKER", ::tt::runtime::DispatchCoreType::WORKER)
.value("ETH", ::tt::runtime::DispatchCoreType::ETH);
m.def("get_current_runtime", &tt::runtime::getCurrentRuntime,
"Get the backend device runtime type");
m.def("get_available_runtimes", &tt::runtime::getAvailableRuntimes,
Expand Down Expand Up @@ -119,6 +122,7 @@ PYBIND11_MODULE(_C, m) {
m.def("open_device", &tt::runtime::openDevice, py::arg("device_ids"),
py::arg("num_hw_cqs") = size_t{1},
py::arg("l1_small_size") = py::none(),
py::arg("dispatch_core_type") = py::none(),
"Open a mesh of devices for execution");
m.def("close_device", &tt::runtime::closeDevice, "Close a mesh device");
m.def("to_host", &tt::runtime::toHost, py::arg("tensor"),
Expand Down
Loading

0 comments on commit 7214507

Please sign in to comment.