From 0f7e0c20cb45e0fcd67420f8731a6843852bf1dc Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Fri, 3 Jan 2025 15:03:51 +0000 Subject: [PATCH 1/7] #16391: propagate sub_device_ids to mesh - Further update all-gather-async tests --- .../ccl/test_all_gather_TG_post_commit.py | 190 ++++++++++-------- .../operations/ccl/test_new_all_gather.py | 145 ++++++------- ttnn/ttnn/distributed/distributed.py | 5 +- ttnn/ttnn/operations/core.py | 2 +- 4 files changed, 184 insertions(+), 158 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py index bd26d78e062..b025acf7ddb 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py @@ -180,17 +180,9 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( False, ) output_mem_config = ttnn.MemoryConfig(tensor_memory_layout, buffer_type=buffer_type, shard_spec=output_shard_spec) - ttnn_tensor = ttnn.from_torch( - full_input_tensor_unfractured, - tile=ttnn.Tile(tile), - dtype=input_dtype, - device=mesh_device, - layout=layout, - memory_config=input_mem_config, - mesh_mapper=ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dims), - ) - ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device) + worker_subdevice_ids = [] + fabric_torn_down = True if use_all_gather_async: compute_grid_size = mesh_device.compute_with_storage_grid_size() worker_sub_device = ttnn.SubDevice( @@ -205,89 +197,117 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( ] ) worker_sub_device_id = ttnn.SubDeviceId(0) + worker_subdevice_ids = [worker_sub_device_id] if create_persistent_fabric: - logger.info("Create persistent fabric interface") mesh_sub_device_manager_id = create_and_load_sub_device_manager_with_fabric_interface( mesh_device, [worker_sub_device], 0, 0, enable_persistent_fabric ) - logger.info("Done Create persistent fabric interface") - - # ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor) - if trace_mode: - ttnn_tensor_out = run_with_trace( - input_tensor=ttnn_tensor, - dim=dim, - cluster_axis=cluster_axis, - mesh_device=mesh_device, - num_links=num_links, - output_mem_config=output_mem_config, - all_gather_topology=ttnn.Topology.Linear, - num_iter=num_iters, + fabric_torn_down = False + elif teardown_persistent_fabric: + fabric_torn_down = False + + try: + ttnn_tensor = ttnn.from_torch( + full_input_tensor_unfractured, + tile=ttnn.Tile(tile), + dtype=input_dtype, + device=mesh_device, + layout=layout, + memory_config=input_mem_config, + mesh_mapper=ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dims), + sub_device_ids=worker_subdevice_ids, ) - else: - for _ in range(num_iters): - if use_all_gather_async: - logger.info("Running all-gather async") - ttnn_tensor_out = ttnn.experimental.all_gather_async( - ttnn_tensor, - dim, - cluster_axis=cluster_axis, - mesh_device=mesh_device, - topology=ttnn.Topology.Linear, - num_links=num_links, - memory_config=output_mem_config, - subdevice_id=worker_sub_device_id, - enable_persistent_fabric_mode=enable_persistent_fabric, - create_semaphore_handles=True, - ) - else: - ttnn_tensor_out = ttnn.all_gather( - ttnn_tensor, - dim=dim, - cluster_axis=cluster_axis, - mesh_device=mesh_device, - num_links=num_links, - memory_config=output_mem_config, - topology=ttnn.Topology.Linear, - ) + ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device, sub_device_ids=worker_subdevice_ids) + + # ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor) + if trace_mode: + ttnn_tensor_out = run_with_trace( + input_tensor=ttnn_tensor, + dim=dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + num_links=num_links, + output_mem_config=output_mem_config, + all_gather_topology=ttnn.Topology.Linear, + num_iter=num_iters, + ) + else: + for _ in range(num_iters): + if use_all_gather_async: + ttnn_tensor_out = ttnn.experimental.all_gather_async( + ttnn_tensor, + dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + topology=ttnn.Topology.Linear, + num_links=num_links, + memory_config=output_mem_config, + subdevice_id=worker_sub_device_id, + enable_persistent_fabric_mode=enable_persistent_fabric, + create_semaphore_handles=True, + ) + else: + ttnn_tensor_out = ttnn.all_gather( + ttnn_tensor, + dim=dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + num_links=num_links, + memory_config=output_mem_config, + topology=ttnn.Topology.Linear, + ) + + if enable_persistent_fabric: + logger.info(f"Waiting for op completion") + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d, sub_device_ids=worker_subdevice_ids) + logger.info(f"Done synchronizing with op") + + # ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor_out) + tt_output_tensor = ttnn.to_torch( + ttnn_tensor_out, + mesh_composer=ConcatMesh2dToTensor(mesh_device, mesh_shape=mesh_shape, dims=concat_dims), + sub_device_ids=worker_subdevice_ids, + ) + output_tensors_list = torch.chunk( + tt_output_tensor, num_all_gather_instances, dim=all_gather_instances_concat_dim + ) + output_golden = torch.zeros(tt_output_tensor.shape) + + # Repeat the input tensor to represent the fact that the full concatenated input tensor lives across every + # device in the line + repeat_factor = [1] * len(output_golden.shape) + repeat_factor[dim] = num_devices_per_line + output_golden[:, :, :, :] = full_input_tensor_unfractured.repeat(repeat_factor) + + eq = True + logger.info("Comparing output tensors") + if input_dtype == ttnn.bfloat16: + eq, output = comp_equal(tt_output_tensor, output_golden) + if not eq and debug is True: + logger.error(f"found mismatches") + report_mismatches(tt_output_tensor, output_golden, 100) + print_tile_corners_of_tensor(tt_output_tensor) + else: + eq, output = comp_pcc(tt_output_tensor, output_golden) + if not eq: + logger.error(f"output mismatch for tensor: {output}") - if enable_persistent_fabric: - logger.info(f"Waiting for op {i}") - for d in mesh_device.get_devices(): - ttnn.synchronize_device(d, sub_device_ids=[worker_sub_device_id]) - logger.info(f"Done iteration {i}") + logger.info("Done op call") - if enable_persistent_fabric and teardown_persistent_fabric: - logger.info("Tearing down persistent fabric interface") - teardown_fabric_interface(mesh_device) - logger.info("Done tearing down persistent fabric interface") + if enable_persistent_fabric and teardown_persistent_fabric: + logger.info("Tearing down persistent fabric interface") + teardown_fabric_interface(mesh_device) + logger.info("Done tearing down persistent fabric interface") + fabric_torn_down = True - # ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor_out) - tt_output_tensor = ttnn.to_torch( - ttnn_tensor_out, mesh_composer=ConcatMesh2dToTensor(mesh_device, mesh_shape=mesh_shape, dims=concat_dims) - ) - output_tensors_list = torch.chunk(tt_output_tensor, num_all_gather_instances, dim=all_gather_instances_concat_dim) - output_golden = torch.zeros(tt_output_tensor.shape) - - # Repeat the input tensor to represent the fact that the full concatenated input tensor lives across every - # device in the line - repeat_factor = [1] * len(output_golden.shape) - repeat_factor[dim] = num_devices_per_line - output_golden[:, :, :, :] = full_input_tensor_unfractured.repeat(repeat_factor) - - eq = True - if input_dtype == ttnn.bfloat16: - eq, output = comp_equal(tt_output_tensor, output_golden) - if not eq and debug is True: - logger.error(f"found mismatches") - report_mismatches(tt_output_tensor, output_golden, 100) - print_tile_corners_of_tensor(tt_output_tensor) - else: - eq, output = comp_pcc(tt_output_tensor, output_golden) - if not eq: - logger.error(f"output mismatch for tensor: {output}") - - assert eq, f"FAILED: {output}" + assert eq, f"FAILED: {output}" + + except Exception as e: + if create_persistent_fabric and not fabric_torn_down: + logger.error(f"Tearing down persistent fabric after failure") + teardown_fabric_interface(mesh_device) + raise e # Enumerate the post-commit cases explicitly diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py index 83ada0b11a1..a76912da0f6 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py @@ -196,16 +196,6 @@ def run_all_gather_impl( output_tensor[w, z, y : y + 32, x : x + 32] = tile_id tile_id += 1 - input_tensors = torch.chunk(output_tensor, num_devices, dim) - tt_input_tensors = [] - for i, t in enumerate(input_tensors): - tt_input_tensors.append( - ttnn.Tensor(t, input_dtype).to(layout).to(mesh_device.get_devices()[i], input_mem_config) - ) - logger.info(f"using device {mesh_device.get_devices()[i].id()}") - - input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) - compute_grid_size = mesh_device.compute_with_storage_grid_size() worker_sub_device = ttnn.SubDevice( [ @@ -215,68 +205,91 @@ def run_all_gather_impl( ] ) worker_sub_device_id = ttnn.SubDeviceId(0) + worker_subdevice_ids = [worker_sub_device_id] + fabric_torn_down = True if create_persistent_fabric: mesh_sub_device_manager_id = create_and_load_sub_device_manager_with_fabric_interface( mesh_device, [worker_sub_device], 0, 0, enable_persistent_fabric ) - - if trace_mode: - tt_out_tensor = run_with_trace( - mesh_device, - all_gather_topology, - input_tensor_mesh, - dim, - num_links, - output_mem_config, - num_iter=num_iters, - subdevice_id=worker_sub_device_id, - ) - else: - for i in range(num_iters): - if use_cluster_axis_api: - tt_out_tensor = ttnn.experimental.all_gather_async( - input_tensor_mesh, - dim, - cluster_axis=cluster_axis, - mesh_device=mesh_device, - memory_config=output_mem_config, - topology=all_gather_topology, - subdevice_id=worker_sub_device_id, - enable_persistent_fabric_mode=enable_persistent_fabric, - num_preferred_links=num_links, - create_semaphore_handles=True, - ) - - else: - tt_out_tensor = ttnn.experimental.all_gather_async( - input_tensor_mesh, - dim, - num_links=num_links, - memory_config=output_mem_config, - topology=all_gather_topology, - subdevice_id=worker_sub_device_id, - enable_persistent_fabric_mode=enable_persistent_fabric, - ) + fabric_torn_down = False + + try: + input_tensors = torch.chunk(output_tensor, num_devices, dim) + tt_input_tensors = [] + for i, t in enumerate(input_tensors): + tt_input_tensors.append( + ttnn.Tensor(t, input_dtype) + .to(layout) + .to(mesh_device.get_devices()[i], input_mem_config, sub_device_ids=worker_subdevice_ids) + ) + logger.info(f"using device {mesh_device.get_devices()[i].id()}") + + input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) + + if trace_mode: + tt_out_tensor = run_with_trace( + mesh_device, + all_gather_topology, + input_tensor_mesh, + dim, + num_links, + output_mem_config, + num_iter=num_iters, + subdevice_id=worker_sub_device_id, + ) + else: + for i in range(num_iters): + if use_cluster_axis_api: + tt_out_tensor = ttnn.experimental.all_gather_async( + input_tensor_mesh, + dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + memory_config=output_mem_config, + topology=all_gather_topology, + subdevice_id=worker_sub_device_id, + enable_persistent_fabric_mode=enable_persistent_fabric, + num_preferred_links=num_links, + create_semaphore_handles=True, + ) + + else: + tt_out_tensor = ttnn.experimental.all_gather_async( + input_tensor_mesh, + dim, + num_links=num_links, + memory_config=output_mem_config, + topology=all_gather_topology, + subdevice_id=worker_sub_device_id, + enable_persistent_fabric_mode=enable_persistent_fabric, + ) logger.info(f"Waiting for op {i}") for d in mesh_device.get_devices(): ttnn.synchronize_device(d, sub_device_ids=[worker_sub_device_id]) logger.info(f"Done iteration {i}") - if enable_persistent_fabric and teardown_persistent_fabric: - teardown_fabric_interface(mesh_device) + for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): + tt_output_tensor = t.cpu(sub_device_ids=worker_subdevice_ids).to(ttnn.ROW_MAJOR_LAYOUT).to_torch() + logger.info(f"Checking for device {t.device().id()}") - for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): - tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() - logger.info(f"Checking for device {t.device().id()}") + if input_dtype == ttnn.bfloat16: + eq, output = comp_equal(tt_output_tensor, output_tensor) + else: + eq, output = comp_pcc(tt_output_tensor, output_tensor) + if not eq: + logger.error(f"output mismatch for tensor {i}") + assert eq, f"{i} FAILED: {output}" - if input_dtype == ttnn.bfloat16: - eq, output = comp_equal(tt_output_tensor, output_tensor) - else: - eq, output = comp_pcc(tt_output_tensor, output_tensor) - if not eq: - logger.error(f"output mismatch for tensor {i}") - assert eq, f"{i} FAILED: {output}" + if enable_persistent_fabric and teardown_persistent_fabric: + teardown_fabric_interface(mesh_device) + fabric_torn_down = True + + except Exception as e: + if create_persistent_fabric and not fabric_torn_down: + logger.error(f"Tearing down persistent fabric after failure") + teardown_fabric_interface(mesh_device) + raise e # Enumerate the post-commit cases explicitly @@ -301,8 +314,7 @@ def run_all_gather_impl( ttnn.MemoryConfig(buffer_type=ttnn.BufferType.DRAM), ], ) -@pytest.mark.parametrize("num_iters", [8]) -@pytest.mark.parametrize("enable_async", [True]) +@pytest.mark.parametrize("enable_async, num_iters", [(True, 1), (False, 8)]) def test_all_gather( t3k_mesh_device, # pcie_mesh_device, @@ -395,8 +407,7 @@ def test_all_gather( ttnn.bfloat16, ], ) -@pytest.mark.parametrize("num_iters", [8]) -@pytest.mark.parametrize("enable_async", [True]) +@pytest.mark.parametrize("enable_async, num_iters", [(True, 1), (False, 8)]) def test_all_gather_sharded( t3k_mesh_device, # pcie_mesh_device, @@ -499,9 +510,6 @@ def test_line_all_gather_async_on_T3K_cols_transient_fabric_post_commit( ) -@pytest.mark.skip( - "persistent fabric test with cluster-axis API and multiple concurrent all-gather instances not enabled yet" -) @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize( "num_devices, num_links, per_chip_output_shape, dim, layout", @@ -636,9 +644,6 @@ def test_line_all_gather_async_on_T3K_rows_transient_fabric_post_commit( # Enumerate the post-commit cases explicitly -@pytest.mark.skip( - "persistent fabric test with cluster-axis API and multiple concurrent all-gather instances not enabled yet" -) @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize( "num_devices, num_links, per_chip_output_shape, dim, layout", diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index 65a902d11cf..6b8ab1fd000 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -400,7 +400,7 @@ def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: T if self.dims[0] == self.dims[1]: raise ValueError("Both dimensions in 'dims' must be different") - def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": + def compose(self, tensor: ttnn.Tensor, sub_device_ids: List[ttnn.SubDeviceId] = []) -> "torch.Tensor": """ Compose the sharded tensors back into a single tensor. @@ -416,7 +416,8 @@ def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": import torch device_shards = [ - ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) + ttnn.to_torch(tt_input_tensor, mesh_composer=None, sub_device_ids=sub_device_ids) + for tt_input_tensor in ttnn.get_device_tensors(tensor) ] rows, cols = self.mesh_shape diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index fa1aba4cfb7..a94c1ffe4e7 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -302,7 +302,7 @@ def to_torch( [ 0.9023, -0.5820, 0.5312]], dtype=torch.bfloat16) """ if mesh_composer: - return mesh_composer.compose(tensor) + return mesh_composer.compose(tensor, sub_device_ids=sub_device_ids) if ttnn.is_tensor_storage_on_device(tensor): tensor = ttnn.from_device(tensor, cq_id=cq_id, sub_device_ids=sub_device_ids) From ca046284cf8d2dbb5b97d5c47dd7e35cae42312d Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Thu, 2 Jan 2025 23:04:37 +0000 Subject: [PATCH 2/7] Fix all-gather global semaphore (fake) lockstep allocator bug --- .../operations/ccl/test_new_all_gather.py | 5 +- .../experimental/ccl/CMakeLists.txt | 1 + .../device/all_gather_async_op.cpp | 51 ++++--------- .../device/all_gather_async_op.hpp | 2 +- .../common/ccl_global_semaphore_creation.cpp | 63 ++++++++++++++++ .../common/ccl_global_semaphore_creation.hpp | 19 +++++ .../device/reduce_scatter_async_op.cpp | 73 ++++--------------- 7 files changed, 111 insertions(+), 103 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/common/ccl_global_semaphore_creation.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/common/ccl_global_semaphore_creation.hpp diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py index a76912da0f6..2671ee3b30c 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py @@ -711,9 +711,6 @@ def test_line_all_gather_async_on_T3K_rows_persistent_fabric_post_commit( ) -@pytest.mark.skip( - "persistent fabric test with cluster-axis API and multiple concurrent all-gather instances not enabled yet" -) @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize( "num_devices1, num_links1, per_chip_output_shape1, dim1, layout1", @@ -738,7 +735,7 @@ def test_line_all_gather_async_on_T3K_rows_persistent_fabric_post_commit( ], ) @pytest.mark.parametrize("replication_factor1", [4]) -@pytest.mark.parametrize("enable_async", [True]) +@pytest.mark.parametrize("enable_async", [False]) @pytest.mark.parametrize( "num_devices2, num_links2, per_chip_output_shape2, dim2, layout2", [ diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt b/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt index 851de9b13ad..3f9594bd0f7 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt @@ -16,6 +16,7 @@ set(CCL_EXPERIMENTAL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_async/all_gather_async_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_async/device/all_gather_async_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_async/device/all_gather_async_program.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/common/ccl_global_semaphore_creation.cpp CACHE INTERNAL "CCL Experimental sources to reuse in ttnn build" ) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp index 5815d33c3b1..e96f04b6697 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp @@ -5,6 +5,7 @@ #include "all_gather_async_op.hpp" #include "ttnn/operations/math.hpp" #include "ttnn/cpp/ttnn/global_semaphore.hpp" +#include "ttnn/cpp/ttnn/operations/experimental/ccl/common/ccl_global_semaphore_creation.hpp" #include "tt_metal/host_api.hpp" @@ -23,7 +24,7 @@ AllGatherAsync create_all_gather_async_struct( const std::optional& memory_config, const std::vector& devices, const ttnn::ccl::Topology topology, - const std::optional>& semaphores, + const std::optional>>& semaphores, bool enable_persistent_fabric_mode) { uint32_t num_devices = devices.size(); @@ -35,7 +36,7 @@ AllGatherAsync create_all_gather_async_struct( if (devices.at(i) == input_tensor.device()) { device_index = i; if (semaphores.has_value()) { - semaphore = semaphores.value().at(i); // Get raw pointer + semaphore = *semaphores.value().at(i); // Get raw pointer } if (i != 0) { backward_device = devices.at(i - 1); @@ -59,37 +60,6 @@ AllGatherAsync create_all_gather_async_struct( enable_persistent_fabric_mode}; } -std::optional> get_global_semaphores( - const std::vector& devices, - const CoreRange& core_range, - std::optional subdevice_id, - bool create_semaphore_handles) { - std::optional> semaphores_opt; - if (create_semaphore_handles) { - std::vector semaphores; - for (const auto& device : devices) { - auto worker_subdevice_id = - subdevice_id.has_value() ? std::vector{subdevice_id.value()} : std::vector{}; - - auto sem = - global_semaphore::create_global_semaphore(device, core_range, 0, BufferType::L1, worker_subdevice_id); - log_trace(tt::LogOp, "Created semaphore at address {} for device {}", sem.address(), device->id()); - semaphores.push_back(std::move(sem)); - } - // HACK: assert every address is the same - TT_FATAL( - std::all_of( - semaphores.begin(), - semaphores.end(), - [&](const auto& sem) { return sem.address() == semaphores.front().address(); }), - "[Hack] All semaphores should have the same address"); - semaphores_opt = std::move(semaphores); - } else { - semaphores_opt = std::nullopt; - } - - return semaphores_opt; -} } // namespace all_gather_detail } // namespace ccl @@ -217,10 +187,13 @@ Tensor all_gather_async( // create this semaphore for all cores since we don't know which core will be used for teardown draining CoreCoord grid_size = devices[0]->compute_with_storage_grid_size(); - auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); + auto core_grid = CoreRangeSet({CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1})}); + - std::optional> semaphores_opt = - ttnn::ccl::all_gather_detail::get_global_semaphores(devices, core_grid, subdevice_id, create_semaphore_handles); + std::optional>> semaphores_opt; + if (create_semaphore_handles) { + semaphores_opt = ttnn::ccl::worker_detail::create_global_semaphores(devices, core_grid, subdevice_id); + } operation::launch_op( [dim, @@ -285,8 +258,10 @@ Tensor all_gather_async( std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; CoreCoord grid_size = devices[0]->compute_with_storage_grid_size(); auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); - std::optional> semaphores_opt = - ttnn::ccl::all_gather_detail::get_global_semaphores(devices, core_grid, subdevice_id, create_semaphore_handles); + std::optional>> semaphore_handles_opt; + if (create_semaphore_handles) { + semaphore_handles_opt = ttnn::ccl::worker_detail::create_global_semaphores(devices, core_grid, subdevice_id); + } operation::launch_op( [gather_dim, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp index d8b7a9c6648..94e6a458abb 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp @@ -92,7 +92,7 @@ AllGatherAsync create_all_gather_async_struct( const std::optional& memory_config, const std::vector& devices, const ccl::Topology topology, - const std::optional>& semaphores, + const std::optional>& semaphore_handles, bool enable_persistent_fabric_mode); } // namespace all_gather_async_detail } // namespace ccl diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/common/ccl_global_semaphore_creation.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/common/ccl_global_semaphore_creation.cpp new file mode 100644 index 00000000000..fc87069edd0 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/common/ccl_global_semaphore_creation.cpp @@ -0,0 +1,63 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/ttnn/operations/experimental/ccl/common/ccl_global_semaphore_creation.hpp" +#include "ttnn/cpp/ttnn/global_semaphore.hpp" + +namespace ttnn::ccl::worker_detail { + +std::vector> create_global_semaphores( + const std::vector& devices, + const CoreRangeSet& worker_cores, + std::optional worker_subdevice_id_opt) { + std::vector> semaphores; + for (Device* d : devices) { + CoreCoord grid_size = devices[0]->compute_with_storage_grid_size(); + auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); + auto worker_subdevice_id = worker_subdevice_id_opt.has_value() + ? std::vector{worker_subdevice_id_opt.value()} + : std::vector{}; + auto sem = std::make_shared( + global_semaphore::create_global_semaphore(d, core_grid, 0, BufferType::L1, worker_subdevice_id)); + semaphores.push_back(sem); + } + + auto first_addr = semaphores.front()->address(); + bool all_same = std::all_of( + semaphores.begin(), semaphores.end(), [first_addr](const auto& sem) { return sem->address() == first_addr; }); + + if (!all_same) { + DeviceAddr lowest_addr = semaphores.front()->address(); + for (auto i = 1; i < semaphores.size(); i++) { + if (semaphores[i]->address() < lowest_addr) { + lowest_addr = semaphores[i]->address(); + } + }; + for (auto i = 0; i < semaphores.size(); i++) { + size_t attempts = 1000; + size_t attempt = 0; + std::vector> garbage; + while (semaphores[i]->address() != lowest_addr) { + auto worker_subdevice_id = worker_subdevice_id_opt.has_value() + ? std::vector{worker_subdevice_id_opt.value()} + : std::vector{}; + auto sem = std::make_shared( + global_semaphore::create_global_semaphore(devices[i], worker_cores, 0, BufferType::L1, worker_subdevice_id)); + if (sem->address() == lowest_addr) { + semaphores[i] = sem; + } else { + garbage.push_back(std::move(sem)); + attempt++; + } + + if (attempt > attempts) { + TT_THROW("Failed to create global semaphores with the same address"); + } + } + } + } + return semaphores; +} + +} // namespace ttnn::ccl::worker_detail diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/common/ccl_global_semaphore_creation.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/common/ccl_global_semaphore_creation.hpp new file mode 100644 index 00000000000..c4eaa425c59 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/common/ccl_global_semaphore_creation.hpp @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "tt_metal/host_api.hpp" +#include "tt_metal/impl/buffers/global_semaphore.hpp" + +#include +#include + +namespace ttnn::ccl::worker_detail { + +std::vector> create_global_semaphores( + const std::vector& devices, + const CoreRangeSet& worker_cores, + std::optional worker_subdevice_id_opt = std::nullopt); +} diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp index 8092a445b0c..d9aabed3f30 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp @@ -7,6 +7,8 @@ #include "tt_metal/host_api.hpp" #include "ttnn/cpp/ttnn/global_semaphore.hpp" +#include "ttnn/cpp/ttnn/operations/experimental/ccl/common/ccl_global_semaphore_creation.hpp" + #include #include #include @@ -220,62 +222,6 @@ ttnn::operations::binary::BinaryOpType convert_reduce_type_to_eltwise_type( } // namespace CMAKE_UNIQUE_NAMESPACE } // namespace -std::vector> create_global_semaphores( - const std::vector& devices, std::optional worker_subdevice_id_opt = std::nullopt) { - std::vector> semaphores; - auto worker_cores = CoreRangeSet(CoreRange(CoreCoord(0, 0), CoreCoord(6, 6))); - for (Device* d : devices) { - CoreCoord grid_size = devices[0]->compute_with_storage_grid_size(); - auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); - auto worker_subdevice_id = worker_subdevice_id_opt.has_value() - ? std::vector{worker_subdevice_id_opt.value()} - : std::vector{}; - // TODO: Remove shared_ptr - auto sem = std::make_shared( - global_semaphore::create_global_semaphore(d, core_grid, 0, BufferType::L1, worker_subdevice_id)); - semaphores.push_back(sem); - } - - auto first_addr = semaphores.front()->address(); - bool all_same = std::all_of( - semaphores.begin(), semaphores.end(), [first_addr](const auto& sem) { return sem->address() == first_addr; }); - - if (!all_same) { - DeviceAddr highest_addr = semaphores.front()->address(); - for (auto i = 1; i < semaphores.size(); i++) { - if (semaphores[i]->address() > highest_addr) { - highest_addr = semaphores[i]->address(); - } - }; - for (auto i = 0; i < semaphores.size(); i++) { - size_t attempts = 1000; - size_t attempt = 0; - std::vector> garbage; - CoreCoord grid_size = devices[0]->compute_with_storage_grid_size(); - auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); - while (semaphores[i]->address() != highest_addr) { - auto worker_subdevice_id = worker_subdevice_id_opt.has_value() - ? std::vector{worker_subdevice_id_opt.value()} - : std::vector{}; - // TODO: Remove shared_ptr - auto sem = std::make_shared(global_semaphore::create_global_semaphore( - devices[i], core_grid, 0, BufferType::L1, worker_subdevice_id)); - if (sem->address() == highest_addr) { - semaphores[i] = sem; - } else { - garbage.push_back(std::move(sem)); - attempt++; - } - - if (attempt > attempts) { - TT_THROW("Failed to create global semaphores with the same address"); - } - } - } - } - return semaphores; -} - namespace operations { namespace experimental { namespace ccl { @@ -313,9 +259,13 @@ Tensor reduce_scatter( std::optional>> from_remote_inputs_semaphores_opt; std::optional>> to_remote_inputs_semaphores_opt; + CoreCoord grid_size = devices[0]->compute_with_storage_grid_size(); + auto worker_cores = CoreRangeSet({CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1})}); if (create_semaphore_handles) { - from_remote_inputs_semaphores_opt = create_global_semaphores(devices, worker_subdevice_id_opt); - to_remote_inputs_semaphores_opt = create_global_semaphores(devices, worker_subdevice_id_opt); + from_remote_inputs_semaphores_opt = + ttnn::ccl::worker_detail::create_global_semaphores(devices, worker_cores, worker_subdevice_id_opt); + to_remote_inputs_semaphores_opt = + ttnn::ccl::worker_detail::create_global_semaphores(devices, worker_cores, worker_subdevice_id_opt); } else { from_remote_inputs_semaphores_opt = std::nullopt; to_remote_inputs_semaphores_opt = std::nullopt; @@ -390,9 +340,12 @@ Tensor reduce_scatter( auto devices = input_tensor.get_workers(); std::optional>> from_remote_inputs_semaphores_opt; std::optional>> to_remote_inputs_semaphores_opt; + auto worker_cores = CoreRangeSet(CoreRange(CoreCoord(0, 0), CoreCoord(6, 6))); if (create_semaphore_handles) { - from_remote_inputs_semaphores_opt = create_global_semaphores(devices, worker_subdevice_id_opt); - to_remote_inputs_semaphores_opt = create_global_semaphores(devices, worker_subdevice_id_opt); + from_remote_inputs_semaphores_opt = + ttnn::ccl::worker_detail::create_global_semaphores(devices, worker_cores, worker_subdevice_id_opt); + to_remote_inputs_semaphores_opt = + ttnn::ccl::worker_detail::create_global_semaphores(devices, worker_cores, worker_subdevice_id_opt); } else { from_remote_inputs_semaphores_opt = std::nullopt; to_remote_inputs_semaphores_opt = std::nullopt; From 016437266da9bb375a7c4d2c98af34e95ef49e81 Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Thu, 2 Jan 2025 23:16:50 +0000 Subject: [PATCH 3/7] update missed methods --- ttnn/ttnn/distributed/distributed.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index 6b8ab1fd000..d9c35b2ea3b 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -452,11 +452,12 @@ def __init__(self, mesh_device: MeshDevice, dim: int): self.concat_dim = dim self.mesh_device = mesh_device - def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": + def compose(self, tensor: ttnn.Tensor, sub_device_ids: List[ttnn.SubDeviceId] = []) -> "torch.Tensor": import torch device_shards_converted_to_torch = [ - ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) + ttnn.to_torch(tt_input_tensor, mesh_composer=None, sub_device_ids=sub_device_ids) + for tt_input_tensor in ttnn.get_device_tensors(tensor) ] return torch.cat(device_shards_converted_to_torch, dim=self.concat_dim) @@ -467,9 +468,10 @@ class ListMeshToTensor(MeshToTensor): def __init__(self, mesh_device: MeshDevice): self.mesh_device = mesh_device - def compose(self, tensor: ttnn.Tensor) -> List["torch.Tensor"]: + def compose(self, tensor: ttnn.Tensor, sub_device_ids: List[ttnn.SubDeviceId] = []) -> List["torch.Tensor"]: return [ - ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) + ttnn.to_torch(tt_input_tensor, mesh_composer=None, sub_device_ids=sub_device_ids) + for tt_input_tensor in ttnn.get_device_tensors(tensor) ] From ac19676240277bce71f0070176fc89c4807d21bf Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Fri, 3 Jan 2025 03:46:50 +0000 Subject: [PATCH 4/7] one more missed method --- models/demos/t3000/llama2_70b/tt/llama_common.py | 7 +++++-- ttnn/ttnn/distributed/distributed.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/models/demos/t3000/llama2_70b/tt/llama_common.py b/models/demos/t3000/llama2_70b/tt/llama_common.py index 63c8aad8233..c3b922173c0 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_common.py +++ b/models/demos/t3000/llama2_70b/tt/llama_common.py @@ -70,8 +70,11 @@ def __init__(self, mesh_device, dims, cluster_shape): self.cluster_shape = cluster_shape self.mesh_device = mesh_device - def compose(self, tensor: ttnn.Tensor) -> torch.Tensor: - tt_shards = [ttnn.to_torch(tt_input_tensor) for tt_input_tensor in ttnn.get_device_tensors(tensor)] + def compose(self, tensor: ttnn.Tensor, sub_device_ids: List[ttnn.SubDeviceId] = []) -> torch.Tensor: + tt_shards = [ + ttnn.to_torch(tt_input_tensor, sub_device_ids=sub_device_ids) + for tt_input_tensor in ttnn.get_device_tensors(tensor) + ] row_concat = [] for cluster_row in range(self.cluster_shape[1]): diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index d9c35b2ea3b..97b7fad4663 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -253,7 +253,7 @@ class MeshToTensor: You can also "Bring your own MeshToTensor" based on your custom mapping. """ - def compose(self, tensor: ttnn.Tensor): + def compose(self, tensor: ttnn.Tensor, sub_device_ids: List[ttnn.SubDeviceId] = []): raise NotImplementedError("Subclasses must implement this method") From f1b53f0a6fe31e796f613f8631979a857caca018 Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Fri, 3 Jan 2025 04:56:27 +0000 Subject: [PATCH 5/7] fix typing --- models/demos/t3000/llama2_70b/tt/llama_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/demos/t3000/llama2_70b/tt/llama_common.py b/models/demos/t3000/llama2_70b/tt/llama_common.py index c3b922173c0..7e70bd64041 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_common.py +++ b/models/demos/t3000/llama2_70b/tt/llama_common.py @@ -6,7 +6,7 @@ import math from loguru import logger import re -from typing import Tuple +from typing import Tuple, List import numpy as np import torch import ttnn From 3ff6e59d5150622a3bbf6b5d499f703c95a479e4 Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Fri, 3 Jan 2025 15:55:12 +0000 Subject: [PATCH 6/7] fixes after rebase conflicts --- .../ccl/all_gather_async/device/all_gather_async_op.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp index e96f04b6697..8e4f104644c 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp @@ -258,9 +258,9 @@ Tensor all_gather_async( std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; CoreCoord grid_size = devices[0]->compute_with_storage_grid_size(); auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); - std::optional>> semaphore_handles_opt; + std::optional>> semaphores_opt; if (create_semaphore_handles) { - semaphore_handles_opt = ttnn::ccl::worker_detail::create_global_semaphores(devices, core_grid, subdevice_id); + semaphores_opt = ttnn::ccl::worker_detail::create_global_semaphores(devices, core_grid, subdevice_id); } operation::launch_op( From 0044f0446f6fc9e66d0ae16835726b6eee8e35b4 Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Fri, 3 Jan 2025 16:04:55 +0000 Subject: [PATCH 7/7] trim trailing white space --- .../reduce_scatter_async/device/reduce_scatter_async_op.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp index d9aabed3f30..7b1dbe02d49 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp @@ -342,9 +342,9 @@ Tensor reduce_scatter( std::optional>> to_remote_inputs_semaphores_opt; auto worker_cores = CoreRangeSet(CoreRange(CoreCoord(0, 0), CoreCoord(6, 6))); if (create_semaphore_handles) { - from_remote_inputs_semaphores_opt = + from_remote_inputs_semaphores_opt = ttnn::ccl::worker_detail::create_global_semaphores(devices, worker_cores, worker_subdevice_id_opt); - to_remote_inputs_semaphores_opt = + to_remote_inputs_semaphores_opt = ttnn::ccl::worker_detail::create_global_semaphores(devices, worker_cores, worker_subdevice_id_opt); } else { from_remote_inputs_semaphores_opt = std::nullopt;