From ef534f0216d9881da2d5d4919651e1b1f013747b Mon Sep 17 00:00:00 2001 From: Jack Cai Date: Mon, 13 Jan 2025 21:22:51 +0000 Subject: [PATCH 1/3] #0: added ccl async tests to TG nightly CI and fixed all gather core assignment + teardown --- .../tg/ccl/test_all_gather_async_nightly.py | 1 + .../ccl/test_reduce_scatter_async_nightly.py | 1 + tests/scripts/tg/run_tg_nightly_tests.sh | 3 +- .../ccl/test_all_gather_TG_post_commit.py | 8 +- .../ccl/test_all_gather_async_TG_nightly.py | 324 ++++++++++++++++++ .../ccl/test_reduce_scatter_TG_nightly.py | 2 + .../test_reduce_scatter_async_TG_nightly.py | 154 +++++++++ .../device/all_gather_async_program.cpp | 25 +- 8 files changed, 498 insertions(+), 20 deletions(-) create mode 120000 tests/nightly/tg/ccl/test_all_gather_async_nightly.py create mode 120000 tests/nightly/tg/ccl/test_reduce_scatter_async_nightly.py create mode 100644 tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py create mode 100644 tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py diff --git a/tests/nightly/tg/ccl/test_all_gather_async_nightly.py b/tests/nightly/tg/ccl/test_all_gather_async_nightly.py new file mode 120000 index 00000000000..24cc55db361 --- /dev/null +++ b/tests/nightly/tg/ccl/test_all_gather_async_nightly.py @@ -0,0 +1 @@ +/proj_sw/user_dev/xuncai/tt-metal/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py \ No newline at end of file diff --git a/tests/nightly/tg/ccl/test_reduce_scatter_async_nightly.py b/tests/nightly/tg/ccl/test_reduce_scatter_async_nightly.py new file mode 120000 index 00000000000..31f5aed40d5 --- /dev/null +++ b/tests/nightly/tg/ccl/test_reduce_scatter_async_nightly.py @@ -0,0 +1 @@ +/proj_sw/user_dev/xuncai/tt-metal/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py \ No newline at end of file diff --git a/tests/scripts/tg/run_tg_nightly_tests.sh b/tests/scripts/tg/run_tg_nightly_tests.sh index 89e5c253c7c..d3f23a6a50c 100755 --- a/tests/scripts/tg/run_tg_nightly_tests.sh +++ b/tests/scripts/tg/run_tg_nightly_tests.sh @@ -7,8 +7,7 @@ run_tg_llama3_70b_tests() { echo "LOG_METAL: Running run_tg_llama3_70b_tests" - pytest tests/nightly/tg/ccl/test_all_gather_nightly.py ; fail+=$? - pytest tests/nightly/tg/ccl/test_reduce_scatter_nightly.py ; fail+=$? + pytest -n auto tests/nightly/tg/ccl --timeout=180 ; fail+=$? # Falcon40B prefill 60 layer end to end with 10 loops; we need 8x8 grid size pytest tests/nightly/tg/models/demos/tg/llama3_70b ; fail+=$? diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py index 6bd93f99eb5..8162b81bf2e 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py @@ -254,10 +254,10 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( topology=ttnn.Topology.Linear, ) - if enable_persistent_fabric: - logger.info(f"Waiting for op") - ttnn.synchronize_devices(mesh_device, sub_device_ids=sub_device_stall_group) - logger.info(f"Done iteration") + if enable_persistent_fabric: + logger.info(f"Waiting for op") + ttnn.synchronize_devices(mesh_device, sub_device_ids=sub_device_stall_group) + logger.info(f"Done iteration") if enable_persistent_fabric and teardown_persistent_fabric: logger.info("Tearing down persistent fabric interface") diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py new file mode 100644 index 00000000000..431e0151048 --- /dev/null +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py @@ -0,0 +1,324 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +from loguru import logger +import ttnn +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc +from models.utility_functions import skip_for_grayskull +from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import ( + create_and_load_sub_device_manager_with_fabric_interface, + teardown_fabric_interface, + create_global_semaphore_with_same_address, +) + +from tests.ttnn.unit_tests.operations.ccl.test_all_gather_TG_post_commit import ( + run_line_all_gather_on_TG_with_mesh_tensor_along_rows, +) + +from tests.ttnn.unit_tests.operations.ccl.test_new_all_gather import ( + run_all_gather_impl, +) + + +# Enumerate the post-commit cases explicitly +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize( + "num_devices, num_links", + [(4, 1)], + # [(4, 3)], Multi-links fails +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + ttnn.bfloat8_b, + ], +) +@pytest.mark.parametrize("shard_grid_orientation", [ttnn.ShardOrientation.ROW_MAJOR]) +@pytest.mark.parametrize( + "tensor_mem_layout,per_chip_output_shape, dim, input_shard_shape,shard_grid,layout", + ( + # LLama + ( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + (1, 1, 32, 1024 * 4), + 3, + (32, 32), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), + ttnn.TILE_LAYOUT, + ), + ( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + (4, 1, 32, 1280), + 0, + (32, 32), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 4))}), + ttnn.TILE_LAYOUT, + ), + ), +) +@pytest.mark.parametrize("replication_factor", [8]) +@pytest.mark.parametrize("enable_async", [True]) +@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) +def test_line_all_gather_sharded_on_TG_rows_post_commit( + mesh_device, + num_devices, + per_chip_output_shape, + input_shard_shape, + shard_grid, + shard_grid_orientation, + tensor_mem_layout, + dim, + num_links, + input_dtype, + layout, + use_program_cache, + function_level_defaults, + enable_async, + replication_factor, + num_iters=1, +): + if len(mesh_device.get_devices()) != 32: + pytest.skip("Not TG!") + if input_dtype == ttnn.bfloat16 and per_chip_output_shape == (1, 1, 32, 1024 * 4): + pytest.skip("Skiped due to hang") + input_shard_spec = ttnn.ShardSpec( + shard_grid, + input_shard_shape, + shard_grid_orientation, + ) + run_line_all_gather_on_TG_with_mesh_tensor_along_rows( + mesh_device, + num_devices, + per_chip_output_shape, + tensor_mem_layout, + dim, + num_links, + input_dtype, + layout, + ttnn.BufferType.L1, + use_program_cache, + function_level_defaults, + enable_async=enable_async, + input_shard_spec=input_shard_spec, + num_iters=num_iters, + num_all_gather_instances=replication_factor, + cluster_axis=1, + use_all_gather_async=True, + enable_persistent_fabric=True, + create_persistent_fabric=True, + teardown_persistent_fabric=True, + ) + + +# Enumerate the post-commit cases explicitly +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize( + "num_devices, num_links", + [ + (8, 1), + ], + # [(8, 4), (8, 3), (8, 2)], Multi-links fails +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + ttnn.bfloat8_b, + ], +) +@pytest.mark.parametrize("shard_grid_orientation", [ttnn.ShardOrientation.ROW_MAJOR]) +@pytest.mark.parametrize( + "tensor_mem_layout, input_shape, dim, input_shard_shape,shard_grid,layout", + ( + ( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + (8, 1, 32, 2048), + 0, + (32, 64), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), + ttnn.TILE_LAYOUT, + ), + ( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + (1, 8, 32, 2048), + 1, + (32, 64), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), + ttnn.TILE_LAYOUT, + ), + ( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + (1, 1, 256, 2048), + 2, + (32, 64), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), + ttnn.TILE_LAYOUT, + ), + ( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + (1, 1, 32, 16384), + 3, + (32, 64), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), + ttnn.TILE_LAYOUT, + ), + ( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + (8, 1, 2048, 32), + 0, + (64, 32), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), + ttnn.TILE_LAYOUT, + ), + ( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + (1, 8, 2048, 32), + 1, + (64, 32), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), + ttnn.TILE_LAYOUT, + ), + ( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + (1, 1, 16384, 32), + 2, + (64, 32), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), + ttnn.TILE_LAYOUT, + ), + ( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + (1, 1, 2048, 256), + 3, + (64, 32), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), + ttnn.TILE_LAYOUT, + ), + ), +) +@pytest.mark.parametrize("replication_factor", [4]) +@pytest.mark.parametrize("enable_async", [True]) +@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) +def test_line_all_gather_sharded_on_TG_cols_post_commit( + mesh_device, + num_devices, + input_shape, + input_shard_shape, + shard_grid, + shard_grid_orientation, + tensor_mem_layout, + dim, + num_links, + input_dtype, + layout, + use_program_cache, + function_level_defaults, + enable_async, + replication_factor, + num_iters=1, +): + if len(mesh_device.get_devices()) != 32: + pytest.skip("Not TG!") + if input_dtype == ttnn.bfloat16 and input_shape == (1, 1, 256, 2048): + pytest.skip("Skiped due to hang") + input_shard_spec = ttnn.ShardSpec( + shard_grid, + input_shard_shape, + shard_grid_orientation, + ) + + run_line_all_gather_on_TG_with_mesh_tensor_along_rows( + mesh_device, + num_devices, + input_shape, + tensor_mem_layout, + dim, + num_links, + input_dtype, + layout, + ttnn.BufferType.L1, + use_program_cache, + function_level_defaults, + enable_async=enable_async, + num_iters=num_iters, + input_shard_spec=input_shard_spec, + num_all_gather_instances=replication_factor, + cluster_axis=0, + use_all_gather_async=True, + enable_persistent_fabric=True, + create_persistent_fabric=True, + teardown_persistent_fabric=True, + ) + + +# Enumerate the post-commit cases explicitly +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize( + "num_devices, num_links, per_chip_output_shape, dim, layout", + [ + (8, 1, [1, 8, 32, 1280], 1, ttnn.TILE_LAYOUT), + (8, 1, [8, 1, 32, 1280], 0, ttnn.TILE_LAYOUT), + (8, 1, [1, 8, 32, 2048], 1, ttnn.TILE_LAYOUT), + (8, 1, [1, 8, 32, 2304], 1, ttnn.TILE_LAYOUT), + (8, 1, [1, 8, 32, 4096], 1, ttnn.TILE_LAYOUT), + ], +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + ttnn.bfloat8_b, + ], +) +@pytest.mark.parametrize( + "buffer_type", + [ + ttnn.BufferType.DRAM, + ttnn.BufferType.L1, + ], +) +@pytest.mark.parametrize("replication_factor", [4]) +@pytest.mark.parametrize("enable_async", [True]) +@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) +def test_line_all_gather_on_TG_cols_nightly( + mesh_device, + num_devices, + per_chip_output_shape, + dim, + num_links, + input_dtype, + layout, + buffer_type, + use_program_cache, + function_level_defaults, + enable_async, + replication_factor, + num_iters=1, +): + if len(mesh_device.get_devices()) != 32: + pytest.skip("Not TG!") + run_line_all_gather_on_TG_with_mesh_tensor_along_rows( + mesh_device, + num_devices, + per_chip_output_shape, + ttnn.TensorMemoryLayout.INTERLEAVED, + dim, + num_links, + input_dtype, + layout, + buffer_type, + use_program_cache, + function_level_defaults, + enable_async=enable_async, + num_iters=num_iters, + num_all_gather_instances=replication_factor, + cluster_axis=0, + use_all_gather_async=True, + enable_persistent_fabric=True, + create_persistent_fabric=True, + teardown_persistent_fabric=True, + ) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py index 63c9fd1bd4e..ef0c2193e28 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py @@ -279,6 +279,8 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows( memory_config=output_mem_config, topology=ttnn.Topology.Linear, ) + if enable_persistent_fabric: + ttnn.synchronize_devices(mesh_device, sub_device_ids=sub_device_stall_group) ttnn.synchronize_devices(mesh_device, sub_device_ids=sub_device_stall_group) if enable_persistent_fabric and teardown_persistent_fabric: diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py new file mode 100644 index 00000000000..5f3695e62ed --- /dev/null +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py @@ -0,0 +1,154 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +from loguru import logger +import ttnn +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc +from models.utility_functions import skip_for_grayskull +from tests.ttnn.unit_tests.operations.ccl.test_reduce_scatter_TG_nightly import ( + run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows, +) +from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import ( + create_and_load_sub_device_manager_with_fabric_interface, + teardown_fabric_interface, + create_global_semaphore_with_same_address, +) + + +# Enumerate the post-commit cases explicitly +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize( + "num_devices, num_links, per_chip_output_shape, dim, layout", + [ + (4, 1, [1, 4, 32, 2304], 1, ttnn.TILE_LAYOUT), + ], +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + # ttnn.bfloat8_b, + ], +) +@pytest.mark.parametrize( + "buffer_type", + [ + ttnn.BufferType.DRAM, + ttnn.BufferType.L1, + ], +) +@pytest.mark.parametrize("replication_factor", [8]) # 1, 8]) +@pytest.mark.parametrize("enable_async", [True]) +@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) +@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) +@pytest.mark.parametrize("device_params", [{"trace_region_size": 10281600}], indirect=True) +def test_line_reduce_scatter_on_TG_rows_post_commit( + mesh_device, + num_devices, + per_chip_output_shape, + dim, + num_links, + math_op, + input_dtype, + layout, + buffer_type, + use_program_cache, + function_level_defaults, + enable_async, + replication_factor, + num_iters=16, +): + if len(mesh_device.get_devices()) != 32: + pytest.skip("Not TG!") + run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows( + mesh_device, + num_devices, + per_chip_output_shape, + ttnn.TensorMemoryLayout.INTERLEAVED, + dim, + num_links, + math_op, + input_dtype, + layout, + buffer_type, + use_program_cache, + function_level_defaults, + enable_async=enable_async, + num_iters=num_iters, + num_reduce_scatter_instances=replication_factor, + cluster_axis=1, + use_reduce_scatter_async=True, + enable_persistent_fabric=True, + create_persistent_fabric=True, + teardown_persistent_fabric=True, + ) + + +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize( + "num_devices, num_links, per_chip_output_shape, dim, layout", + [ + (8, 1, [1, 8, 32, 1280], 1, ttnn.TILE_LAYOUT), + (8, 1, [8, 1, 32, 1280], 0, ttnn.TILE_LAYOUT), + ], +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + ], +) +@pytest.mark.parametrize( + "buffer_type", + [ + ttnn.BufferType.DRAM, + ], +) +@pytest.mark.parametrize("enable_async", [True]) +@pytest.mark.parametrize("replication_factor", [4]) +@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) +@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) +def test_line_reduce_scatter_on_TG_cols_post_commit( + mesh_device, + num_devices, + per_chip_output_shape, + dim, + num_links, + math_op, + input_dtype, + layout, + buffer_type, + use_program_cache, + function_level_defaults, + enable_async, + replication_factor, + num_iters=16, +): + if len(mesh_device.get_devices()) != 32: + pytest.skip("Not TG!") + + run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows( + mesh_device, + num_devices, + per_chip_output_shape, + ttnn.TensorMemoryLayout.INTERLEAVED, + dim, + num_links, + math_op, + input_dtype, + layout, + buffer_type, + use_program_cache, + function_level_defaults, + enable_async=enable_async, + num_iters=num_iters, + num_reduce_scatter_instances=replication_factor, + cluster_axis=0, + use_reduce_scatter_async=True, + enable_persistent_fabric=True, + create_persistent_fabric=True, + teardown_persistent_fabric=True, + ) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp index 1dd6e8293c9..05d3eb0a662 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp @@ -208,12 +208,6 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( ); // KERNEL CREATION - const auto& worker_defines = op_config.emit_worker_defines(); - static const std::string& sender_kernel_reader_path = - "ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader.cpp"; - static const std::string& sender_kernel_writer_path = - "ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_writer.cpp"; - KernelHandle worker_sender_reader_kernel_id = ttnn::ccl::worker_detail::generate_multi_command_stream_kernel_ct_args( program, @@ -261,7 +255,8 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( std::unordered_map writer_rt_args_overrider_map; for (std::size_t link = 0; link < num_links; link++) { - CoreCoord core = {num_workers_per_link - 1, link}; + // CoreCoord core = {num_workers_per_link - 1, link}; + CoreCoord core = sender_worker_cores[link]; if (link == 0) { // drain sync core is the first worker core drain_sync_core = device->worker_core_from_logical_core(core); @@ -334,16 +329,12 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( // 2, mcast the semaphore to all dest for teardown writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::fabric_multicast_semaphore_inc( &semaphore, ttnn::ccl::cmd::CclCommandAtomicInc{1}, drain_sync_core.x, drain_sync_core.y, mcast_dest_args)); - if (!enable_async_output_tensor) { + bool wait_for_semaphore = !enable_async_output_tensor && link == 0; + if (wait_for_semaphore) { // 3, wait for n_chip*num_links number of semaphore at teardown semaphore address for first chip, and // n_chip*num_links+1 for other chips writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::local_semaphore_wait( - &semaphore, - is_first_chip ? ring_size * num_links : ring_size * num_links + !enable_persistent_fabric_mode)); - } - - bool generate_teardown_commands = !enable_persistent_fabric_mode && link == 0; - if (generate_teardown_commands) { + &semaphore, is_first_chip ? ring_size * num_links : ring_size * num_links + 1)); // 4, send semaphore unicast to forward device except for the last chip if (!is_last_chip) { writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::fabric_unicast_semaphore_inc( @@ -353,6 +344,9 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( drain_sync_core.y, ttnn::ccl::cmd::UnicastCommandDestArgs{1, true})); } + } + bool generate_teardown_commands = !enable_persistent_fabric_mode && link == 0; + if (generate_teardown_commands) { // 5, increment the termination semaphore for local device for local teardown only for the drain sync core auto termination_infos = local_fabric_handle->generate_local_chip_fabric_termination_infos(device); for (auto& info : termination_infos) { @@ -362,6 +356,9 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::local_chip_noc_absolute_address_semaphore_inc( info.edm_noc_x, info.edm_noc_y, info.termination_addr, 1)); } + } + bool reset_semaphore = generate_teardown_commands || (!enable_async_output_tensor && link == 0); + if (reset_semaphore) { // 6. (drain sync core) reset semaphore to 0 writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::local_core_semaphore_set(&semaphore, 0)); } From c2552f3aee52fcb6660125188249f404117a46c6 Mon Sep 17 00:00:00 2001 From: Jack Cai Date: Mon, 13 Jan 2025 21:56:55 +0000 Subject: [PATCH 2/3] enabled multi-link reduce scater (passing) and added link to all gather failure issue --- .../ccl/test_all_gather_async_TG_nightly.py | 14 ++++++++++---- .../ccl/test_reduce_scatter_async_TG_nightly.py | 6 +++--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py index 431e0151048..74c1f113240 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py @@ -28,7 +28,7 @@ @pytest.mark.parametrize( "num_devices, num_links", [(4, 1)], - # [(4, 3)], Multi-links fails + # [(4, 3)], Multi-links fails https://github.com/tenstorrent/tt-metal/issues/16699 ) @pytest.mark.parametrize( "input_dtype", @@ -84,7 +84,7 @@ def test_line_all_gather_sharded_on_TG_rows_post_commit( if len(mesh_device.get_devices()) != 32: pytest.skip("Not TG!") if input_dtype == ttnn.bfloat16 and per_chip_output_shape == (1, 1, 32, 1024 * 4): - pytest.skip("Skiped due to hang") + pytest.skip("Skiped due to hang Issue #16699") input_shard_spec = ttnn.ShardSpec( shard_grid, input_shard_shape, @@ -121,7 +121,7 @@ def test_line_all_gather_sharded_on_TG_rows_post_commit( [ (8, 1), ], - # [(8, 4), (8, 3), (8, 2)], Multi-links fails + # [(8, 4), (8, 3), (8, 2)], Multi-links fails https://github.com/tenstorrent/tt-metal/issues/16699 ) @pytest.mark.parametrize( "input_dtype", @@ -224,7 +224,7 @@ def test_line_all_gather_sharded_on_TG_cols_post_commit( if len(mesh_device.get_devices()) != 32: pytest.skip("Not TG!") if input_dtype == ttnn.bfloat16 and input_shape == (1, 1, 256, 2048): - pytest.skip("Skiped due to hang") + pytest.skip("Skiped due to hang Issue #16699") input_shard_spec = ttnn.ShardSpec( shard_grid, input_shard_shape, @@ -265,6 +265,12 @@ def test_line_all_gather_sharded_on_TG_cols_post_commit( (8, 1, [1, 8, 32, 2048], 1, ttnn.TILE_LAYOUT), (8, 1, [1, 8, 32, 2304], 1, ttnn.TILE_LAYOUT), (8, 1, [1, 8, 32, 4096], 1, ttnn.TILE_LAYOUT), + # multi-links fails: https://github.com/tenstorrent/tt-metal/issues/16699 + # (8, 4, [1, 8, 32, 1280], 1, ttnn.TILE_LAYOUT), + # (8, 4, [8, 1, 32, 1280], 0, ttnn.TILE_LAYOUT), + # (8, 4, [1, 8, 32, 2048], 1, ttnn.TILE_LAYOUT), + # (8, 4, [1, 8, 32, 2304], 1, ttnn.TILE_LAYOUT), + # (8, 4, [1, 8, 32, 4096], 1, ttnn.TILE_LAYOUT), ], ) @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py index 5f3695e62ed..5025ffcfbe3 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py @@ -23,7 +23,7 @@ @pytest.mark.parametrize( "num_devices, num_links, per_chip_output_shape, dim, layout", [ - (4, 1, [1, 4, 32, 2304], 1, ttnn.TILE_LAYOUT), + (4, 2, [1, 4, 32, 2304], 1, ttnn.TILE_LAYOUT), ], ) @pytest.mark.parametrize( @@ -91,8 +91,8 @@ def test_line_reduce_scatter_on_TG_rows_post_commit( @pytest.mark.parametrize( "num_devices, num_links, per_chip_output_shape, dim, layout", [ - (8, 1, [1, 8, 32, 1280], 1, ttnn.TILE_LAYOUT), - (8, 1, [8, 1, 32, 1280], 0, ttnn.TILE_LAYOUT), + (8, 2, [1, 8, 32, 1280], 1, ttnn.TILE_LAYOUT), + (8, 2, [8, 1, 32, 1280], 0, ttnn.TILE_LAYOUT), ], ) @pytest.mark.parametrize( From 1fcb0191c687e097f69d24fdc9f38d8f3c0314cf Mon Sep 17 00:00:00 2001 From: Jack Cai Date: Tue, 14 Jan 2025 15:43:15 +0000 Subject: [PATCH 3/3] addressed pr comments --- tests/nightly/tg/ccl/test_all_gather_async_nightly.py | 2 +- tests/nightly/tg/ccl/test_reduce_scatter_async_nightly.py | 2 +- .../operations/ccl/test_all_gather_TG_post_commit.py | 3 +-- .../operations/ccl/test_all_gather_async_TG_nightly.py | 6 +++--- .../operations/ccl/test_reduce_scatter_async_TG_nightly.py | 2 +- .../all_gather_async/device/all_gather_async_program.cpp | 1 - 6 files changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/nightly/tg/ccl/test_all_gather_async_nightly.py b/tests/nightly/tg/ccl/test_all_gather_async_nightly.py index 24cc55db361..f342d96f5be 120000 --- a/tests/nightly/tg/ccl/test_all_gather_async_nightly.py +++ b/tests/nightly/tg/ccl/test_all_gather_async_nightly.py @@ -1 +1 @@ -/proj_sw/user_dev/xuncai/tt-metal/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py \ No newline at end of file +../../../ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py \ No newline at end of file diff --git a/tests/nightly/tg/ccl/test_reduce_scatter_async_nightly.py b/tests/nightly/tg/ccl/test_reduce_scatter_async_nightly.py index 31f5aed40d5..2187a4cc4fb 120000 --- a/tests/nightly/tg/ccl/test_reduce_scatter_async_nightly.py +++ b/tests/nightly/tg/ccl/test_reduce_scatter_async_nightly.py @@ -1 +1 @@ -/proj_sw/user_dev/xuncai/tt-metal/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py \ No newline at end of file +../../../ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py \ No newline at end of file diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py index 8162b81bf2e..7534038d205 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py @@ -255,9 +255,8 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( ) if enable_persistent_fabric: - logger.info(f"Waiting for op") ttnn.synchronize_devices(mesh_device, sub_device_ids=sub_device_stall_group) - logger.info(f"Done iteration") + ttnn.synchronize_devices(mesh_device, sub_device_ids=sub_device_stall_group) if enable_persistent_fabric and teardown_persistent_fabric: logger.info("Tearing down persistent fabric interface") diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py index 74c1f113240..b572de93aab 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC # SPDX-License-Identifier: Apache-2.0 @@ -84,7 +84,7 @@ def test_line_all_gather_sharded_on_TG_rows_post_commit( if len(mesh_device.get_devices()) != 32: pytest.skip("Not TG!") if input_dtype == ttnn.bfloat16 and per_chip_output_shape == (1, 1, 32, 1024 * 4): - pytest.skip("Skiped due to hang Issue #16699") + pytest.skip("Skipped due to hang Issue #16699") input_shard_spec = ttnn.ShardSpec( shard_grid, input_shard_shape, @@ -224,7 +224,7 @@ def test_line_all_gather_sharded_on_TG_cols_post_commit( if len(mesh_device.get_devices()) != 32: pytest.skip("Not TG!") if input_dtype == ttnn.bfloat16 and input_shape == (1, 1, 256, 2048): - pytest.skip("Skiped due to hang Issue #16699") + pytest.skip("Skipped due to hang Issue #16699") input_shard_spec = ttnn.ShardSpec( shard_grid, input_shard_shape, diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py index 5025ffcfbe3..d7ff05200d0 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC # SPDX-License-Identifier: Apache-2.0 diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp index 05d3eb0a662..d8e695a2657 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp @@ -255,7 +255,6 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( std::unordered_map writer_rt_args_overrider_map; for (std::size_t link = 0; link < num_links; link++) { - // CoreCoord core = {num_workers_per_link - 1, link}; CoreCoord core = sender_worker_cores[link]; if (link == 0) { // drain sync core is the first worker core