Skip to content

Commit

Permalink
fixed rebase errors and updated llama tests
Browse files Browse the repository at this point in the history
  • Loading branch information
caixunshiren committed Jan 16, 2025
1 parent 6279a64 commit ff143d9
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 186 deletions.
190 changes: 10 additions & 180 deletions tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
@pytest.mark.parametrize(
"num_devices, num_links",
[
(8, 1),
(4, 1),
],
# [(8, 4), (8, 3), (8, 2)], Multi-links fails https://github.com/tenstorrent/tt-metal/issues/16699
# [(4, 3), (4, 2)], Multi-links fails https://github.com/tenstorrent/tt-metal/issues/16699
)
@pytest.mark.parametrize(
"input_dtype",
Expand All @@ -45,97 +45,20 @@
@pytest.mark.parametrize(
"tensor_mem_layout, output_shape, dim, input_shard_shape,shard_grid,layout",
(
(
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
(1, 1, 32, 8192),
3,
( # AllGather after SDPA (~160 us)
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
(1, 32, 32, 128),
1,
(32, 128),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 0))}),
ttnn.TILE_LAYOUT,
),
),
)
@pytest.mark.parametrize("replication_factor", [4])
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True)
def test_line_all_gather_sharded_on_TG_cols_llama(
mesh_device,
num_devices,
output_shape,
input_shard_shape,
shard_grid,
shard_grid_orientation,
tensor_mem_layout,
dim,
num_links,
input_dtype,
layout,
use_program_cache,
function_level_defaults,
enable_async,
replication_factor,
num_iters=10,
):
if len(mesh_device.get_devices()) != 32:
pytest.skip("Not TG!")
input_shard_spec = ttnn.ShardSpec(
shard_grid,
input_shard_shape,
shard_grid_orientation,
)

logger.warning("sharding not used due to issue #16699")

run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
mesh_device,
num_devices,
output_shape,
ttnn.TensorMemoryLayout.INTERLEAVED, # tensor_mem_layout,
dim,
num_links,
input_dtype,
layout,
ttnn.BufferType.L1,
use_program_cache,
function_level_defaults,
enable_async=enable_async,
num_iters=num_iters,
# input_shard_spec=input_shard_spec,
num_all_gather_instances=replication_factor,
cluster_axis=0,
use_all_gather_async=True,
enable_persistent_fabric=True,
create_persistent_fabric=True,
teardown_persistent_fabric=True,
)


# Enumerate the post-commit cases explicitly
@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize(
"num_devices, num_links",
[
(4, 1),
],
# [(4, 3), (4, 2)], Multi-links fails https://github.com/tenstorrent/tt-metal/issues/16699
)
@pytest.mark.parametrize(
"input_dtype",
[
ttnn.bfloat16, # hang??
# ttnn.bfloat8_b,
],
)
@pytest.mark.parametrize("shard_grid_orientation", [ttnn.ShardOrientation.ROW_MAJOR])
@pytest.mark.parametrize(
"tensor_mem_layout, output_shape, dim, input_shard_shape,shard_grid,layout",
(
( # AllGather after Binary Mult+Silu
( # AllGather after Binary Mult+Silu (~160 us)
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
(1, 1, 32, 30720),
(1, 1, 32, 3840),
3,
(32, 96),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 4))}),
(32, 32),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(5, 4))}),
ttnn.TILE_LAYOUT,
),
),
Expand Down Expand Up @@ -195,99 +118,6 @@ def test_line_all_gather_sharded_on_TG_rows_llama(
)


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize(
"num_devices, num_links",
[
(8, 2),
],
)
@pytest.mark.parametrize(
"tensor_mem_layout, per_chip_input_shape, dim, input_shard_shape,shard_grid,layout",
(
(
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
(1, 1, 32, 3584),
3,
(32, 160),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 2))}),
ttnn.TILE_LAYOUT,
),
),
)
@pytest.mark.parametrize("shard_grid_orientation", [ttnn.ShardOrientation.ROW_MAJOR])
@pytest.mark.parametrize(
"input_dtype",
[
ttnn.bfloat16,
# ttnn.bfloat8_b,
],
)
@pytest.mark.parametrize(
"buffer_type",
[
ttnn.BufferType.L1,
],
)
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("replication_factor", [4])
@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True)
@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum])
def test_line_reduce_scatter_sharded_on_TG_cols_llama(
mesh_device,
num_devices,
per_chip_input_shape,
tensor_mem_layout,
input_shard_shape,
shard_grid,
shard_grid_orientation,
dim,
num_links,
math_op,
input_dtype,
layout,
buffer_type,
use_program_cache,
function_level_defaults,
enable_async,
replication_factor,
num_iters=10,
):
if len(mesh_device.get_devices()) != 32:
pytest.skip("Not TG!")
input_shard_spec = ttnn.ShardSpec(
shard_grid,
input_shard_shape,
shard_grid_orientation,
)

logger.warning("sharding not used due to issue #16699")

run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows(
mesh_device,
num_devices,
per_chip_input_shape,
ttnn.TensorMemoryLayout.INTERLEAVED, # tensor_mem_layout,
dim,
num_links,
math_op,
input_dtype,
layout,
buffer_type,
use_program_cache,
function_level_defaults,
enable_async=enable_async,
# input_shard_spec=input_shard_spec,
num_iters=num_iters,
num_reduce_scatter_instances=replication_factor,
cluster_axis=0,
use_reduce_scatter_async=True,
enable_persistent_fabric=True,
create_persistent_fabric=True,
teardown_persistent_fabric=True,
)


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize(
"num_devices, num_links",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

#include "all_reduce_async.hpp"

#include "ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp"
#include "cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp"
#include "ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp"
#include "ttnn/cpp/ttnn/global_semaphore.hpp"
#include "cpp/ttnn/global_semaphore.hpp"

namespace ttnn::operations::experimental::ccl {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

#include "ttnn/operations/reduction/generic/generic_reductions.hpp"

#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_types.hpp"
#include "ttnn/cpp/ttnn/global_semaphore.hpp"
#include "cpp/ttnn/operations/ccl/ccl_host_types.hpp"
#include "cpp/ttnn/global_semaphore.hpp"

namespace ttnn {
namespace operations {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"
#include "cpp/pybind11/decorators.hpp"
#include "ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.hpp"
#include "ttnn/types.hpp"
#include "ttnn/cpp/ttnn/global_semaphore.hpp"
#include "cpp/ttnn/global_semaphore.hpp"

#include "ttnn/operations/reduction/generic/generic_reductions.hpp"

Expand Down

0 comments on commit ff143d9

Please sign in to comment.