Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#2105: Expanded test coverage for working multichip configurations and ccl operations #2106

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,11 @@ jobs:
{runs-on: n150, enable_perf: OFF, enable_emitc: ON, enable_async: OFF, name: "run", build_name: "emitc", ttrt_flags: "--emitc", container-options: "--device /dev/tenstorrent/0"},
{runs-on: n150, enable_perf: OFF, enable_emitc: OFF, enable_async: ON, name: "run", build_name: "async", ttrt_flags: "--non-zero --enable-async-ttnn", container-options: "--device /dev/tenstorrent/0"},
{runs-on: n300, enable_perf: OFF, enable_emitc: OFF, enable_async: OFF, name: "run", build_name: "run", ttrt_flags: "--non-zero", container-options: "--device /dev/tenstorrent/0"},
{runs-on: n300, enable_perf: ON, enable_emitc: OFF, enable_async: OFF, name: "perf", build_name: "perf", container-options: "--device /dev/tenstorrent/0"},
{runs-on: llmbox, enable_perf: OFF, enable_emitc: OFF, enable_async: OFF, name: "run", build_name: "run", ttrt_flags: "--non-zero", container-options: "--device /dev/tenstorrent/0 --device /dev/tenstorrent/1 --device /dev/tenstorrent/2 --device /dev/tenstorrent/3"},
{runs-on: tg, enable_perf: OFF, enable_emitc: OFF, enable_async: OFF, name: "run", build_name: "run", ttrt_flags: "--non-zero --disable-eth-dispatch", container-options: "--device /dev/tenstorrent/0 --device /dev/tenstorrent/1 --device /dev/tenstorrent/2 --device /dev/tenstorrent/3"},
{runs-on: llmbox, enable_perf: ON, enable_emitc: OFF, enable_async: OFF, name: "perf", build_name: "perf", container-options: "--device /dev/tenstorrent/0 --device /dev/tenstorrent/1 --device /dev/tenstorrent/2 --device /dev/tenstorrent/3"},
#{runs-on: tg, enable_perf: OFF, enable_emitc: OFF, enable_async: OFF, name: "run", build_name: "run", ttrt_flags: "--non-zero --disable-eth-dispatch", container-options: "--device /dev/tenstorrent/0 --device /dev/tenstorrent/1 --device /dev/tenstorrent/2 --device /dev/tenstorrent/3"},
#{runs-on: tg, enable_perf: ON, enable_emitc: OFF, enable_async: OFF, name: "perf", build_name: "perf", ttrt_flags: "--disable-eth-dispatch", container-options: "--device /dev/tenstorrent/0 --device /dev/tenstorrent/1 --device /dev/tenstorrent/2 --device /dev/tenstorrent/3"},
]
name: "run-tests (${{ matrix.build.runs-on }}, ${{ matrix.build.enable_perf }}, ${{ matrix.build.enable_emitc }}, ${{ matrix.build.enable_async }}, ${{ matrix.build.build_name }})"

Expand Down
16 changes: 15 additions & 1 deletion runtime/tools/python/ttrt/common/perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ def initialize_api():
choices=[True, False],
help="dump memory reports after every op execution",
)
Perf.register_arg(
name="--disable-eth-dispatch",
type=bool,
default=False,
choices=[True, False],
help="disable putting dispatch on ethernet cores - place it on worker cores instead",
)
Perf.register_arg(
name="binary",
type=str,
Expand Down Expand Up @@ -134,7 +141,11 @@ def __init__(self, args={}, logger=None, artifacts=None):
artifacts_folder_path=self["--artifact-dir"],
)
)
self.query = Query({"--quiet": True}, self.logger, self.artifacts)
self.query = Query(
{"--quiet": True, "--disable-eth-dispatch": self["--disable-eth-dispatch"]},
self.logger,
self.artifacts,
)
self.ttnn_binaries = []
self.ttmetal_binaries = []
self.tracy_capture_tool_path = (
Expand Down Expand Up @@ -385,6 +396,9 @@ def get_available_port():
if self["--memory"]:
command_options += " --memory "

if self["--disable-eth-dispatch"]:
command_options += " --disable-eth-dispatch "

if self["--disable-golden"]:
command_options += " --disable-golden "

Expand Down
155 changes: 145 additions & 10 deletions test/ttmlir/Conversion/StableHLOToTTIR/ccl_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ module @all_reduce_1x32 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_rep
}
}

// jax/pjrt sharding target 1x2 for n300 all_gather
module @all_gather_1x2 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
// jax/pjrt sharding target 1x2 for n300 all_gather rank=2
module @all_gather_1x2_rank_2 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<16384x800xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[2,1]<=[2]}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<4096x800xf32>
Expand All @@ -223,8 +223,35 @@ module @all_gather_1x2 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_repli
}
}

// jax/pjrt sharding target 1x8 for t3k all_gather
module @all_gather_1x8 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
// jax/pjrt sharding target 1x2 for n300 all_gather rank=4
module @all_gather_1x2_rank_4 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x1x8192x784xf32>) -> (tensor<1x1x16384x784xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,2,1]<=[2]}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x8192x784xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x4096x784xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: -1, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<full_to_shard>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 2, 1>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
%2 = call @shmap_body(%1) : (tensor<1x1x4096x784xf32>) -> tensor<1x1x8192x784xf32>
%3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x8192x784xf32>
%4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,2,1]<=[2]}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x16384x784xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: -1, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<shard_to_full>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 2, 1>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
return %4 : tensor<1x1x16384x784xf32>
}
func.func private @shmap_body(%arg0: tensor<1x1x4096x784xf32>) -> (tensor<1x1x8192x784xf32> {jax.result_info = "[None, None, ('model',), None]"}) {
%0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 2 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, use_global_device_ids}> : (tensor<1x1x4096x784xf32>) -> tensor<1x1x8192x784xf32>
// CHECK: %[[C:.*]] = "ttir.all_gather"[[C:.*]]
return %0 : tensor<1x1x8192x784xf32>
}
}

// jax/pjrt sharding target 1x8 for t3k all_gather rank=2
module @all_gather_1x8_rank_2 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<65536x800xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<1024x800xf32>
Expand All @@ -250,8 +277,35 @@ module @all_gather_1x8 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_repli
}
}

// jax/pjrt sharding target 2x4 for t3k all_gather
module @all_gather_2x4 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
// jax/pjrt sharding target 1x8 for t3k all_gather rank=4
module @all_gather_1x8_rank4 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x1x8192x784xf32>) -> (tensor<1x1x65536x784xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,8,1]<=[8]}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x8192x784xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x1024x784xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: -1, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<full_to_shard>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 8, 1>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
%2 = call @shmap_body(%1) : (tensor<1x1x1024x784xf32>) -> tensor<1x1x8192x784xf32>
%3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x8192x784xf32>
%4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,8,1]<=[8]}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x65536x784xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: -1, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<shard_to_full>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 8, 1>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
return %4 : tensor<1x1x65536x784xf32>
}
func.func private @shmap_body(%arg0: tensor<1x1x1024x784xf32>) -> (tensor<1x1x8192x784xf32> {jax.result_info = "[None, None, ('model',), None]"}) {
%0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 2 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<1x1x1024x784xf32>) -> tensor<1x1x8192x784xf32>
// CHECK: %[[C:.*]] = "ttir.all_gather"[[C:.*]]
return %0 : tensor<1x1x8192x784xf32>
}
}

// jax/pjrt sharding target 2x4 for t3k all_gather rank=2
module @all_gather_2x4_rank_2 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<32768x800xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[4,2]<=[2,4]T(1,0)}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<2048x400xf32>
Expand All @@ -277,8 +331,35 @@ module @all_gather_2x4 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_repli
}
}

// jax/pjrt sharding target 1x32 for tg all_gather
module @all_gather_1x32 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} {
// jax/pjrt sharding target 2x4 for t3k all_gather rank=4
module @all_gather_2x4_rank4 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x1x8192x784xf32>) -> (tensor<1x1x32768x784xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,4,2]<=[2,4]T(1,0)}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x8192x784xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x2048x392xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: 3, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<full_to_shard>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 4, 2>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
%2 = call @shmap_body(%1) : (tensor<1x1x2048x392xf32>) -> tensor<1x1x8192x392xf32>
%3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x392xf32>) -> tensor<1x1x8192x392xf32>
%4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,4,2]<=[2,4]T(1,0)}"} : (tensor<1x1x8192x392xf32>) -> tensor<1x1x32768x784xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: 3, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<shard_to_full>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 4, 2>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
return %4 : tensor<1x1x32768x784xf32>
}
func.func private @shmap_body(%arg0: tensor<1x1x2048x392xf32>) -> (tensor<1x1x8192x392xf32> {jax.result_info = "[None, None, ('model',), ('batch',)]"}) {
%0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 2 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<1x1x2048x392xf32>) -> tensor<1x1x8192x392xf32>
// CHECK: %[[C:.*]] = "ttir.all_gather"[[C:.*]]
return %0 : tensor<1x1x8192x392xf32>
}
}

// jax/pjrt sharding target 1x32 for tg all_gather rank=2
module @all_gather_1x32_rank_2 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<262144x800xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[32,1]<=[32]}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<256x800xf32>
Expand All @@ -304,8 +385,35 @@ module @all_gather_1x32 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_rep
}
}

// jax/pjrt sharding target 8x4 for tg all_gather
module @all_gather_4x8 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} {
// jax/pjrt sharding target 1x32 for tg all_gather rank4
module @all_gather_1x32_rank4 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x1x8192x784xf32>) -> (tensor<1x1x262144x784xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,32,1]<=[32]}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x8192x784xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x256x784xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: -1, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<full_to_shard>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 32, 1>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
%2 = call @shmap_body(%1) : (tensor<1x1x256x784xf32>) -> tensor<1x1x8192x784xf32>
%3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x8192x784xf32>
%4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,32,1]<=[32]}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x262144x784xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: -1, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<shard_to_full>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 32, 1>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
return %4 : tensor<1x1x262144x784xf32>
}
func.func private @shmap_body(%arg0: tensor<1x1x256x784xf32>) -> (tensor<1x1x8192x784xf32> {jax.result_info = "[None, None, ('model',), None]"}) {
%0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 2 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]> : tensor<1x32xi64>, use_global_device_ids}> : (tensor<1x1x256x784xf32>) -> tensor<1x1x8192x784xf32>
// CHECK: %[[C:.*]] = "ttir.all_gather"[[C:.*]]
return %0 : tensor<1x1x8192x784xf32>
}
}

// jax/pjrt sharding target 8x4 for tg all_gather rank=2
module @all_gather_4x8_rank_2 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<65536x800xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[8,4]<=[4,8]T(1,0)}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<1024x200xf32>
Expand All @@ -331,6 +439,33 @@ module @all_gather_4x8 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_repl
}
}

// jax/pjrt sharding target 8x4 for tg all_gather rank4
module @all_gather_8x4_rank4 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x1x8192x784xf32>) -> (tensor<1x1x32768x784xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,1,4,8]<=[8,4]T(1,0)}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x8192x784xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x784xf32>) -> tensor<1x1x2048x98xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: 3, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<full_to_shard>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 4, 8>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
%2 = call @shmap_body(%1) : (tensor<1x1x2048x98xf32>) -> tensor<1x1x8192x98xf32>
%3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x1x8192x98xf32>) -> tensor<1x1x8192x98xf32>
%4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[1,1,4,8]<=[8,4]T(1,0)}"} : (tensor<1x1x8192x98xf32>) -> tensor<1x1x32768x784xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: 3, 2>
// CHECK-SAME: shard_direction = #tt.shard_direction<shard_to_full>
// CHECK-SAME: shard_shape = array<i64: 1, 1, 4, 8>
// CHECK-SAME: shard_type = #tt.shard_type<devices>
return %4 : tensor<1x1x32768x784xf32>
}
func.func private @shmap_body(%arg0: tensor<1x1x2048x98xf32>) -> (tensor<1x1x8192x98xf32> {jax.result_info = "[None, None, ('model',), ('batch',)]"}) {
%0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 2 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31]]> : tensor<8x4xi64>, use_global_device_ids}> : (tensor<1x1x2048x98xf32>) -> tensor<1x1x8192x98xf32>
// CHECK: %[[C:.*]] = "ttir.all_gather"[[C:.*]]
return %0 : tensor<1x1x8192x98xf32>
}
}

// jax/pjrt sharding target 2x4 for t3k - GSPMD negative, sharding [None, "x", None, "y"]
module @jit_neg_basic0 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x1024x128x1024xf32>) -> (tensor<1x1024x128x1024xf32> {jax.result_info = ""}) {
Expand Down
Loading
Loading