Skip to content

Commit

Permalink
#2105: Expanded test coverage for working multichip configurations an…
Browse files Browse the repository at this point in the history
…d ccl operations
  • Loading branch information
tapspatel committed Feb 6, 2025
1 parent 3718fb4 commit 435efeb
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 10 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,11 @@ jobs:
{runs-on: n150, enable_perf: OFF, enable_emitc: ON, enable_async: OFF, name: "run", build_name: "emitc", ttrt_flags: "--emitc", container-options: "--device /dev/tenstorrent/0"},
{runs-on: n150, enable_perf: OFF, enable_emitc: OFF, enable_async: ON, name: "run", build_name: "async", ttrt_flags: "--non-zero --enable-async-ttnn", container-options: "--device /dev/tenstorrent/0"},
{runs-on: n300, enable_perf: OFF, enable_emitc: OFF, enable_async: OFF, name: "run", build_name: "run", ttrt_flags: "--non-zero", container-options: "--device /dev/tenstorrent/0"},
{runs-on: n300, enable_perf: ON, enable_emitc: OFF, enable_async: OFF, name: "perf", build_name: "perf", container-options: "--device /dev/tenstorrent/0"},
{runs-on: llmbox, enable_perf: OFF, enable_emitc: OFF, enable_async: OFF, name: "run", build_name: "run", ttrt_flags: "--non-zero", container-options: "--device /dev/tenstorrent/0 --device /dev/tenstorrent/1 --device /dev/tenstorrent/2 --device /dev/tenstorrent/3"},
{runs-on: llmbox, enable_perf: ON, enable_emitc: OFF, enable_async: OFF, name: "perf", build_name: "perf", container-options: "--device /dev/tenstorrent/0 --device /dev/tenstorrent/1 --device /dev/tenstorrent/2 --device /dev/tenstorrent/3"},
{runs-on: tg, enable_perf: OFF, enable_emitc: OFF, enable_async: OFF, name: "run", build_name: "run", ttrt_flags: "--non-zero --disable-eth-dispatch", container-options: "--device /dev/tenstorrent/0 --device /dev/tenstorrent/1 --device /dev/tenstorrent/2 --device /dev/tenstorrent/3"},
{runs-on: tg, enable_perf: ON, enable_emitc: OFF, enable_async: OFF, name: "perf", build_name: "perf", ttrt_flags: "--disable-eth-dispatch", container-options: "--device /dev/tenstorrent/0 --device /dev/tenstorrent/1 --device /dev/tenstorrent/2 --device /dev/tenstorrent/3"},
]
name: "run-tests (${{ matrix.build.runs-on }}, ${{ matrix.build.enable_perf }}, ${{ matrix.build.enable_emitc }}, ${{ matrix.build.enable_async }}, ${{ matrix.build.build_name }})"

Expand Down
10 changes: 10 additions & 0 deletions runtime/tools/python/ttrt/common/perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ def initialize_api():
choices=[True, False],
help="dump memory reports after every op execution",
)
Perf.register_arg(
name="--disable-eth-dispatch",
type=bool,
default=False,
choices=[True, False],
help="disable putting dispatch on ethernet cores - place it on worker cores instead",
)
Perf.register_arg(
name="binary",
type=str,
Expand Down Expand Up @@ -388,6 +395,9 @@ def get_available_port():
if self["--disable-golden"]:
command_options += " --disable-golden "

if self["--disable-eth-dispatch"]:
command_options += " --disable-eth-dispatch "

ttrt_executable_path = shutil.which("ttrt")
test_command = (
f"{ttrt_executable_path} run {bin.file_path} {command_options}"
Expand Down
155 changes: 145 additions & 10 deletions test/ttmlir/Conversion/StableHLOToTTIR/ccl_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ module @all_reduce_1x32 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_rep
}
}

// jax/pjrt sharding target 1x2 for n300 all_gather
module @all_gather_1x2 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
// jax/pjrt sharding target 1x2 for n300 all_gather rank=2
module @all_gather_1x2_rank_2 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<16384x800xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[2,1]<=[2]}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<4096x800xf32>
Expand All @@ -223,8 +223,35 @@ module @all_gather_1x2 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_repli
}
}

// jax/pjrt sharding target 1x8 for t3k all_gather
module @all_gather_1x8 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
// jax/pjrt sharding target 1x2 for n300 all_gather rank=4
module @all_gather_1x2_rank_4 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x1x8192x784xf32>) -> (tensor<1x1x16384x784xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,2,1]<=[2]}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x8192x784xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x4096x784xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: -1, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<full_to_shard>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 2, 1>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
%2 = call @shmap_body(%1) : (tensor<1x1x4096x784xf32>) -> tensor<1x1x8192x784xf32>
%3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x8192x784xf32>
%4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,2,1]<=[2]}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x16384x784xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: -1, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<shard_to_full>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 2, 1>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
return %4 : tensor<1x1x16384x784xf32>
}
func.func private @shmap_body(%arg0: tensor<1x1x4096x784xf32>) -> (tensor<1x1x8192x784xf32> {jax.result_info = "[None, None, ('model',), None]"}) {
%0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 2 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, use_global_device_ids}> : (tensor<1x1x4096x784xf32>) -> tensor<1x1x8192x784xf32>
// CHECK: %[[C:.*]] = "ttir.all_gather"[[C:.*]]
return %0 : tensor<1x1x8192x784xf32>
}
}

// jax/pjrt sharding target 1x8 for t3k all_gather rank=2
module @all_gather_1x8_rank_2 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<65536x800xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<1024x800xf32>
Expand All @@ -250,8 +277,35 @@ module @all_gather_1x8 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_repli
}
}

// jax/pjrt sharding target 2x4 for t3k all_gather
module @all_gather_2x4 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
// jax/pjrt sharding target 1x8 for t3k all_gather rank=4
module @all_gather_1x8_rank4 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x1x8192x784xf32>) -> (tensor<1x1x65536x784xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,8,1]<=[8]}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x8192x784xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x1024x784xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: -1, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<full_to_shard>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 8, 1>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
%2 = call @shmap_body(%1) : (tensor<1x1x1024x784xf32>) -> tensor<1x1x8192x784xf32>
%3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x8192x784xf32>
%4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,8,1]<=[8]}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x65536x784xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: -1, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<shard_to_full>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 8, 1>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
return %4 : tensor<1x1x65536x784xf32>
}
func.func private @shmap_body(%arg0: tensor<1x1x1024x784xf32>) -> (tensor<1x1x8192x784xf32> {jax.result_info = "[None, None, ('model',), None]"}) {
%0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 2 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<1x1x1024x784xf32>) -> tensor<1x1x8192x784xf32>
// CHECK: %[[C:.*]] = "ttir.all_gather"[[C:.*]]
return %0 : tensor<1x1x8192x784xf32>
}
}

// jax/pjrt sharding target 2x4 for t3k all_gather rank=2
module @all_gather_2x4_rank_2 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<32768x800xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[4,2]<=[2,4]T(1,0)}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<2048x400xf32>
Expand All @@ -277,8 +331,35 @@ module @all_gather_2x4 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_repli
}
}

// jax/pjrt sharding target 1x32 for tg all_gather
module @all_gather_1x32 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} {
// jax/pjrt sharding target 2x4 for t3k all_gather rank=4
module @all_gather_2x4_rank4 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x1x8192x784xf32>) -> (tensor<1x1x32768x784xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,4,2]<=[2,4]T(1,0)}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x8192x784xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x2048x392xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: 3, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<full_to_shard>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 4, 2>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
%2 = call @shmap_body(%1) : (tensor<1x1x2048x392xf32>) -> tensor<1x1x8192x392xf32>
%3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x392xf32>) -> tensor<1x1x8192x392xf32>
%4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,4,2]<=[2,4]T(1,0)}"} : (tensor<1x1x8192x392xf32>) -> tensor<1x1x32768x784xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: 3, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<shard_to_full>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 4, 2>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
return %4 : tensor<1x1x32768x784xf32>
}
func.func private @shmap_body(%arg0: tensor<1x1x2048x392xf32>) -> (tensor<1x1x8192x392xf32> {jax.result_info = "[None, None, ('model',), ('batch',)]"}) {
%0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 2 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<1x1x2048x392xf32>) -> tensor<1x1x8192x392xf32>
// CHECK: %[[C:.*]] = "ttir.all_gather"[[C:.*]]
return %0 : tensor<1x1x8192x392xf32>
}
}

// jax/pjrt sharding target 1x32 for tg all_gather rank=2
module @all_gather_1x32_rank_2 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<262144x800xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[32,1]<=[32]}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<256x800xf32>
Expand All @@ -304,8 +385,35 @@ module @all_gather_1x32 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_rep
}
}

// jax/pjrt sharding target 8x4 for tg all_gather
module @all_gather_4x8 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} {
// jax/pjrt sharding target 1x32 for tg all_gather rank4
module @all_gather_1x32_rank4 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x1x8192x784xf32>) -> (tensor<1x1x262144x784xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,32,1]<=[32]}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x8192x784xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x256x784xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: -1, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<full_to_shard>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 32, 1>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
%2 = call @shmap_body(%1) : (tensor<1x1x256x784xf32>) -> tensor<1x1x8192x784xf32>
%3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x8192x784xf32>
%4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,32,1]<=[32]}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x262144x784xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: -1, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<shard_to_full>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 32, 1>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
return %4 : tensor<1x1x262144x784xf32>
}
func.func private @shmap_body(%arg0: tensor<1x1x256x784xf32>) -> (tensor<1x1x8192x784xf32> {jax.result_info = "[None, None, ('model',), None]"}) {
%0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 2 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]> : tensor<1x32xi64>, use_global_device_ids}> : (tensor<1x1x256x784xf32>) -> tensor<1x1x8192x784xf32>
// CHECK: %[[C:.*]] = "ttir.all_gather"[[C:.*]]
return %0 : tensor<1x1x8192x784xf32>
}
}

// jax/pjrt sharding target 8x4 for tg all_gather rank=2
module @all_gather_4x8_rank_2 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<65536x800xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[8,4]<=[4,8]T(1,0)}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<1024x200xf32>
Expand All @@ -331,6 +439,33 @@ module @all_gather_4x8 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_repl
}
}

// jax/pjrt sharding target 8x4 for tg all_gather rank4
module @all_gather_8x4_rank4 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x1x8192x784xf32>) -> (tensor<1x1x32768x784xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,4,8]<=[8,4]T(1,0)}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x8192x784xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x2048x98xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: 3, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<full_to_shard>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 4, 8>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
%2 = call @shmap_body(%1) : (tensor<1x1x2048x98xf32>) -> tensor<1x1x8192x98xf32>
%3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x98xf32>) -> tensor<1x1x8192x98xf32>
%4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,4,8]<=[8,4]T(1,0)}"} : (tensor<1x1x8192x98xf32>) -> tensor<1x1x32768x784xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: 3, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<shard_to_full>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 4, 8>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
return %4 : tensor<1x1x32768x784xf32>
}
func.func private @shmap_body(%arg0: tensor<1x1x2048x98xf32>) -> (tensor<1x1x8192x98xf32> {jax.result_info = "[None, None, ('model',), ('batch',)]"}) {
%0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 2 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31]]> : tensor<8x4xi64>, use_global_device_ids}> : (tensor<1x1x2048x98xf32>) -> tensor<1x1x8192x98xf32>
// CHECK: %[[C:.*]] = "ttir.all_gather"[[C:.*]]
return %0 : tensor<1x1x8192x98xf32>
}
}

// jax/pjrt sharding target 2x4 for t3k - GSPMD negative, sharding [None, "x", None, "y"]
module @jit_neg_basic0 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) {
Expand Down
16 changes: 16 additions & 0 deletions test/ttmlir/Silicon/TTNN/llmbox/perf/all_gather.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% mesh-shape=2,4" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn

func.func @forward(%arg0: tensor<1x1x256x512xf32>) -> tensor<1x1x256x512xf32> {
%0 = tensor.empty() : tensor<1x1x128x128xf32>
%1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array<i64: 2, 3>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 1, 1, 2, 4>, shard_type = #tt.shard_type<devices>}> : (tensor<1x1x256x512xf32>, tensor<1x1x128x128xf32>) -> tensor<1x1x128x128xf32>
// CHECK: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]]
%2 = tensor.empty() : tensor<1x1x128x512xf32>
%3 = "ttir.all_gather"(%1, %2) <{dim = 3 : si32}> : (tensor<1x1x128x128xf32>, tensor<1x1x128x512xf32>) -> tensor<1x1x128x512xf32>
// CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]]
%4 = tensor.empty() : tensor<1x1x256x512xf32>
%5 = "ttir.mesh_shard"(%3, %4) <{shard_dims = array<i64: 2, -1>, shard_direction = #tt.shard_direction<shard_to_full>, shard_shape = array<i64: 1, 1, 2, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<1x1x128x512xf32>, tensor<1x1x256x512xf32>) -> tensor<1x1x256x512xf32>
// CHECK: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]]
return %5 : tensor<1x1x256x512xf32>
}
Loading

0 comments on commit 435efeb

Please sign in to comment.