Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][mesh] Fix wrong argument passed to targetShardingInUnsplitLast… #95059

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
builder.setInsertionPointAfterValue(sourceShard);

MeshShardingAttr targetSharding =
targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
Value allGatherResult = builder.create<AllGatherOp>(
Expand Down
35 changes: 35 additions & 0 deletions mlir/test/Dialect/Linalg/mesh-spmdization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,38 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partia
// CHECK: return %[[SHARDED_MATMUL]] : tensor<4x8xi8>
return %res_shared2 : tensor<4x8xi8>
}

// -----

mesh.mesh @mesh_1d(shape = 4)

// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis
func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis(
// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x6xi8>,
%in1: tensor<4x6xi8>,
// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<6x8xi8>,
%in2: tensor<6x8xi8>,
// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8>
%dps_out: tensor<4x8xi8>
// CHECK-SAME: -> tensor<4x8xi8> {
) -> tensor<4x8xi8> {
%in1_replicated1 = mesh.shard %in1 to <@mesh_1d, [[], []]> : tensor<4x6xi8>
%in1_replicated2 = mesh.shard %in1_replicated1 to <@mesh_1d, [[], []]> annotate_for_users : tensor<4x6xi8>
// CHECK: %[[ALL_SLICE1:.*]] = mesh.all_slice %[[IN2]] on @mesh_1d mesh_axes = [0] slice_axis = 1
%in2_replicated = mesh.shard %in2 to <@mesh_1d, [[], []]> : tensor<6x8xi8>
%in2_sharded = mesh.shard %in2_replicated to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<6x8xi8>
// CHECK: %[[ALL_SLICE2:.*]] = mesh.all_slice %[[DPS_OUT]] on @mesh_1d mesh_axes = [0] slice_axis = 1
%dps_out_replicated = mesh.shard %dps_out to <@mesh_1d, [[], []]> : tensor<4x8xi8>
%dps_out_sharded = mesh.shard %dps_out_replicated to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x8xi8>
// CHECK: %[[MATMUL_RES:.*]] = linalg.matmul
// CHECK-SAME: ins(%[[IN1]], %[[ALL_SLICE1]] : tensor<4x6xi8>, tensor<6x2xi8>)
// CHECK-SAME: outs(%[[ALL_SLICE2]] : tensor<4x2xi8>)
// CHECK-SAME: -> tensor<4x2xi8>
%res = linalg.matmul ins(%in1_replicated2, %in2_sharded : tensor<4x6xi8>, tensor<6x8xi8>)
outs(%dps_out_sharded : tensor<4x8xi8>) -> tensor<4x8xi8>
// CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[MATMUL_RES]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8>
%res_sharded = mesh.shard %res to <@mesh_1d, [[], [0]]> : tensor<4x8xi8>
%res_replicated = mesh.shard %res_sharded to <@mesh_1d, [[], []]> annotate_for_users: tensor<4x8xi8>
// CHECK: return %[[ALL_GATHER]] : tensor<4x8xi8>
return %res_replicated : tensor<4x8xi8>
}
13 changes: 13 additions & 0 deletions mlir/test/Dialect/Mesh/resharding-spmdization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,19 @@ func.func @unshard_static_axis(
return %1 : tensor<10x14xf32>
}

// CHECK-LABEL: func @unshard_static_last_axis
func.func @unshard_static_last_axis(
// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
%arg0: tensor<10x14xf32>
) -> tensor<10x14xf32> {
// CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32>
// CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32>
%0 = mesh.shard %arg0 to <@mesh_1d, [[], [0]]> : tensor<10x14xf32>
%1 = mesh.shard %0 to <@mesh_1d, [[], []]> annotate_for_users : tensor<10x14xf32>
// CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
return %1 : tensor<10x14xf32>
}

// CHECK-LABEL: func @unshard_dynamic_axis
func.func @unshard_dynamic_axis(
// CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
Expand Down
Loading