From b2ebc5a2a593659d94cc07b2421b0e0bf7c4cbc3 Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Wed, 25 Sep 2024 01:40:32 -0700 Subject: [PATCH] Lower transpose to materialize & insert op PiperOrigin-RevId: 678591749 --- .../gpu/fusions/tests/transpose/epilogue.hlo | 7 +- .../tests/transpose/fused_transpose_021.hlo | 9 +- .../tests/transpose/fused_transpose_102.hlo | 7 +- .../tests/transpose/fused_transpose_210.hlo | 7 +- xla/service/gpu/fusions/transpose_mlir.cc | 97 ++++++++++++------- xla/service/gpu/model/indexing_analysis.cc | 15 +-- xla/service/gpu/model/indexing_analysis.h | 2 + 7 files changed, 88 insertions(+), 56 deletions(-) diff --git a/xla/service/gpu/fusions/tests/transpose/epilogue.hlo b/xla/service/gpu/fusions/tests/transpose/epilogue.hlo index 1ca7036259669..25695c8212f7d 100644 --- a/xla/service/gpu/fusions/tests/transpose/epilogue.hlo +++ b/xla/service/gpu/fusions/tests/transpose/epilogue.hlo @@ -11,10 +11,9 @@ fusion { // CHECK-SAME: }, %[[OUT:.*]]: tensor<20x170x160xf32> // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<1x32x33xf32> -// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.loop -// CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) -// CHECK: %[[EXP:.*]] = xla_gpu.pure_call @fusion_p0 -// CHECK: tensor.insert %[[EXP]] into %[[SHMEM_]] +// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @fusion_p0 +// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.insert %[[MATERIALIZED]] +// CHECK-SAME: into %[[SHMEM]] at #indexing_map // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] diff --git a/xla/service/gpu/fusions/tests/transpose/fused_transpose_021.hlo b/xla/service/gpu/fusions/tests/transpose/fused_transpose_021.hlo index 60e5cd404e150..7c2d63c78a47e 100644 --- a/xla/service/gpu/fusions/tests/transpose/fused_transpose_021.hlo +++ b/xla/service/gpu/fusions/tests/transpose/fused_transpose_021.hlo @@ -9,13 +9,14 @@ fusion { ROOT %abs = f32[20,170,160] abs(%transpose) } // CHECK-LABEL: func.func @main( +// CHECK-SAME: %[[INPUT:.*]]: tensor<20x160x170xf32> { // CHECK-SAME: }, %[[OUT:.*]]: tensor<20x170x160xf32> // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<1x32x33xf32> -// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.loop -// CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) -// CHECK: %[[EXP:.*]] = xla_gpu.pure_call @fusion_exp -// CHECK: tensor.insert %[[EXP]] into %[[SHMEM_]] +// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize +// CHECK-SAME: @fusion_exp(%[[INPUT]]) at #indexing_map +// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.insert %[[MATERIALIZED]] +// CHECK-SAME: into %[[SHMEM]] at #indexing_map // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] diff --git a/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo b/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo index 55c2976d32b34..2fc3855efe4c1 100644 --- a/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo +++ b/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo @@ -10,10 +10,9 @@ fusion { // CHECK-SAME: }, %[[OUT:.*]]: tensor<170x160x3xi8> // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<32x33x3xi8> -// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.loop -// CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) -// CHECK: %[[P0:.*]] = xla_gpu.pure_call @fusion_p0 -// CHECK: tensor.insert %[[P0]] into %[[SHMEM_]] +// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @fusion_p0 +// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.insert %[[MATERIALIZED]] +// CHECK-SAME: into %[[SHMEM]] at #indexing_map // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] diff --git a/xla/service/gpu/fusions/tests/transpose/fused_transpose_210.hlo b/xla/service/gpu/fusions/tests/transpose/fused_transpose_210.hlo index 0dd4a27547514..97e23f171b713 100644 --- a/xla/service/gpu/fusions/tests/transpose/fused_transpose_210.hlo +++ b/xla/service/gpu/fusions/tests/transpose/fused_transpose_210.hlo @@ -12,10 +12,9 @@ fusion { // CHECK-SAME: }, %[[OUT:.*]]: tensor<170x160x20xf32> // // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<32x1x33xf32> -// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.loop -// CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) -// CHECK: %[[EXP:.*]] = xla_gpu.pure_call @fusion_exp -// CHECK: tensor.insert %[[EXP]] into %[[SHMEM_]] +// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @fusion_exp +// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.insert %[[MATERIALIZED]] +// CHECK-SAME: into %[[SHMEM]] at #indexing_map // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] diff --git a/xla/service/gpu/fusions/transpose_mlir.cc b/xla/service/gpu/fusions/transpose_mlir.cc index fd18cef310a8f..57e7d4fa104f5 100644 --- a/xla/service/gpu/fusions/transpose_mlir.cc +++ b/xla/service/gpu/fusions/transpose_mlir.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" @@ -211,6 +212,8 @@ IndexingMap MlirTransposeFusion::GetSharedMemoryIndexing( std::vector dim_var_sizes(6, 1); dim_var_sizes[KernelFusionInterface::kIndexingMapThreadIdxDims[0]] = kNumThreadsPerBlock; + dim_var_sizes[KernelFusionInterface::kIndexingMapBlockIdxDims[0]] = + Product(block_counts_); return {mlir::AffineMap::get(6, 2, thread_offsets, ctx), DimVarsFromTensorSizes(dim_var_sizes), RangeVarsFromTensorSizes({block_size_ / kNumRows, vector_size_}), @@ -233,44 +236,72 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( } else { ++shmem_tensor_size.back(); } + int num_inputs = fusion.fused_instructions_computation()->num_parameters(); + SmallVector callee_operands( + entry_function.getArguments().take_front(num_inputs)); + auto tids_and_bids = EmitThreadAndBlockIds(builder); + auto identity_map = + IndexingMapAttr::get(ctx, CreateIdentityMap(shmem_tensor_size, ctx)); + + // We can assume that all transpose operands have the same shape. + Shape operand_shape = shmem_transposes_.front()->operand(0)->shape(); - // Allocate shared memory. - SmallVector inits; + // Indexing for MaterializeOp to read from input. + auto indexing = GetIndexing(/*input=*/true, operand_shape, ctx); + + // Indexing for InsertOp to write into shared memory. + IndexingMap write_indexing = GetSharedMemoryIndexing(/*read=*/false, ctx); + // As we are writing the same elements that we are reading, any read + // constraints can also be constraints for the write. + for (auto constraint : indexing.GetConstraints()) { + write_indexing.AddConstraint(constraint.first, constraint.second); + } + for (auto [index, bound] : llvm::enumerate(indexing.GetSymbolBounds())) { + write_indexing.GetMutableSymbolBound(index) = bound; + } + write_indexing.Simplify(); + auto dimensions = SmallVector(operand_shape.dimensions().begin(), + operand_shape.dimensions().end()); + SmallVector shmem_tensors; for (auto* transpose : shmem_transposes_) { auto elem_type = mlir_converter::PrimitiveTypeToMlirType( transpose->shape().element_type(), builder); - inits.push_back(builder.create( - RankedTensorType::get(shmem_tensor_size, elem_type))); + auto shmem = builder.create( + RankedTensorType::get(shmem_tensor_size, elem_type)); + auto indexed_vector = + IndexedVectorType::get(ctx, shmem_tensor_size, elem_type, + IndexingMapAttr::get(ctx, write_indexing)); + auto callee = + mlir::SymbolRefAttr::get(call_target_provider(transpose->operand(0))); + + auto materialized = builder.create( + /* result_type=*/indexed_vector, + /*input=*/callee_operands, + /*indices(dimensions)=*/tids_and_bids, + /*callee=*/callee, + /*map=*/IndexingMapAttr::get(ctx, indexing)); + + auto insert = builder.create( + /*result_type=*/shmem.getType(), + /*source=*/materialized.getResult(), + /*indices(dimensions)=*/tids_and_bids, + /*dest=*/shmem, + /*map=*/identity_map); + shmem_tensors.push_back(insert.getResult()); } - // Add output arguments for side outputs. - int num_inputs = fusion.fused_instructions_computation()->num_parameters(); + // Produce all side outputs and then write them. + SmallVector side_output_inits; for (int index : side_output_root_indices_) { - inits.push_back(entry_function.getArgument(num_inputs + index)); + side_output_inits.push_back(entry_function.getArgument(num_inputs + index)); } - - IndexingMap write_indexing = GetSharedMemoryIndexing(/*read=*/false, ctx); auto body_builder = [&](ValueRange symbol_values, ValueRange map_results, ValueRange output_tensors) -> SmallVector { auto input_indices = [&](const HloInstruction* instr) { return ApplyIndexing(GetIndexing(/*input=*/true, instr->shape(), ctx), thread_and_block_ids, symbol_values, builder); }; - SmallVector result_tensors; - auto shmem_indices = ApplyIndexing(write_indexing, thread_and_block_ids, - symbol_values, builder); - for (auto [transpose, output] : - llvm::zip(shmem_transposes_, output_tensors)) { - // Emit loop that writes subgraphs of transpose operands to shmem. - auto result_scalar = mlir_converter::ProvideParameter( - root_computation, transpose, - /*operand_index=*/0, input_indices(transpose->operand(0)), - call_target_provider, entry_function, builder)[0]; - result_tensors.push_back(builder.create( - result_scalar, output, shmem_indices)); - } - // Produce all side outputs and then write them. SmallVector side_outputs; SmallVector> side_output_indices; auto* root_tuple = fusion.fused_expression_root(); @@ -283,22 +314,21 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( side_outputs.append(param_values.begin(), param_values.end()); } + SmallVector result_tensors; for (const auto& [value, indices, output] : - llvm::zip(side_outputs, side_output_indices, - output_tensors.take_back(side_output_roots_.size()))) { + llvm::zip(side_outputs, side_output_indices, output_tensors)) { result_tensors.push_back( builder.create(value, output, indices)); } return result_tensors; }; - - auto indexing = GetIndexing( - /*input=*/true, shmem_transposes_.front()->operand(0)->shape(), ctx); - auto written_vector = mlir_converter::EmitXlaLoopOp( - builder, thread_and_block_ids, inits, indexing, body_builder); - ValueRange written = written_vector; - auto shmem_tensors = written.take_front(shmem_transposes_.size()); + mlir::ValueRange side_output_vector; + if (!side_output_inits.empty()) { + side_output_vector = mlir_converter::EmitXlaLoopOp( + builder, thread_and_block_ids, side_output_inits, indexing, + body_builder); + } WriteResult result; result.shmem_tensors = @@ -307,8 +337,7 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( .getResults(); result.updated_outputs = output_args; for (auto [index, side_output_result] : - llvm::zip(side_output_root_indices_, - written.take_back(side_output_roots_.size()))) { + llvm::zip(side_output_root_indices_, side_output_vector)) { result.updated_outputs[index] = side_output_result; } return result; diff --git a/xla/service/gpu/model/indexing_analysis.cc b/xla/service/gpu/model/indexing_analysis.cc index 18ed0d526862a..ad9215c921b09 100644 --- a/xla/service/gpu/model/indexing_analysis.cc +++ b/xla/service/gpu/model/indexing_analysis.cc @@ -1151,18 +1151,21 @@ std::vector ToTransposeDimensions(const Layout& l) { } // namespace +IndexingMap CreateIdentityMap(absl::Span dimensions, + mlir::MLIRContext* mlir_context) { + return IndexingMap::FromTensorSizes( + AffineMap::getMultiDimIdentityMap(dimensions.size(), mlir_context), + /*dim_upper_bounds=*/dimensions, /*symbol_upper_bounds=*/{}, + /*is_simplified=*/dimensions.empty()); +} + IndexingMap CreateIdentityMap(const Shape& shape, MLIRContext* mlir_context) { if (shape.IsTuple()) { // Should happen only for variadic reduce. In that case all tuple shapes are // equal. return CreateIdentityMap(shape.tuple_shapes(0), mlir_context); } - - auto dimensions = shape.dimensions(); - IndexingMap identity_map = IndexingMap::FromTensorSizes( - AffineMap::getMultiDimIdentityMap(dimensions.size(), mlir_context), - dimensions, {}, /*is_simplified=*/dimensions.empty()); - return identity_map; + return CreateIdentityMap(shape.dimensions(), mlir_context); } llvm::SmallVector DelinearizeInBoundsIndex( diff --git a/xla/service/gpu/model/indexing_analysis.h b/xla/service/gpu/model/indexing_analysis.h index 965b060da30be..d4c170aace206 100644 --- a/xla/service/gpu/model/indexing_analysis.h +++ b/xla/service/gpu/model/indexing_analysis.h @@ -163,6 +163,8 @@ std::vector DelinearizeIndex(absl::Span dims, // Creates an identity indexing map corresponding to the parameter shape. IndexingMap CreateIdentityMap(const Shape& shape, mlir::MLIRContext* mlir_context); +IndexingMap CreateIdentityMap(absl::Span dimensions, + mlir::MLIRContext* mlir_context); llvm::SmallVector DelinearizeInBoundsIndex( mlir::AffineExpr linear, absl::Span sizes);