Skip to content

Commit

Permalink
[Flow] Avoid fusion of dequantization-like ops with producers (#16610)
Browse files Browse the repository at this point in the history
Dequantization-like operations (e.g. arith.extf/extui/extsi) are
generally best to aggressively fuse with consumers because
this minimizes the amount of memory on the boundaries of dispatches.

This patch changes the heuristics in FusionOfTensorOps to avoid fusion
of a dequant-like consumer with its producer. Additionally it disables
multi-use fusion of a multi-use dequant-like producer for the same
reason, namely we would rather fuse by cloning the dequant-like op
instead of typical multi-use fusion.
  • Loading branch information
qedawkins authored Feb 29, 2024
1 parent 24bf0ac commit 8237d9a
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ static bool areFusableOps(MLIRContext *context, OpOperand *fusedOperand,
if (!producerOp->hasOneUse())
return false;

// Do no fuse dequantization-like operations with producers as we want to keep
// the smallest bitwidths at the dispatch boundaries, unless the consumer
// dequantization op only has one use, in which case elementwise op fusion
// is fine.
if (isDequantizationLikeOp(consumerOp) && !consumerOp->hasOneUse()) {
return false;
}

// If the producer has a single use (this op), only fuse if
// - 1) The consumer op is all parallel loops. The parallelism of the consumer
// can be used as a way to amortize cost of redundant computation
Expand All @@ -113,6 +121,7 @@ static bool areFusableOps(MLIRContext *context, OpOperand *fusedOperand,
// broadcast this ends up redundantly computing operations without more
// parallelism.
if (auto linalgConsumerOp = dyn_cast<linalg::LinalgOp>(consumerOp)) {

if (linalgConsumerOp.getNumParallelLoops() ==
linalgConsumerOp.getNumLoops()) {
return true;
Expand Down Expand Up @@ -225,6 +234,12 @@ static FailureOr<unsigned> fuseMultiUseProducers(Operation *funcOp,
return;
}

// Dequantization-like operations should be fused with consumers to keep
// the smaller bit width on the dispatch boundary.
if (isDequantizationLikeOp(genericOp)) {
return;
}

Operation *fusableProducer = nullptr;
for (OpOperand &operand : genericOp->getOpOperands()) {
// 2. Only fuse with `linalg.generic` producers that arent
Expand Down Expand Up @@ -259,7 +274,13 @@ static FailureOr<unsigned> fuseMultiUseProducers(Operation *funcOp,
continue;
}

// 7. All uses from `producer` -> `consumer` need to be fusable.
// 7. Skip dequantization-like `producer` ops as we would rather fuse
// by cloning the producer instead of multi-use fusion.
if (isDequantizationLikeOp(producer)) {
return;
}

// 8. All uses from `producer` -> `consumer` need to be fusable.
// Without this the `producer` is still live, and there is no
// advantage to do the fusion.
if (llvm::any_of(getAllUsesInConsumer(producer, genericOp),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,109 @@ util.func public @fix_issue_14953(%arg0: tensor<11008x32x1xf16>, %arg1: tensor<1
// CHECK: %[[GENERIC1:.+]] = linalg.generic
// CHECK-SAME: ins(%{{.+}}, %[[GENERIC0]] :
// CHECK: flow.return %[[GENERIC1]]

// -----

util.func public @fuse_single_use_dequant_with_producer(%arg0: tensor<12x128x128xf16>, %arg1: tensor<12x128x128xf16>) -> tensor<12x128x128xf32> {
%4 = tensor.empty() : tensor<12x128x128xf16>
%5 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>
], iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0, %arg1 : tensor<12x128x128xf16>, tensor<12x128x128xf16>) outs(%4 : tensor<12x128x128xf16>) {
^bb0(%b0: f16, %b1: f16, %arg2: f16):
%9 = arith.subf %b0, %b1 : f16
linalg.yield %9 : f16
} -> tensor<12x128x128xf16>
%6 = tensor.empty() : tensor<12x128x128xf32>
%7 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>
], iterator_types = ["parallel", "parallel", "parallel"]}
ins(%5 : tensor<12x128x128xf16>) outs(%6 : tensor<12x128x128xf32>) {
^bb0(%b0: f16, %b1: f32):
%10 = arith.extf %b0 : f16 to f32
linalg.yield %10 : f32
} -> tensor<12x128x128xf32>
%8 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>
], iterator_types = ["parallel", "parallel", "parallel"]}
ins(%7 : tensor<12x128x128xf32>) outs(%6 : tensor<12x128x128xf32>) {
^bb0(%b0: f32, %b1: f32):
%11 = math.exp %b0 : f32
linalg.yield %11 : f32
} -> tensor<12x128x128xf32>
util.return %8 : tensor<12x128x128xf32>
}
// CHECK-LABEL: util.func public @fuse_single_use_dequant_with_producer
// CHECK: %[[GENERIC0:.+]] = linalg.generic
// CHECK: arith.subf
// CHECK-NEXT: arith.extf
// CHECK-NEXT: math.exp

// -----

util.func public @no_fuse_multi_use_dequant_with_producer(%arg0: tensor<12x128x128xf16>,
%arg1: tensor<12x128x128xf16>) -> (tensor<12x128x128xf32>, tensor<12x128xf32>) {
%4 = tensor.empty() : tensor<12x128x128xf16>
%5 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>
], iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0, %arg1 : tensor<12x128x128xf16>, tensor<12x128x128xf16>) outs(%4 : tensor<12x128x128xf16>) {
^bb0(%b0: f16, %b1: f16, %arg2: f16):
%10 = arith.subf %b0, %b1 : f16
linalg.yield %10 : f16
} -> tensor<12x128x128xf16>
%6 = tensor.empty() : tensor<12x128x128xf32>
%7 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>
], iterator_types = ["parallel", "parallel", "parallel"]}
ins(%5 : tensor<12x128x128xf16>) outs(%6 : tensor<12x128x128xf32>) {
^bb0(%b0: f16, %b1: f32):
%11 = arith.extf %b0 : f16 to f32
linalg.yield %11 : f32
} -> tensor<12x128x128xf32>
%8 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>
], iterator_types = ["parallel", "parallel", "parallel"]}
ins(%7 : tensor<12x128x128xf32>) outs(%6 : tensor<12x128x128xf32>) {
^bb0(%b0: f32, %b1: f32):
%12 = math.exp %b0 : f32
linalg.yield %12 : f32
} -> tensor<12x128x128xf32>
%empty = tensor.empty() : tensor<12x128xf32>
%9 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>
], iterator_types = ["parallel", "parallel", "reduction"]}
ins(%7 : tensor<12x128x128xf32>) outs(%empty : tensor<12x128xf32>) {
^bb0(%b0: f32, %b1: f32):
%13 = math.rsqrt %b0 : f32
%14 = arith.addf %13, %b1 : f32
linalg.yield %14 : f32
} -> tensor<12x128xf32>
util.return %8, %9 : tensor<12x128x128xf32>, tensor<12x128xf32>
}
// CHECK-LABEL: util.func public @no_fuse_multi_use_dequant_with_producer
// CHECK: %[[GENERIC0:.+]] = linalg.generic
// CHECK: arith.subf
// CHECK: %[[GENERIC1:.+]] = linalg.generic
// CHECK: arith.extf
// CHECK: %[[GENERIC2:.+]] = linalg.generic
// CHECK: math.exp
// CHECK: %[[GENERIC3:.+]] = linalg.generic
// CHECK: math.rsqrt
// CHECK-NEXT: arith.addf

0 comments on commit 8237d9a

Please sign in to comment.