Skip to content

Commit

Permalink
#sdy fix sharding rule of @PartialReduce/ApproxTopK custom call.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653581265
  • Loading branch information
tomnatan30 authored and copybara-github committed Jul 19, 2024
1 parent 446f1b6 commit 189519a
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,22 @@ OpShardingRuleBuilder& OpShardingRuleBuilder::addPointwiseIf(
return *this;
}

OpShardingRuleBuilder& OpShardingRuleBuilder::addPointwiseIfDimSizesMatch(
ArrayRef<int64_t> inShape, ArrayRef<int64_t> outShape, bool alwaysAddFactor,
std::function<void(int64_t dim, OpShardingRuleBuilder& builder)>
onMismatchFn) {
for (auto [dim, dimSizes] :
llvm::enumerate(llvm::zip_equal(inShape, outShape))) {
auto [inDimSize, outDimSize] = dimSizes;
if (alwaysAddFactor || inDimSize == outDimSize) {
addFactor(dim, inDimSize);
} else {
onMismatchFn(dim, *this);
}
}
return *this;
}

OpShardingRuleAttr createIdentityShardingRule(RankedTensorType type,
size_t numOperands,
size_t numResults) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,20 @@ class OpShardingRuleBuilder {

// Adds a pointwise factor for all dimensions that satisfy `pred` of all
// operands/results that have rank at least 1.
//
// Adds a factor of size 1 to all other dimensions, which would block any
// propagation along these dimensions.
OpShardingRuleBuilder& addPointwiseIf(ArrayRef<int64_t> shape,
std::function<bool(int64_t)> pred);

// Adds a pointwise factor for all dimensions, whose input and output sizes
// match, of all operands/results that have rank at least 1.
//
// If `alwaysAddFactor` is true, we add a factor for all dimensions with the
// corresponding size in `inType`, otherwise we only
OpShardingRuleBuilder& addPointwiseIfDimSizesMatch(
ArrayRef<int64_t> inShape, ArrayRef<int64_t> outShape,
bool alwaysAddFactor = false,
std::function<void(int64_t dim, OpShardingRuleBuilder& builder)>
onMismatchFn = [](int64_t dim, OpShardingRuleBuilder& builder) {});

private:
MLIRContext* context;
SmallVector<int64_t> factorSizes;
Expand Down
120 changes: 65 additions & 55 deletions shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,6 @@ bool isTranspose(stablehlo::Transpose transpose) {
llvm_unreachable("unknown stablehlo::Transpose");
}

// If `addFactorForMismatchedSize` is true, we add a factor for all dimensions
// with the corresponding size in `inType`, otherwise we only add a factor for
// dimensions with the same input and output size, letting the builder add size
// 1 factors for other dimensions.
OpShardingRuleAttr createMismatchedDimSizeShardingRule(
Operation* op, RankedTensorType inType, RankedTensorType outType,
bool addFactorForMismatchedSize) {
return OpShardingRuleBuilder(op)
.addPointwiseIf(inType.getShape(),
[&](int64_t dim) {
return addFactorForMismatchedSize ||
inType.getDimSize(dim) ==
outType.getDimSize(dim);
})
.build();
}

// Returns a vector with `numInputs` copies of `inputDim`, followed by a single
// `indicesDim`, then `numInputs` copies of `updateDim`, which matches the order
// and quantity of scatter operands.
Expand Down Expand Up @@ -450,11 +433,30 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
//
// Operands: [operand, iota, init_val (scalar), init_arg (scalar)]
// Results: [values, indices]
return createMismatchedDimSizeShardingRule(
customCall,
cast<RankedTensorType>(customCall.getOperand(0).getType()),
cast<RankedTensorType>(customCall.getResult(0).getType()),
/*addFactorForMismatchedSize=*/false);
ArrayRef<int64_t> inputShape =
getTensorShape(customCall.getOperand(0));
ArrayRef<int64_t> resultShape =
getTensorShape(customCall.getResult(0));
int64_t numInputs = 2, numResults = 2;
SmallVector<int64_t> operandDims(customCall->getNumOperands(),
kNullDim);
SmallVector<int64_t> resultDims(customCall->getNumResults(),
kNullDim);
return OpShardingRuleBuilder(customCall)
.addPointwiseIfDimSizesMatch(
inputShape, resultShape,
/*alwaysAddFactor=*/false,
/*onMismatchFn=*/
[&](int64_t dim, OpShardingRuleBuilder& builder) {
std::fill_n(operandDims.begin(), numInputs, dim);
resultDims.assign(numResults, kNullDim);
builder.addFactor(operandDims, resultDims, inputShape[dim]);
resultDims.assign(numResults, dim);
std::fill_n(operandDims.begin(), numInputs, kNullDim);
builder.addFactor(operandDims, resultDims,
resultShape[dim]);
})
.build();
}
// TODO(b/327191011): output unregistered op stats instead.
unreachableFormatv(
Expand Down Expand Up @@ -540,30 +542,30 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
})
.Case<stablehlo::DynamicSliceOp>(
[](stablehlo::DynamicSliceOp dynamicSlice) {
return createMismatchedDimSizeShardingRule(
dynamicSlice, dynamicSlice.getOperand().getType(),
dynamicSlice.getType(),
/*addFactorForMismatchedSize=*/false);
return OpShardingRuleBuilder(dynamicSlice)
.addPointwiseIfDimSizesMatch(
getTensorShape(dynamicSlice.getOperand()),
getTensorShape(dynamicSlice.getResult()))
.build();
})
.Case<stablehlo::DynamicUpdateSliceOp>(
[](stablehlo::DynamicUpdateSliceOp dynamicUpdateSlice) {
OpShardingRuleBuilder builder(dynamicUpdateSlice);
ArrayRef<int64_t> operandShape =
dynamicUpdateSlice.getOperand().getType().getShape();
getTensorShape(dynamicUpdateSlice.getOperand());
ArrayRef<int64_t> updateShape =
dynamicUpdateSlice.getUpdate().getType().getShape();

getTensorShape(dynamicUpdateSlice.getUpdate());
SmallVector<int64_t> operandDims(
dynamicUpdateSlice->getNumOperands(), kNullDim);
for (auto [dim, dimSizes] :
llvm::enumerate(llvm::zip_equal(operandShape, updateShape))) {
auto [operandDimSize, updateDimSize] = dimSizes;
operandDims[0] = dim;
operandDims[1] = operandDimSize == updateDimSize ? dim : kNullDim;
builder.addFactor(operandDims, dim, operandDimSize);
}

return builder.build();
return OpShardingRuleBuilder(dynamicUpdateSlice)
.addPointwiseIfDimSizesMatch(
operandShape, updateShape,
/*alwaysAddFactor=*/false,
/*onMismatchFn=*/
[&](int64_t dim, OpShardingRuleBuilder& builder) {
operandDims[0] = dim;
builder.addFactor(operandDims, dim, operandShape[dim]);
})
.build();
})
.Case<stablehlo::FftOp>([](stablehlo::FftOp fft) {
ArrayRef<int64_t> inShape = getTensorShape(fft.getOperand());
Expand Down Expand Up @@ -602,9 +604,12 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
.Case<stablehlo::PadOp>([conservativePropagation](stablehlo::PadOp pad) {
// If `conservativePropagation` is false, we propagate through padded
// dimensions, even though that would require communication.
return createMismatchedDimSizeShardingRule(
pad, pad.getOperand().getType(), pad.getType(),
/*addFactorForMismatchedSize=*/!conservativePropagation);
return OpShardingRuleBuilder(pad)
.addPointwiseIfDimSizesMatch(
getTensorShape(pad.getOperand()),
getTensorShape(pad.getResult()),
/*alwaysAddFactor=*/!conservativePropagation)
.build();
})
.Case<stablehlo::ReduceOp>([](stablehlo::ReduceOp reduce) {
OpShardingRuleBuilder builder(reduce);
Expand Down Expand Up @@ -657,12 +662,12 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
// In conservative mode, we only add a factor if the input and
// output dimension sizes are equal.
// TODO(tomnatan): should the reduced factor be compound?
return createMismatchedDimSizeShardingRule(
reduceWindow,
cast<RankedTensorType>(reduceWindow.getResult(0).getType()),
cast<RankedTensorType>(
reduceWindow.getInputs().front().getType()),
/*addFactorForMismatchedSize=*/!conservativePropagation);
return OpShardingRuleBuilder(reduceWindow)
.addPointwiseIfDimSizesMatch(
getTensorShape(reduceWindow.getInputs().front()),
getTensorShape(reduceWindow.getResult(0)),
/*alwaysAddFactor=*/!conservativePropagation)
.build();
})
.Case<stablehlo::ReshapeOp>([](stablehlo::ReshapeOp reshape) {
RankedTensorType inType = reshape.getOperand().getType();
Expand Down Expand Up @@ -790,10 +795,12 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
// In conservative mode, we only add a factor if the input and
// source dimension sizes are equal.
// TODO(tomnatan): should the reduced factor be compound?
return createMismatchedDimSizeShardingRule(
selectAndScatter, selectAndScatter.getSource().getType(),
selectAndScatter.getOperand().getType(),
/*addFactorForMismatchedSize=*/!conservativePropagation);
return OpShardingRuleBuilder(selectAndScatter)
.addPointwiseIfDimSizesMatch(
getTensorShape(selectAndScatter.getOperand()),
getTensorShape(selectAndScatter.getSource()),
/*alwaysAddFactor=*/!conservativePropagation)
.build();
})
.Case<stablehlo::SelectOp>([](stablehlo::SelectOp select) {
// Case 1: `pred` is a scalar in which case it is broadcasted and must
Expand All @@ -815,9 +822,12 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
// `conservativePropagation`, and the reason is that for `SliceOp`
// the start indices are static, so we know how to shift the data
// to keep the sliced dimension sharded.
return createMismatchedDimSizeShardingRule(
slice, slice.getOperand().getType(), slice.getType(),
/*addFactorForMismatchedSize=*/!conservativePropagation);
return OpShardingRuleBuilder(slice)
.addPointwiseIfDimSizesMatch(
getTensorShape(slice.getOperand()),
getTensorShape(slice.getResult()),
/*alwaysAddFactor=*/!conservativePropagation)
.build();
})
.Case<stablehlo::TransposeOp>([](stablehlo::TransposeOp transpose) {
OpShardingRuleBuilder builder(transpose);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,43 +194,43 @@ func.func @custom_call_householder_product(%arg0: tensor<8x12x16xf32>, %arg1: te
}

// CHECK-LABEL: func @custom_call_approx_topk
func.func @custom_call_approx_topk(%arg0: tensor<16x4xf32>, %arg1: tensor<16x4xf32>, %arg2: tensor<f32>, %arg3: tensor<i32>) -> (tensor<16x1xf32>, tensor<16x1xf32>) {
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, k], [], [])->([i, l], [i, m]) {i=16, j=1, k=1, l=1, m=1}>
func.func @custom_call_approx_topk(%arg0: tensor<16x4xf32>, %arg1: tensor<16x4xf32>, %arg2: tensor<f32>, %arg3: tensor<i32>) -> (tensor<16x2xf32>, tensor<16x2xf32>) {
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j], [], [])->([i, k], [i, k]) {i=16, j=4, k=2}>
%0:2 = stablehlo.custom_call @ApproxTopK(%arg0, %arg1, %arg2, %arg3) {
mhlo.backend_config = {
aggregate_to_topk = true,
recall_target = 0.9 : f32,
reduction_dim = 1 : i64,
reduction_input_size_override = -1 : i64,
top_k = 1 : i64},
top_k = 2 : i64},
called_computations = [@top_k_gt_f32_comparator]} :
(tensor<16x4xf32>, tensor<16x4xf32>, tensor<f32>, tensor<i32>) -> (tensor<16x1xf32>, tensor<16x1xf32>)
return %0#0, %0#1 : tensor<16x1xf32>, tensor<16x1xf32>
(tensor<16x4xf32>, tensor<16x4xf32>, tensor<f32>, tensor<i32>) -> (tensor<16x2xf32>, tensor<16x2xf32>)
return %0#0, %0#1 : tensor<16x2xf32>, tensor<16x2xf32>
}

// CHECK-LABEL: func @custom_call_partial_reduce
func.func @custom_call_partial_reduce(%arg0: tensor<16x4xf32>, %arg1: tensor<16x4xf32>, %arg2: tensor<f32>, %arg3: tensor<i32>) -> (tensor<16x1xf32>, tensor<16x1xf32>) {
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, k], [], [])->([i, l], [i, m]) {i=16, j=1, k=1, l=1, m=1}>
func.func @custom_call_partial_reduce(%arg0: tensor<16x4xf32>, %arg1: tensor<16x4xf32>, %arg2: tensor<f32>, %arg3: tensor<i32>) -> (tensor<16x2xf32>, tensor<16x2xf32>) {
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j], [], [])->([i, k], [i, k]) {i=16, j=4, k=2}>
%0:2 = stablehlo.custom_call @PartialReduce(%arg0, %arg1, %arg2, %arg3) {
mhlo.backend_config = {
aggregate_to_topk = true,
recall_target = 0.9 : f32,
reduction_dim = 1 : i64,
reduction_input_size_override = -1 : i64,
top_k = 1 : i64},
top_k = 2 : i64},
called_computations = [@top_k_gt_f32_comparator]} :
(tensor<16x4xf32>, tensor<16x4xf32>, tensor<f32>, tensor<i32>) -> (tensor<16x1xf32>, tensor<16x1xf32>)
return %0#0, %0#1 : tensor<16x1xf32>, tensor<16x1xf32>
(tensor<16x4xf32>, tensor<16x4xf32>, tensor<f32>, tensor<i32>) -> (tensor<16x2xf32>, tensor<16x2xf32>)
return %0#0, %0#1 : tensor<16x2xf32>, tensor<16x2xf32>
}

// CHECK-LABEL: func @custom_call_partial_reduce_string_backend_config
func.func @custom_call_partial_reduce_string_backend_config(%arg0: tensor<16x4xf32>, %arg1: tensor<16x4xf32>, %arg2: tensor<f32>, %arg3: tensor<i32>) -> (tensor<16x1xf32>, tensor<16x1xf32>) {
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, k], [], [])->([i, l], [i, m]) {i=16, j=1, k=1, l=1, m=1}>
func.func @custom_call_partial_reduce_string_backend_config(%arg0: tensor<16x4xf32>, %arg1: tensor<16x4xf32>, %arg2: tensor<f32>, %arg3: tensor<i32>) -> (tensor<16x2xf32>, tensor<16x2xf32>) {
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j], [], [])->([i, k], [i, k]) {i=16, j=4, k=2}>
%0:2 = stablehlo.custom_call @PartialReduce(%arg0, %arg1, %arg2, %arg3) {
backend_config = "{\22log2_reduction\22: 5, \22reduction_dim\22: 1, \22to_apply_type\22: \22comparator\22, \22top_k\22: 64, \22recall_target\22: 0.950000}",
backend_config = "{\22log2_reduction\22: 5, \22reduction_dim\22: 1, \22to_apply_type\22: \22comparator\22, \22top_k\22: 2, \22recall_target\22: 0.950000}",
called_computations = [@top_k_gt_f32_comparator]} :
(tensor<16x4xf32>, tensor<16x4xf32>, tensor<f32>, tensor<i32>) -> (tensor<16x1xf32>, tensor<16x1xf32>)
return %0#0, %0#1 : tensor<16x1xf32>, tensor<16x1xf32>
(tensor<16x4xf32>, tensor<16x4xf32>, tensor<f32>, tensor<i32>) -> (tensor<16x2xf32>, tensor<16x2xf32>)
return %0#0, %0#1 : tensor<16x2xf32>, tensor<16x2xf32>
}

// CHECK-LABEL: func @unregisterd_custom_call_with_existing_rule
Expand Down

0 comments on commit 189519a

Please sign in to comment.