Skip to content

Commit

Permalink
[mlir][mesh] Fix wrong argument passed to targetShardingInUnsplitLast…
Browse files Browse the repository at this point in the history
…Axis (#95059)

In unsplitLastAxisInResharding, wrong argument was passed when calling
targetShardingInUnsplitLastAxis.There weren't any tests to uncover this.
I added one in mesh-spmdization.mlir for Linalg and one in
resharding-spmdization.mlir for Mesh dialects.
  • Loading branch information
Arda Unal authored Jun 13, 2024
1 parent 1ebda11 commit 01a429c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
builder.setInsertionPointAfterValue(sourceShard);

MeshShardingAttr targetSharding =
targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
Value allGatherResult = builder.create<AllGatherOp>(
Expand Down
35 changes: 35 additions & 0 deletions mlir/test/Dialect/Linalg/mesh-spmdization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,38 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partia
// CHECK: return %[[SHARDED_MATMUL]] : tensor<4x8xi8>
return %res_shared2 : tensor<4x8xi8>
}

// -----

mesh.mesh @mesh_1d(shape = 4)

// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis
func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis(
// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x6xi8>,
%in1: tensor<4x6xi8>,
// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<6x8xi8>,
%in2: tensor<6x8xi8>,
// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8>
%dps_out: tensor<4x8xi8>
// CHECK-SAME: -> tensor<4x8xi8> {
) -> tensor<4x8xi8> {
%in1_replicated1 = mesh.shard %in1 to <@mesh_1d, [[], []]> : tensor<4x6xi8>
%in1_replicated2 = mesh.shard %in1_replicated1 to <@mesh_1d, [[], []]> annotate_for_users : tensor<4x6xi8>
// CHECK: %[[ALL_SLICE1:.*]] = mesh.all_slice %[[IN2]] on @mesh_1d mesh_axes = [0] slice_axis = 1
%in2_replicated = mesh.shard %in2 to <@mesh_1d, [[], []]> : tensor<6x8xi8>
%in2_sharded = mesh.shard %in2_replicated to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<6x8xi8>
// CHECK: %[[ALL_SLICE2:.*]] = mesh.all_slice %[[DPS_OUT]] on @mesh_1d mesh_axes = [0] slice_axis = 1
%dps_out_replicated = mesh.shard %dps_out to <@mesh_1d, [[], []]> : tensor<4x8xi8>
%dps_out_sharded = mesh.shard %dps_out_replicated to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x8xi8>
// CHECK: %[[MATMUL_RES:.*]] = linalg.matmul
// CHECK-SAME: ins(%[[IN1]], %[[ALL_SLICE1]] : tensor<4x6xi8>, tensor<6x2xi8>)
// CHECK-SAME: outs(%[[ALL_SLICE2]] : tensor<4x2xi8>)
// CHECK-SAME: -> tensor<4x2xi8>
%res = linalg.matmul ins(%in1_replicated2, %in2_sharded : tensor<4x6xi8>, tensor<6x8xi8>)
outs(%dps_out_sharded : tensor<4x8xi8>) -> tensor<4x8xi8>
// CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[MATMUL_RES]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8>
%res_sharded = mesh.shard %res to <@mesh_1d, [[], [0]]> : tensor<4x8xi8>
%res_replicated = mesh.shard %res_sharded to <@mesh_1d, [[], []]> annotate_for_users: tensor<4x8xi8>
// CHECK: return %[[ALL_GATHER]] : tensor<4x8xi8>
return %res_replicated : tensor<4x8xi8>
}
13 changes: 13 additions & 0 deletions mlir/test/Dialect/Mesh/resharding-spmdization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,19 @@ func.func @unshard_static_axis(
return %1 : tensor<10x14xf32>
}

// CHECK-LABEL: func @unshard_static_last_axis
func.func @unshard_static_last_axis(
// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
%arg0: tensor<10x14xf32>
) -> tensor<10x14xf32> {
// CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32>
// CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32>
%0 = mesh.shard %arg0 to <@mesh_1d, [[], [0]]> : tensor<10x14xf32>
%1 = mesh.shard %0 to <@mesh_1d, [[], []]> annotate_for_users : tensor<10x14xf32>
// CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
return %1 : tensor<10x14xf32>
}

// CHECK-LABEL: func @unshard_dynamic_axis
func.func @unshard_dynamic_axis(
// CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
Expand Down

0 comments on commit 01a429c

Please sign in to comment.