From 436cd54e5d5b6e20af828d81711701d708cb028a Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Tue, 9 Apr 2024 02:46:45 -0700 Subject: [PATCH] [XLA:GPU] Compute a graph of TiledHloInstruction for given tile parameters. Use TiledHloInstructions in the Triton emitter. PiperOrigin-RevId: 623103735 --- xla/service/gpu/BUILD | 1 + xla/service/gpu/ir_emitter_triton.cc | 80 ++++---- xla/service/gpu/model/BUILD | 6 +- .../gpu/model/symbolic_tile_analysis.cc | 175 +++++++++++------- .../gpu/model/symbolic_tile_analysis.h | 58 ++---- .../gpu/model/symbolic_tile_analysis_test.cc | 111 ++++++----- 6 files changed, 231 insertions(+), 200 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index e1edc1110968c..6b4e152a6b95b 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -538,6 +538,7 @@ cc_library( "//xla/service/gpu/model:indexing_map", "//xla/service/gpu/model:symbolic_tile_analysis", "//xla/service/gpu/model:symbolic_tiled_hlo_instruction", + "//xla/service/gpu/model:tiled_hlo_instruction", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", "//xla/stream_executor:launch_dim", diff --git a/xla/service/gpu/ir_emitter_triton.cc b/xla/service/gpu/ir_emitter_triton.cc index f1bbceb65c676..235ebe695a352 100644 --- a/xla/service/gpu/ir_emitter_triton.cc +++ b/xla/service/gpu/ir_emitter_triton.cc @@ -115,7 +115,7 @@ limitations under the License. #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/model/symbolic_tile_analysis.h" -#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" +#include "xla/service/gpu/model/tiled_hlo_instruction.h" #include "xla/service/gpu/target_util.h" #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/service/gpu/triton_tiling_propagation.h" @@ -738,11 +738,10 @@ absl::StatusOr EmitNestedFusion( // TODO(b/331332678): Add unit tests to target this function specifically. Value EmitTiledBroadcast( - ImplicitLocOpBuilder& b, const SymbolicTileAnalysis& analysis, - const SymbolicTiledHloInstruction& tiled_broadcast, - absl::flat_hash_map& values) { - auto input_tile_shape = analysis.TileSizes(*tiled_broadcast.operand(0)); - auto output_tile_shape = analysis.TileSizes(tiled_broadcast); + ImplicitLocOpBuilder& b, const TiledHloInstruction& tiled_broadcast, + absl::flat_hash_map& values) { + auto input_tile_shape = tiled_broadcast.operand(0)->tile_sizes(); + auto output_tile_shape = tiled_broadcast.tile_sizes(); Value expanded_input = values[tiled_broadcast.operand(0)]; @@ -799,11 +798,10 @@ Value EmitTiledBroadcast( absl::StatusOr EmitTiledHloInstruction( ImplicitLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, - const SymbolicTileAnalysis& analysis, - const SymbolicTiledHloInstruction& tiled_hlo, - std::function(const SymbolicTiledHloInstruction&)> + const TiledHloInstruction& tiled_hlo, + std::function(const TiledHloInstruction&)> emit_param_load_fn, - absl::flat_hash_map& values) { + absl::flat_hash_map& values) { const HloInstruction* hlo = tiled_hlo.hlo(); if (hlo->opcode() == HloOpcode::kParameter) { @@ -817,7 +815,7 @@ absl::StatusOr EmitTiledHloInstruction( } if (hlo->opcode() == HloOpcode::kBroadcast) { - return EmitTiledBroadcast(b, analysis, tiled_hlo, values); + return EmitTiledBroadcast(b, tiled_hlo, values); } if (hlo->opcode() == HloOpcode::kReduce) { @@ -829,7 +827,7 @@ absl::StatusOr EmitTiledHloInstruction( std::vector operands; operands.reserve(hlo->operands().size()); - for (const SymbolicTiledHloInstruction* operand : tiled_hlo.operands()) { + for (const TiledHloInstruction* operand : tiled_hlo.operands()) { operands.push_back(values[operand]); } return EmitElementwise(b, libdevice_path, device_info, *hlo, operands); @@ -852,21 +850,22 @@ absl::StatusOr EmitTiledHloInstruction( absl::StatusOr EmitTiledScope( ImplicitLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, - const SymbolicTileAnalysis& analysis, - std::function(const SymbolicTiledHloInstruction&)> + const std::vector>& + tiled_hlo_instructions, + std::function(const TiledHloInstruction&)> emit_param_load_fn, - absl::flat_hash_map& values) { - for (const auto& tiled_hlo : analysis.GetTiledHloInstructions()) { + absl::flat_hash_map& values) { + for (const auto& tiled_hlo : tiled_hlo_instructions) { TF_ASSIGN_OR_RETURN( Value result, - EmitTiledHloInstruction(b, libdevice_path, device_info, analysis, - *tiled_hlo, emit_param_load_fn, values)); + EmitTiledHloInstruction(b, libdevice_path, device_info, *tiled_hlo, + emit_param_load_fn, values)); TF_RET_CHECK(values.insert({tiled_hlo.get(), result}).second) << tiled_hlo->hlo()->ToString(); VLOG(8) << "Emitted " << tiled_hlo->hlo()->ToString(HloPrintOptions::ShortParsable()); } - return values[analysis.GetRoot()]; + return values[tiled_hlo_instructions.back().get()]; } // Emit sequence of instructions using compatible tiling ordered producers @@ -2219,10 +2218,10 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, // `tile_offset_indexing` is a mapping from // (program_id) -> [tile_offset0, ..., tile_offsetN] Value ComputeBasePtrOffset(ImplicitLocOpBuilder b, Value pid, - const Shape& shape, - const IndexingMap& tile_offset_indexing) { + const TiledHloInstruction& tiled_hlo) { + const Shape& shape = tiled_hlo.hlo()->shape(); ArrayRef dimension_exprs = - tile_offset_indexing.GetAffineMap().getResults(); + tiled_hlo.block_id_to_tile_offsets_indexing().GetAffineMap().getResults(); mlir::AffineExpr linear_index = mlir::getAffineConstantExpr(0, b.getContext()); @@ -2292,7 +2291,9 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder, computation->root_instruction()->shape().rank(), 1); output_tile_sizes.back() = row_len; - analysis->SetTileSizes(output_tile_sizes); + TF_ASSIGN_OR_RETURN( + std::vector> tiled_hlo_instructions, + analysis->ComputeTiledHloInstructions(output_tile_sizes)); // block_size must be a power of two. int result_block_size = llvm::PowerOf2Ceil(row_len); @@ -2303,27 +2304,21 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder, } // Emits load instructions - auto emit_param_load = [&](const SymbolicTiledHloInstruction& tiled_hlo) - -> absl::StatusOr { + auto emit_param_load = + [&](const TiledHloInstruction& tiled_hlo) -> absl::StatusOr { std::vector tile_sizes, tile_strides, tile_offsets; - for (auto [size, stride, offset] : llvm::zip( - analysis->TileSizes(tiled_hlo), analysis->TileStrides(tiled_hlo), - analysis->TileOffsets(tiled_hlo))) { + for (auto [size, stride] : + llvm::zip(tiled_hlo.tile_sizes(), tiled_hlo.tile_strides())) { if (size == 1) continue; tile_sizes.push_back(CreateConst(b, b.getI64Type(), size)); tile_strides.push_back(CreateConst(b, b.getI64Type(), stride)); - tile_offsets.push_back(CreateConst(b, b.getI32Type(), offset)); + tile_offsets.push_back(CreateConst(b, b.getI32Type(), 0)); } - TF_ASSIGN_OR_RETURN( - IndexingMap program_id_to_input_tile_indexing, - analysis->ComputeBlockIdToTileOffsetIndexing(tiled_hlo)); - // Manually compute pointer offset to avoid materialized fully parallel // dimensions in the tile. Current codegen tried to avoid size-1 dims. - Value ptr_offset = ComputeBasePtrOffset(b, pid, tiled_hlo.hlo()->shape(), - program_id_to_input_tile_indexing); + Value ptr_offset = ComputeBasePtrOffset(b, pid, tiled_hlo); auto fn_arg = fn.getArgument(tiled_hlo.hlo()->parameter_number()); auto tile_ptr = AddPtr(b, fn_arg, ptr_offset); @@ -2343,17 +2338,14 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder, return EmitParameterLoad(b, emitted_tensor, boundary_checks); }; - absl::flat_hash_map values_out; - TF_ASSIGN_OR_RETURN(Value result, - EmitTiledScope(b, libdevice_path, device_info, *analysis, - emit_param_load, values_out)); - + absl::flat_hash_map values_out; TF_ASSIGN_OR_RETURN( - IndexingMap program_id_to_output_tile_indexing, - analysis->ComputeBlockIdToTileOffsetIndexing(*analysis->GetRoot())); + Value result, + EmitTiledScope(b, libdevice_path, device_info, tiled_hlo_instructions, + emit_param_load, values_out)); - Value ptr_offset = ComputeBasePtrOffset(b, pid, root_shape, - program_id_to_output_tile_indexing); + Value ptr_offset = + ComputeBasePtrOffset(b, pid, *tiled_hlo_instructions.back()); Value store_tensor = b.create( /*base=*/AddPtr(b, fn.getArgument(computation->num_parameters()), diff --git a/xla/service/gpu/model/BUILD b/xla/service/gpu/model/BUILD index 9d40d7a510427..c07a386f84fda 100644 --- a/xla/service/gpu/model/BUILD +++ b/xla/service/gpu/model/BUILD @@ -632,8 +632,8 @@ cc_library( ":indexing_map", ":symbolic_tile", ":symbolic_tiled_hlo_instruction", + ":tiled_hlo_instruction", "//xla:status", - "//xla:status_macros", "//xla/hlo/ir:hlo", "//xla/service:instruction_fusion", "@com_google_absl//absl/algorithm:container", @@ -646,6 +646,8 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", ], ) @@ -655,7 +657,7 @@ xla_cc_test( deps = [ ":indexing_test_utils", ":symbolic_tile_analysis", - ":symbolic_tiled_hlo_instruction", + ":tiled_hlo_instruction", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", diff --git a/xla/service/gpu/model/symbolic_tile_analysis.cc b/xla/service/gpu/model/symbolic_tile_analysis.cc index 424c11266e751..3b9e376c8b32e 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -41,9 +41,11 @@ limitations under the License. #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/model/symbolic_tile.h" #include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" +#include "xla/service/gpu/model/tiled_hlo_instruction.h" #include "xla/service/instruction_fusion.h" #include "xla/status.h" -#include "xla/status_macros.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -55,7 +57,7 @@ using ::mlir::MLIRContext; // Computes indexing map from program id into the tile offset for the given // shape and tile sizes. -IndexingMap ComputeProgramIdToOutputTileIndexing( +IndexingMap ComputeBlockIdToOutputTileIndexing( absl::Span dimensions, absl::Span tile_sizes, mlir::MLIRContext* mlir_context) { CHECK_EQ(dimensions.size(), tile_sizes.size()); // Crash OK @@ -87,6 +89,50 @@ IndexingMap ComputeProgramIdToOutputTileIndexing( /*dim_upper_bounds=*/{num_tiles}, /*symbol_upper_bounds=*/{}); } +absl::StatusOr ComputeBlockIdToTileOffsetIndexing( + const SymbolicTiledHloInstruction& tiled_hlo, + const IndexingMap& block_id_to_root_tile_offset, + mlir::MLIRContext* mlir_context) { + IndexingMap block_id_to_tile_offset_indexing = ComposeIndexingMaps( + block_id_to_root_tile_offset, tiled_hlo.indexing_map()); + + // A symbol in an indexing map means that to produce on element of output, we + // need to read all elements of input in the symbol range. Since this function + // computes start of the tile, we need to substitute each symbol with its + // lower bound value. We assume here the iteration order is normalized. + // TODO(b/330906085): Support cases when tile offsets are not 0. + if (absl::c_any_of(block_id_to_tile_offset_indexing.GetSymbolBounds(), + [](const Interval& symbol_bound) { + return symbol_bound.lower != 0; + })) { + return absl::FailedPreconditionError( + absl::StrCat("Symbol lower bound is not zero. ", + block_id_to_tile_offset_indexing.ToString())); + } + + std::vector symbol_lower_bounds( + block_id_to_tile_offset_indexing.GetSymbolCount(), + mlir::getAffineConstantExpr(0, mlir_context)); + + mlir::AffineMap simplified_affine_map = + block_id_to_tile_offset_indexing.GetAffineMap().replaceDimsAndSymbols( + /*dimReplacements=*/{}, symbol_lower_bounds, + block_id_to_tile_offset_indexing.GetDimVarsCount(), + /*numResultSyms=*/ + block_id_to_tile_offset_indexing.GetRangeVarsCount()); + + IndexingMap simplified_indexing_map = IndexingMap{ + simplified_affine_map, block_id_to_tile_offset_indexing.GetDimVars(), + block_id_to_tile_offset_indexing.GetRangeVars(), + block_id_to_tile_offset_indexing.GetRTVars()}; + + simplified_indexing_map.Simplify(GetIndexingMapForInstruction); + simplified_indexing_map.RescaleSymbols(); + simplified_indexing_map.RemoveUnusedSymbols(); + + return simplified_indexing_map; +} + } // namespace /*static*/ SymbolicTileAnalysisOrError SymbolicTileAnalysis::AnalyzeComputation( @@ -194,81 +240,78 @@ IndexingMap ComputeProgramIdToOutputTileIndexing( return SymbolicTileAnalysis(std::move(tiled_hlo_instructions), ctx); } -std::vector SymbolicTileAnalysis::TileOffsets( - const SymbolicTiledHloInstruction& tiled_hlo) const { - CHECK(tile_parameters_.has_value()) // Crash OK - << "SetTileSizes() must be called before TileOffsets()"; - return tiled_hlo.TileOffsets(*tile_parameters_); -} +absl::StatusOr>> +SymbolicTileAnalysis::ComputeTiledHloInstructions( + const std::vector& tile_parameters) const { + IndexingMap block_id_to_root_tile_offset = ComputeBlockIdToOutputTileIndexing( + GetRoot()->hlo()->shape().dimensions(), tile_parameters, context_); -// TODO(bchetioui): remove dependency on stride and offset parameters. -std::vector SymbolicTileAnalysis::TileSizes( - const SymbolicTiledHloInstruction& tiled_hlo) const { - CHECK(tile_parameters_.has_value()) // Crash OK - << "SetTileSizes() must be called before TileSizes()"; - return tiled_hlo.TileSizes(*tile_parameters_); -} + std::vector> tiled_hlo_instructions; + absl::flat_hash_map + symbolic_to_tiled_hlo_map; + absl::flat_hash_set + tiled_hlo_instructions_set; -std::vector SymbolicTileAnalysis::TileStrides( - const SymbolicTiledHloInstruction& tiled_hlo) const { - CHECK(tile_parameters_.has_value()) // Crash OK - << "SetTileSizes() must be called before TileStrides()"; - return tiled_hlo.TileStrides(*tile_parameters_); -} + absl::flat_hash_map topological_order; -absl::StatusOr -SymbolicTileAnalysis::ComputeBlockIdToTileOffsetIndexing( - const SymbolicTiledHloInstruction& tiled_hlo) const { - TF_RET_CHECK(block_id_to_root_tile_offset_.has_value()) - << "SetTileSizes() must be called before " - "ComputeBlockIdToTileOffsetIndexing()"; + std::function( + const SymbolicTiledHloInstruction*)> + get_tiled_hlo_instruction; - IndexingMap block_id_to_tile_offset_indexing = ComposeIndexingMaps( - *block_id_to_root_tile_offset_, tiled_hlo.indexing_map()); + get_tiled_hlo_instruction = + [&](const SymbolicTiledHloInstruction* symbolic_tiled_hlo) + -> absl::StatusOr { + auto it1 = symbolic_to_tiled_hlo_map.find(symbolic_tiled_hlo); + if (it1 != symbolic_to_tiled_hlo_map.end()) { + return it1->second; + } - // A symbol in an indexing map means that to produce on element of output, we - // need to read all elements of input in the symbol range. Since this function - // computes start of the tile, we need to substitute each symbol with its - // lower bound value. We assume here the iteration order is normalized. - // TODO(b/330906085): Support cases when tile offsets are not 0. - if (absl::c_any_of(block_id_to_tile_offset_indexing.GetSymbolBounds(), - [](const Interval& symbol_bound) { - return symbol_bound.lower != 0; - })) { - return absl::FailedPreconditionError( - absl::StrCat("Symbol lower bound is not zero. ", - block_id_to_tile_offset_indexing.ToString())); - } + std::vector tile_sizes = + symbolic_tiled_hlo->TileSizes(tile_parameters); + std::vector tile_strides = + symbolic_tiled_hlo->TileStrides(tile_parameters); + + TF_ASSIGN_OR_RETURN( + IndexingMap block_id_to_block_offset_indexing, + ComputeBlockIdToTileOffsetIndexing( + *symbolic_tiled_hlo, block_id_to_root_tile_offset, context_)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr tiled_hlo_holder, + TiledHloInstruction::Create( + symbolic_tiled_hlo->hlo(), std::move(tile_sizes), + std::move(tile_strides), + std::move(block_id_to_block_offset_indexing))); + + auto it2 = tiled_hlo_instructions_set.find(tiled_hlo_holder.get()); + if (it2 != tiled_hlo_instructions_set.end()) { + return *it2; + } - std::vector symbol_lower_bounds( - block_id_to_tile_offset_indexing.GetSymbolCount(), - mlir::getAffineConstantExpr(0, context_)); + tiled_hlo_instructions.push_back(std::move(tiled_hlo_holder)); + TiledHloInstruction* tiled_hlo = tiled_hlo_instructions.back().get(); + tiled_hlo_instructions_set.insert(tiled_hlo); + symbolic_to_tiled_hlo_map[symbolic_tiled_hlo] = tiled_hlo; - mlir::AffineMap simplified_affine_map = - block_id_to_tile_offset_indexing.GetAffineMap().replaceDimsAndSymbols( - /*dimReplacements=*/{}, symbol_lower_bounds, - block_id_to_tile_offset_indexing.GetDimVarsCount(), - /*numResultSyms=*/ - block_id_to_tile_offset_indexing.GetRangeVarsCount()); - - IndexingMap simplified_indexing_map = IndexingMap{ - simplified_affine_map, block_id_to_tile_offset_indexing.GetDimVars(), - block_id_to_tile_offset_indexing.GetRangeVars(), - block_id_to_tile_offset_indexing.GetRTVars()}; + for (SymbolicTiledHloInstruction* operand : + symbolic_tiled_hlo->operands()) { + TF_ASSIGN_OR_RETURN(TiledHloInstruction * tiled_operand, + get_tiled_hlo_instruction(operand)); + tiled_hlo->AppendOperand(tiled_operand); + } - simplified_indexing_map.Simplify(GetIndexingMapForInstruction); - simplified_indexing_map.RescaleSymbols(); - simplified_indexing_map.RemoveUnusedSymbols(); + topological_order[tiled_hlo] = topological_order.size(); + return tiled_hlo; + }; - return simplified_indexing_map; -} + TF_CHECK_OK(get_tiled_hlo_instruction(GetRoot()).status()); -void SymbolicTileAnalysis::SetTileSizes(std::vector sizes) { - block_id_to_root_tile_offset_ = ComputeProgramIdToOutputTileIndexing( - GetRoot()->hlo()->shape().dimensions(), sizes, context_); + // Order instructions in def-before-use order. + absl::c_sort(tiled_hlo_instructions, [&](const auto& i1, const auto& i2) { + return topological_order.at(i1.get()) < topological_order.at(i2.get()); + }); - // TODO(bchetioui): CHECK num parameters somehow? - tile_parameters_ = std::vector(std::move(sizes)); + return tiled_hlo_instructions; } } // namespace gpu diff --git a/xla/service/gpu/model/symbolic_tile_analysis.h b/xla/service/gpu/model/symbolic_tile_analysis.h index db643f924fe8a..40b184d795f38 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis.h +++ b/xla/service/gpu/model/symbolic_tile_analysis.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -26,8 +25,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" +#include "xla/service/gpu/model/tiled_hlo_instruction.h" #include "xla/service/instruction_fusion.h" namespace xla { @@ -50,42 +49,21 @@ class SymbolicTileAnalysis { static SymbolicTileAnalysisOrError AnalyzeComputation( const HloComputation& computation, mlir::MLIRContext* ctx); - // Evaluates the tile offsets of an instruction from the analyzed computation - // following the provided path from the root. Tile parameters must have been - // set before calling this method. - std::vector TileOffsets( - const SymbolicTiledHloInstruction& tiled_hlo) const; - // Evaluates the tile sizes of an instruction from the analyzed computation - // following the provided path from the root. Tile parameters must have been - // set before calling this method. - std::vector TileSizes( - const SymbolicTiledHloInstruction& tiled_hlo) const; - // Evaluates the tile strides of an instruction from the analyzed computation - // following the provided path from the root. Tile parameters must have been - // set before calling this method. - std::vector TileStrides( - const SymbolicTiledHloInstruction& tiled_hlo) const; - - // Computes the indexing map from block id to tile offset of the tiled HLO - // instruction. The indexing map has the following form: - // - // (block_id) -> (tile_offset0, tile_offset1, ...) - absl::StatusOr ComputeBlockIdToTileOffsetIndexing( - const SymbolicTiledHloInstruction& tiled_hlo) const; - - // Populates input tile sizes. This is a prerequisite in order to extract - // concrete values using `TileOffsets`, `TileSizes`, and `TileStrides`. - void SetTileSizes(std::vector sizes); + // Returns a graph of HLO instructions tiled with the given tile parameters. + // Result vector has instructions in def-before-use order. + absl::StatusOr>> + ComputeTiledHloInstructions( + const std::vector& tile_parameters) const; // Returns the tiled root instruction. const SymbolicTiledHloInstruction* GetRoot() const { - return tiled_hlo_instructions_.back().get(); + return symbolic_tiled_hlo_instructions_.back().get(); } - // Returns the tiled HLO instructions in def-before-use order. + // Returns the symbolic tiled HLO instructions in def-before-use order. const std::vector>& - GetTiledHloInstructions() const { - return tiled_hlo_instructions_; + GetSymbolicTiledHloComputation() const { + return symbolic_tiled_hlo_instructions_; } // Return the underlying MLIRContext. @@ -93,25 +71,17 @@ class SymbolicTileAnalysis { private: SymbolicTileAnalysis(std::vector> - tiled_hlo_instructions, + symbolic_tiled_hlo_instructions, mlir::MLIRContext* context) - : tiled_hlo_instructions_(std::move(tiled_hlo_instructions)), + : symbolic_tiled_hlo_instructions_( + std::move(symbolic_tiled_hlo_instructions)), context_(context) {} // The tiled HLO instructions in def-before-use order. std::vector> - tiled_hlo_instructions_; + symbolic_tiled_hlo_instructions_; mlir::MLIRContext* context_; - // Optionally set tile parameters. These parameters can be set by calling - // `SetTileParameters`, and correspond to the output tile for the analyzed - // computation. The order and type of parameters are as explained in the - // documentation of `SymbolicTile`. - std::optional> tile_parameters_; - - // Indexing map from block id to root tile offset. Computed from the tile - // parameters. - std::optional block_id_to_root_tile_offset_; }; } // namespace gpu diff --git a/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 8925c46f91721..b6340a83636e3 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -26,7 +27,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/model/indexing_test_utils.h" -#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" +#include "xla/service/gpu/model/tiled_hlo_instruction.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" @@ -74,12 +75,13 @@ ENTRY main { EXPECT_TRUE(SetAnalysis(module.get())); - const SymbolicTiledHloInstruction* root = analysis_->GetRoot(); + TF_ASSERT_OK_AND_ASSIGN( + std::vector> tiled_hlo_instructions, + analysis_->ComputeTiledHloInstructions(/*tile_parameters=*/{1, 10})); - analysis_->SetTileSizes(/*sizes=*/{1, 10}); + TiledHloInstruction* root = tiled_hlo_instructions.back().get(); - EXPECT_THAT(*analysis_->ComputeBlockIdToTileOffsetIndexing(*root), - MatchIndexingMap(R"( + EXPECT_THAT(root->block_id_to_tile_offsets_indexing(), MatchIndexingMap(R"( (d0) -> (d0 floordiv 10, (d0 mod 10) * 10) domain: d0 in [0, 19] @@ -88,25 +90,21 @@ ENTRY main { auto p0_from_subtract0 = root->operand(0); auto p0_from_subtract1 = root->operand(1)->operand(0)->operand(0); - EXPECT_THAT(analysis_->TileOffsets(*p0_from_subtract0), ElementsAre(0, 0)); - EXPECT_THAT(analysis_->TileSizes(*p0_from_subtract0), ElementsAre(1, 10)); - EXPECT_THAT(analysis_->TileStrides(*p0_from_subtract0), ElementsAre(1, 1)); + EXPECT_THAT(p0_from_subtract0->tile_sizes(), ElementsAre(1, 10)); + EXPECT_THAT(p0_from_subtract0->tile_strides(), ElementsAre(1, 1)); - EXPECT_THAT( - *analysis_->ComputeBlockIdToTileOffsetIndexing(*p0_from_subtract0), - MatchIndexingMap(R"( + EXPECT_THAT(p0_from_subtract0->block_id_to_tile_offsets_indexing(), + MatchIndexingMap(R"( (d0) -> (d0 floordiv 10, (d0 mod 10) * 10) domain: d0 in [0, 19] )")); - EXPECT_THAT(analysis_->TileOffsets(*p0_from_subtract1), ElementsAre(0, 0)); - EXPECT_THAT(analysis_->TileSizes(*p0_from_subtract1), ElementsAre(1, 97)); - EXPECT_THAT(analysis_->TileStrides(*p0_from_subtract1), ElementsAre(1, 1)); + EXPECT_THAT(p0_from_subtract1->tile_sizes(), ElementsAre(1, 97)); + EXPECT_THAT(p0_from_subtract1->tile_strides(), ElementsAre(1, 1)); - EXPECT_THAT( - *analysis_->ComputeBlockIdToTileOffsetIndexing(*p0_from_subtract1), - MatchIndexingMap(R"( + EXPECT_THAT(p0_from_subtract1->block_id_to_tile_offsets_indexing(), + MatchIndexingMap(R"( (d0) -> (d0 floordiv 10, 0) domain: d0 in [0, 19] @@ -125,7 +123,11 @@ ENTRY main { EXPECT_TRUE(SetAnalysis(module.get())); - const SymbolicTiledHloInstruction* root = analysis_->GetRoot(); + TF_ASSERT_OK_AND_ASSIGN( + std::vector> tiled_hlo_instructions, + analysis_->ComputeTiledHloInstructions(/*tile_parameters=*/{1, 10})); + + TiledHloInstruction* root = tiled_hlo_instructions.back().get(); auto p0_from_subtract0 = root->operand(0)->operand(0); auto p0_from_subtract1 = root->operand(1)->operand(0); @@ -143,18 +145,19 @@ ENTRY main { EXPECT_TRUE(SetAnalysis(module.get())); - analysis_->SetTileSizes(/*sizes=*/{2, 4, 2}); + TF_ASSERT_OK_AND_ASSIGN( + std::vector> tiled_hlo_instructions, + analysis_->ComputeTiledHloInstructions(/*tile_parameters=*/{2, 4, 2})); - const SymbolicTiledHloInstruction* root = analysis_->GetRoot(); + TiledHloInstruction* root = tiled_hlo_instructions.back().get(); - EXPECT_THAT(*analysis_->ComputeBlockIdToTileOffsetIndexing(*root), - MatchIndexingMap(R"( + EXPECT_THAT(root->block_id_to_tile_offsets_indexing(), MatchIndexingMap(R"( (d0) -> ((d0 floordiv 16) * 2, ((d0 floordiv 8) mod 2) * 4, (d0 mod 8) * 2) domain: d0 in [0, 31] )")); - EXPECT_THAT(*analysis_->ComputeBlockIdToTileOffsetIndexing(*root->operand(0)), + EXPECT_THAT(root->operand(0)->block_id_to_tile_offsets_indexing(), MatchIndexingMap(R"( (d0) -> (((d0 floordiv 8) mod 2) * 4, (d0 mod 8) * 2, (d0 floordiv 16) * 2) domain: @@ -162,6 +165,47 @@ ENTRY main { )")); } +TEST_F(SymbolicTileAnalysisTest, SliceOffsetIndexingIsCorrect) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY main { + p0 = f32[8,16] parameter(0) + slice.0 = f32[4,8] slice(p0), slice={[0:4], [2:10]} + slice.1 = f32[4,8] slice(p0), slice={[3:7], [4:12]} + ROOT add = f32[4,8] add(slice.0, slice.1) +})")); + + EXPECT_TRUE(SetAnalysis(module.get())); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector> tiled_hlo_instructions, + analysis_->ComputeTiledHloInstructions(/*tile_parameters=*/{2, 2})); + + TiledHloInstruction* root = tiled_hlo_instructions.back().get(); + const TiledHloInstruction* p0_from_slice0 = root->operand(0)->operand(0); + const TiledHloInstruction* p0_from_slice1 = root->operand(1)->operand(0); + + EXPECT_THAT(root->block_id_to_tile_offsets_indexing(), MatchIndexingMap(R"( + (d0) -> ((d0 floordiv 4) * 2, (d0 mod 4) * 2) + domain: + d0 in [0, 7] + )")); + + EXPECT_THAT(p0_from_slice0->block_id_to_tile_offsets_indexing(), + MatchIndexingMap(R"( + (d0) -> ((d0 floordiv 4) * 2, (d0 mod 4) * 2 + 2) + domain: + d0 in [0, 7] + )")); + + EXPECT_THAT(p0_from_slice1->block_id_to_tile_offsets_indexing(), + MatchIndexingMap(R"( + (d0) -> ((d0 floordiv 4) * 2 + 3, (d0 mod 4) * 2 + 4) + domain: + d0 in [0, 7] + )")); +} + TEST_F(SymbolicTileAnalysisTest, BailOutOnUnsupportedDot) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( @@ -195,10 +239,6 @@ ENTRY main { ROOT bitcast = f32[2] bitcast(p0) })")); - mlir::MLIRContext mlir_ctx; - SymbolicTileAnalysisOrError analysis_or_error = - SymbolicTileAnalysis::AnalyzeComputation(*module->entry_computation(), - &mlir_ctx); EXPECT_FALSE(SetAnalysis(module.get())); } @@ -214,23 +254,6 @@ ENTRY main { EXPECT_FALSE(SetAnalysis(module.get())); } -TEST_F(SymbolicTileAnalysisTest, ComputingIndexingMapFailsWithoutTileSizes) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY main { - p0 = f32[4,8]{1,0} parameter(0) - ROOT exponential = f32[4,8]{1,0} exponential(p0) -})")); - - EXPECT_TRUE(SetAnalysis(module.get())); - - const SymbolicTiledHloInstruction* root = analysis_->GetRoot(); - - EXPECT_THAT( - analysis_->ComputeBlockIdToTileOffsetIndexing(*root).status().message(), - ::testing::HasSubstr("SetTileSizes() must be called before")); -} - } // namespace } // namespace gpu } // namespace xla