Skip to content

Commit

Permalink
enable async in tests, bump up iteration count
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanNijjar committed Dec 31, 2024
1 parent 0105614 commit 5386bb5
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 10 deletions.
16 changes: 8 additions & 8 deletions tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ def run_all_gather_impl(
ttnn.MemoryConfig(buffer_type=ttnn.BufferType.DRAM),
],
)
@pytest.mark.parametrize("num_iters", [1])
@pytest.mark.parametrize("enable_async", [False])
@pytest.mark.parametrize("num_iters", [8])
@pytest.mark.parametrize("enable_async", [True])
def test_all_gather(
t3k_mesh_device,
# pcie_mesh_device,
Expand Down Expand Up @@ -395,8 +395,8 @@ def test_all_gather(
ttnn.bfloat16,
],
)
@pytest.mark.parametrize("num_iters", [1])
@pytest.mark.parametrize("enable_async", [False])
@pytest.mark.parametrize("num_iters", [8])
@pytest.mark.parametrize("enable_async", [True])
def test_all_gather_sharded(
t3k_mesh_device,
# pcie_mesh_device,
Expand Down Expand Up @@ -525,7 +525,7 @@ def test_line_all_gather_async_on_T3K_cols_transient_fabric_post_commit(
ttnn.BufferType.DRAM,
],
)
@pytest.mark.parametrize("enable_async", [False])
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("replication_factor", [4])
def test_line_all_gather_async_on_T3K_cols_persistent_fabric_post_commit(
t3k_mesh_device,
Expand Down Expand Up @@ -594,7 +594,7 @@ def test_line_all_gather_async_on_T3K_cols_persistent_fabric_post_commit(
],
)
@pytest.mark.parametrize("replication_factor", [2])
@pytest.mark.parametrize("enable_async", [False])
@pytest.mark.parametrize("enable_async", [True])
def test_line_all_gather_async_on_T3K_rows_transient_fabric_post_commit(
t3k_mesh_device,
num_devices,
Expand Down Expand Up @@ -665,7 +665,7 @@ def test_line_all_gather_async_on_T3K_rows_transient_fabric_post_commit(
],
)
@pytest.mark.parametrize("replication_factor", [2])
@pytest.mark.parametrize("enable_async", [False])
@pytest.mark.parametrize("enable_async", [True])
def test_line_all_gather_async_on_T3K_rows_persistent_fabric_post_commit(
t3k_mesh_device,
num_devices,
Expand Down Expand Up @@ -733,7 +733,7 @@ def test_line_all_gather_async_on_T3K_rows_persistent_fabric_post_commit(
],
)
@pytest.mark.parametrize("replication_factor1", [4])
@pytest.mark.parametrize("enable_async", [False])
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize(
"num_devices2, num_links2, per_chip_output_shape2, dim2, layout2",
[
Expand Down
73 changes: 71 additions & 2 deletions tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,10 @@ def test_line_reduce_scatter_async_post_commit(
ttnn.BufferType.DRAM,
],
)
@pytest.mark.parametrize("enable_async", [False])
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("replication_factor", [4])
@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum])
def test_line_reduce_scatter_async_on_TG_cols_post_commit(
def test_line_reduce_scatter_async_on_T3K_cols_post_commit(
t3k_mesh_device,
num_devices,
per_chip_output_shape,
Expand Down Expand Up @@ -397,3 +397,72 @@ def test_line_reduce_scatter_async_on_TG_cols_post_commit(
create_persistent_fabric=True,
teardown_persistent_fabric=True,
)


@pytest.mark.skip(
"persistent fabric test with cluster-axis API and multiple concurrent reduce_scatter instances not enabled yet"
)
@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize(
"num_devices, num_links, per_chip_output_shape, dim, layout",
[
(4, 1, [1, 4, 32, 1280], 1, ttnn.TILE_LAYOUT),
(4, 1, [4, 1, 32, 1280], 0, ttnn.TILE_LAYOUT),
],
)
@pytest.mark.parametrize(
"input_dtype",
[
ttnn.bfloat16,
],
)
@pytest.mark.parametrize(
"buffer_type",
[
ttnn.BufferType.DRAM,
],
)
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("replication_factor", [2])
@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum])
def test_line_reduce_scatter_async_on_T3K_rows_post_commit(
t3k_mesh_device,
num_devices,
per_chip_output_shape,
dim,
num_links,
math_op,
input_dtype,
layout,
buffer_type,
use_program_cache,
function_level_defaults,
enable_async,
replication_factor,
num_iters=1,
):
if len(t3k_mesh_device.get_devices()) < 8:
pytest.skip("Not T3K!")

run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows(
t3k_mesh_device,
num_devices,
per_chip_output_shape,
ttnn.TensorMemoryLayout.INTERLEAVED,
dim,
num_links,
math_op,
input_dtype,
layout,
buffer_type,
use_program_cache,
function_level_defaults,
enable_async=enable_async,
num_iters=num_iters,
num_reduce_scatter_instances=replication_factor,
cluster_axis=1,
use_reduce_scatter_async=True,
enable_persistent_fabric=True,
create_persistent_fabric=True,
teardown_persistent_fabric=True,
)

0 comments on commit 5386bb5

Please sign in to comment.