Skip to content

Commit

Permalink
Reland fix to multi-row reduction triggering.
Browse files Browse the repository at this point in the history
Apparently there was no actual breakage, just a numerically
unstable model.

Reverts 87e62ee

PiperOrigin-RevId: 679548584
  • Loading branch information
jreiffers authored and Google-ML-Automation committed Sep 27, 2024
1 parent 7ea3146 commit 9b1056c
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 59 deletions.
160 changes: 102 additions & 58 deletions xla/service/gpu/fusions/reduction_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -770,32 +770,11 @@ llvm::SmallVector<mlir::Value> MlirSmallColumnReductionFusion::EmitReduction(
shared_rows_ / 2);
}

std::unique_ptr<MlirReductionFusion> CreateMlirReductionFusion(
const HloFusionAnalysis& analysis) {
auto* hero_reduction = analysis.FindHeroReduction();
CHECK_NE(hero_reduction, nullptr);
ReductionDimensions reduction_dimensions =
GetReductionKindAndContiguousComponents(*hero_reduction);
if (reduction_dimensions.is_row_reduction) {
if (RowReductionGetRowsPerWarp(
reduction_dimensions.dimensions[kRowMinorReduced]) > 1) {
return std::make_unique<MlirMultiRowReductionFusion>(analysis);
}
return std::make_unique<MlirRowReductionFusion>(analysis);
}

if (WarpSize() % reduction_dimensions.dimensions[kColMinorKept] == 0) {
return std::make_unique<MlirSmallColumnReductionFusion>(analysis);
}
return std::make_unique<MlirColumnReductionFusion>(analysis);
}

MlirRowReductionFusion::MlirRowReductionFusion(
const HloFusionAnalysis& analysis)
: MlirReductionFusion(analysis) {
CHECK(reduction_dimensions_.is_row_reduction);
Vector3 shape = reduction_dimensions_.dimensions;
CHECK_EQ(RowReductionGetRowsPerWarp(shape[kRowMinorReduced]), 1);
constexpr int64_t kMinorReducedElementsPerThread = 16;

int64_t num_threads_kept = 1;
Expand Down Expand Up @@ -931,58 +910,104 @@ llvm::SmallVector<mlir::Value> MlirRowReductionFusion::EmitReduction(
}

MlirMultiRowReductionFusion::MlirMultiRowReductionFusion(
const HloFusionAnalysis& analysis)
const HloFusionAnalysis& analysis, int vector_size)
: MlirReductionFusion(analysis) {
CHECK(reduction_dimensions_.is_row_reduction);
Vector3 shape = reduction_dimensions_.dimensions;
int64_t rows_per_warp = RowReductionGetRowsPerWarp(shape[kRowMinorReduced]);
input_shape_ = {shape[0], shape[1], shape[2]};
CHECK_GT(rows_per_warp, 1);

auto compute_block_size = [&](int vector_size) {
int64_t num_threads_reduced = shape[kRowMinorReduced] / vector_size;

constexpr int64_t kThreadsPerBlockTarget = 256;
int64_t kept_size = reduction_dimensions_.dimensions[kRowKept];
int64_t num_threads_kept = 1;
if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) {
num_threads_kept = kept_size;
} else {
num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced;
}
num_threads_ = {num_threads_kept, num_threads_reduced};
tile_sizes_per_thread_ = {shape[0], vector_size};
num_blocks_ = {CeilOfRatio(input_shape_[kRowKept], num_threads_kept)};
};
num_threads_ = GetNumThreads(reduction_dimensions_, vector_size);
num_blocks_ = {GetNumBlocks(reduction_dimensions_, num_threads_)};
tile_sizes_per_thread_ = {shape[0], vector_size};
}

// Compute the launch grid without vectorization. We use the results to
// compute the vectorized launch grid.
compute_block_size(1);
std::unique_ptr<MlirReductionFusion> MlirMultiRowReductionFusion::TryCreate(
const HloFusionAnalysis& analysis) {
auto* hero_reduction = analysis.FindHeroReduction();
CHECK_NE(hero_reduction, nullptr);
auto reduction_dimensions =
GetReductionKindAndContiguousComponents(*hero_reduction);
auto shape = reduction_dimensions.dimensions;
// This emitter only supports reductions where the reduced dimension is a
// power of 2.
if (shape[kRowMinorReduced] & (shape[kRowMinorReduced] - 1)) {
return nullptr;
}

// Normally, we only consider input types for vectorization. However, in
// multi-row reductions, the input:output ratio is much higher, so we consider
// both inputs and outputs.
int smallest_input_or_output_bits =
std::min(analysis.input_output_info().smallest_input_dtype_bits,
analysis.input_output_info().smallest_output_dtype_bits);
int largest_input_or_output_bits =
std::max(analysis.input_output_info().smallest_input_dtype_bits,
analysis.input_output_info().smallest_output_dtype_bits);

// This vector size is always valid: we know that the reduced dimension is a
// power of 2, since otherwise RowReductionGetRowsPerWarp would have
// returned 1.
// Our codegen can't currently deal with vectorization across rows, so we
// limit the vector size to the size of the row. Note that this emitter
// essentially reverts to the loop emitter in this case, except for side
// outputs.
int vector_size = std::min(static_cast<int>(input_shape_[kRowMinorReduced]),
32 / smallest_input_or_output_bits);

// We target 8 warps per block, which means there could be up to 8 blocks per
// SM, but we have no good way of knowing. In practice, enabling vectorization
// for decently sized reductions at least does not hurt.
if (num_blocks_.front() > analysis.device_info().core_count() &&
vector_size > 1) {
compute_block_size(vector_size);
int vector_size = std::min(static_cast<int>(shape[kRowMinorReduced]),
64 / smallest_input_or_output_bits);

// Very large vector sizes for f32 can be detrimental, so we limit the vector
// size to 16 bytes if we have some >= 32 bit inputs or outputs. This is still
// a bit on the high side, but remember that we also have very small inputs
// or outputs.
if (largest_input_or_output_bits >= 32) {
vector_size = std::min(128 / largest_input_or_output_bits, vector_size);
}

// The reduced dimension must fit into a single warp.
if (shape[kRowMinorReduced] > WarpSize() * vector_size) {
return nullptr;
}

// At the very least, we want to have work for every SM.
// TODO(jreiffers): This limit is probably too low: if we have as many blocks
// as SMs, we'll only run about 8 warps per SM, so occupancy will be very low.
// Further measurements are needed to refine this heuristic.
int64_t min_desired_blocks = analysis.device_info().core_count();
while (vector_size > 1 &&
GetNumBlocks(reduction_dimensions,
GetNumThreads(reduction_dimensions, vector_size)) <
min_desired_blocks) {
vector_size /= 2;
}

// Check again that the reduced dimension fits after potentially reducing the
// vector size.
if (shape[kRowMinorReduced] > WarpSize() * vector_size) {
return nullptr;
}

return std::make_unique<MlirMultiRowReductionFusion>(analysis, vector_size);
}

absl::InlinedVector<int64_t, 4> MlirMultiRowReductionFusion::GetNumThreads(
const ReductionDimensions& reduction_dimensions, int vector_size) {
int64_t num_threads_reduced =
reduction_dimensions.dimensions[kRowMinorReduced] / vector_size;

constexpr int64_t kThreadsPerBlockTarget = 256;
int64_t kept_size = reduction_dimensions.dimensions[kRowKept];
int64_t num_threads_kept = 1;
if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) {
num_threads_kept = kept_size;
} else {
num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced;
}
return {num_threads_kept, num_threads_reduced};
}

int64_t MlirMultiRowReductionFusion::GetNumBlocks(
const ReductionDimensions& reduction_dimensions,
const absl::InlinedVector<int64_t, 4>& num_threads) {
CHECK_EQ(num_threads.size(), 2)
<< "Expected num_threads to contain the number of threads in the {kept, "
"reduced} dimensions.";
return CeilOfRatio(reduction_dimensions.dimensions[kRowKept],
num_threads.front());
}

IndexingMap MlirMultiRowReductionFusion::ComputeReductionInputIndexing(
Expand Down Expand Up @@ -1013,8 +1038,7 @@ IndexingMap MlirMultiRowReductionFusion::ComputeReductionOutputIndexing(
: mlir::getAffineDimExpr(3, ctx);
IndexingMap projected_index =
GetIndexingMap(block_id * num_threads_[0] + thread_id[0]);
projected_index.AddConstraint(thread_id[1] % (WarpSize() / GetRowsPerWarp()),
{0, 0});
projected_index.AddConstraint(thread_id[1] % num_threads_[1], {0, 0});
// We don't need a constraint on the loop dimensions, because they are removed
// by GetIndexingMap (since they don't show up in the output index
// computation).
Expand All @@ -1034,10 +1058,30 @@ llvm::SmallVector<mlir::Value> MlirMultiRowReductionFusion::EmitReduction(
auto per_thread =
state.EmitPerThreadElements(group_id, inits, state.FusionOutputs());
auto reduced = state.ShuffleReduce(reductions, per_thread.reduction_scalars,
WarpSize() / 2 / GetRowsPerWarp());
num_threads_[1] / 2);
return EvaluateEpilogue(reduced, std::move(per_thread.outputs), state,
group_id, /*symbol_values=*/{});
}

std::unique_ptr<MlirReductionFusion> CreateMlirReductionFusion(
const HloFusionAnalysis& analysis) {
auto* hero_reduction = analysis.FindHeroReduction();
CHECK_NE(hero_reduction, nullptr);
ReductionDimensions reduction_dimensions =
GetReductionKindAndContiguousComponents(*hero_reduction);
if (reduction_dimensions.is_row_reduction) {
auto multi_row_emitter = MlirMultiRowReductionFusion::TryCreate(analysis);
if (multi_row_emitter != nullptr) {
return multi_row_emitter;
}
return std::make_unique<MlirRowReductionFusion>(analysis);
}

if (WarpSize() % reduction_dimensions.dimensions[kColMinorKept] == 0) {
return std::make_unique<MlirSmallColumnReductionFusion>(analysis);
}
return std::make_unique<MlirColumnReductionFusion>(analysis);
}

} // namespace gpu
} // namespace xla
17 changes: 16 additions & 1 deletion xla/service/gpu/fusions/reduction_mlir.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#define XLA_SERVICE_GPU_FUSIONS_REDUCTION_MLIR_H_

#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -168,9 +169,23 @@ class MlirRowReductionFusion : public MlirReductionFusion {

class MlirMultiRowReductionFusion : public MlirReductionFusion {
public:
explicit MlirMultiRowReductionFusion(const HloFusionAnalysis& analysis);
MlirMultiRowReductionFusion(const HloFusionAnalysis& analysis,
int vector_size);

// Attempts to create a multi-row reduction emitter for the given analysis.
// Returns nullptr if the fusion is not supported.
static std::unique_ptr<MlirReductionFusion> TryCreate(
const HloFusionAnalysis& analysis);

protected:
// Returns the number of {kept, reduced} threads for the given reduction and
// vector size.
static absl::InlinedVector<int64_t, 4> GetNumThreads(
const ReductionDimensions& reduction_dimensions, int vector_size);
static int64_t GetNumBlocks(
const ReductionDimensions& reduction_dimensions,
const absl::InlinedVector<int64_t, 4>& num_threads);

int GetRowsPerWarp() const;
llvm::SmallVector<mlir::Value> EmitReduction(
int group_id, EmitterState& state) const override;
Expand Down
22 changes: 22 additions & 0 deletions xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \
// RUN: -xla-gpu-test-transform-loops | FileCheck %s

// The reference implementation reduces in f64, so we need a larger tolerance.
// RUN: test_correctness %s --bijection_inputs=reduce:0 \
// RUN: --bijection_outputs=reduce --abs_error_bound=0.005 --rel_error_bound=0.005

add {
lhs = f16[] parameter(0)
rhs = f16[] parameter(1)
ROOT add = f16[] add(lhs, rhs)
}

fusion {
param_0 = f16[2048,64] parameter(0)
c = f16[] constant(0)
ROOT reduce = f16[2048] reduce(param_0, c), dimensions={1}, to_apply=add
}

// If unvectorized, this would be a regular row reduction. However, since we can
// vectorize to size four, we can emit this as a multi-row reduction.
// CHECK: vector.transfer_read {{.*}} vector<4xf16>

0 comments on commit 9b1056c

Please sign in to comment.