Skip to content

Commit

Permalink
#0: Fix issues with failing test_new_all_gather.py tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Jan 16, 2025
1 parent dcea6ec commit f5a1b95
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ MeshDevice::~MeshDevice() {}

IDevice* MeshDevice::get_device_index(size_t device_index) const {
TT_FATAL(device_index >= 0 and device_index < num_devices(), "Invalid device index");
return this->get_devices().at(device_index);
const auto& devices = scoped_devices_->get_devices();
return devices.at(device_index);
}

IDevice* MeshDevice::get_device(chip_id_t physical_device_id) const {
Expand Down Expand Up @@ -671,7 +672,7 @@ MemoryBlockTable MeshDevice::get_memory_block_table(const BufferType& buffer_typ
MeshSubDeviceManagerId MeshDevice::mesh_create_sub_device_manager(
tt::stl::Span<const SubDevice> sub_devices, DeviceAddr local_l1_size) {
MeshSubDeviceManagerId mesh_sub_device_manager_id(*this);
const auto& devices = this->get_devices();
const auto& devices = scoped_devices_->get_devices();
for (uint32_t i = 0; i < devices.size(); i++) {
auto* device = devices[i];
auto& sub_device_manager_id = mesh_sub_device_manager_id.sub_device_manager_ids[i];
Expand All @@ -688,7 +689,7 @@ MeshSubDeviceManagerId MeshDevice::mesh_create_sub_device_manager(
std::tuple<MeshSubDeviceManagerId, SubDeviceId> MeshDevice::mesh_create_sub_device_manager_with_fabric(tt::stl::Span<const SubDevice> sub_devices, DeviceAddr local_l1_size) {
MeshSubDeviceManagerId mesh_sub_device_manager_id(*this);
SubDeviceId fabric_sub_device_id;
const auto& devices = this->get_devices();
const auto& devices = scoped_devices_->get_devices();
for (uint32_t i = 0; i < devices.size(); i++) {
auto* device = devices[i];
auto& sub_device_manager_id = mesh_sub_device_manager_id.sub_device_manager_ids[i];
Expand All @@ -704,7 +705,7 @@ std::tuple<MeshSubDeviceManagerId, SubDeviceId> MeshDevice::mesh_create_sub_devi
}

void MeshDevice::mesh_load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id) {
const auto& devices = this->get_devices();
const auto& devices = scoped_devices_->get_devices();
for (uint32_t i = 0; i < devices.size(); i++) {
auto* device = devices[i];
auto sub_device_manager_id = mesh_sub_device_manager_id.sub_device_manager_ids[i];
Expand All @@ -713,13 +714,12 @@ void MeshDevice::mesh_load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_de
}
}
void MeshDevice::mesh_clear_loaded_sub_device_manager() {
const auto& devices = this->get_devices();
for (auto* device : devices) {
for (auto* device : scoped_devices_->get_devices()) {
device->push_work([device]() { device->clear_loaded_sub_device_manager(); });
}
}
void MeshDevice::mesh_remove_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id) {
const auto& devices = this->get_devices();
const auto& devices = scoped_devices_->get_devices();
for (uint32_t i = 0; i < devices.size(); i++) {
auto* device = devices[i];
auto sub_device_manager_id = mesh_sub_device_manager_id.sub_device_manager_ids[i];
Expand All @@ -729,13 +729,13 @@ void MeshDevice::mesh_remove_sub_device_manager(MeshSubDeviceManagerId mesh_sub_
}

void MeshDevice::mesh_set_sub_device_stall_group(tt::stl::Span<const SubDeviceId> sub_device_ids) {
for (auto* device : this->get_devices()) {
for (auto* device : scoped_devices_->get_devices()) {
device->push_work([device, sub_device_ids=std::vector<SubDeviceId>(sub_device_ids.begin(), sub_device_ids.end())]() { device->set_sub_device_stall_group(sub_device_ids); });
}
}

void MeshDevice::mesh_reset_sub_device_stall_group() {
for (auto* device : this->get_devices()) {
for (auto* device : scoped_devices_->get_devices()) {
device->push_work([device]() { device->reset_sub_device_stall_group(); });
}
}
Expand Down

0 comments on commit f5a1b95

Please sign in to comment.