Skip to content

Commit

Permalink
Lower transpose to materialize & insert op
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678591749
  • Loading branch information
vwbaker authored and Google-ML-Automation committed Sep 25, 2024
1 parent 46c95ae commit b2ebc5a
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 56 deletions.
7 changes: 3 additions & 4 deletions xla/service/gpu/fusions/tests/transpose/epilogue.hlo
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ fusion {
// CHECK-SAME: }, %[[OUT:.*]]: tensor<20x170x160xf32>

// CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<1x32x33xf32>
// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.loop
// CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]])
// CHECK: %[[EXP:.*]] = xla_gpu.pure_call @fusion_p0
// CHECK: tensor.insert %[[EXP]] into %[[SHMEM_]]
// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @fusion_p0
// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.insert %[[MATERIALIZED]]
// CHECK-SAME: into %[[SHMEM]] at #indexing_map

// CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ fusion {
ROOT %abs = f32[20,170,160] abs(%transpose)
}
// CHECK-LABEL: func.func @main(
// CHECK-SAME: %[[INPUT:.*]]: tensor<20x160x170xf32> {
// CHECK-SAME: }, %[[OUT:.*]]: tensor<20x170x160xf32>

// CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<1x32x33xf32>
// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.loop
// CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]])
// CHECK: %[[EXP:.*]] = xla_gpu.pure_call @fusion_exp
// CHECK: tensor.insert %[[EXP]] into %[[SHMEM_]]
// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize
// CHECK-SAME: @fusion_exp(%[[INPUT]]) at #indexing_map
// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.insert %[[MATERIALIZED]]
// CHECK-SAME: into %[[SHMEM]] at #indexing_map

// CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ fusion {
// CHECK-SAME: }, %[[OUT:.*]]: tensor<170x160x3xi8>

// CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<32x33x3xi8>
// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.loop
// CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]])
// CHECK: %[[P0:.*]] = xla_gpu.pure_call @fusion_p0
// CHECK: tensor.insert %[[P0]] into %[[SHMEM_]]
// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @fusion_p0
// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.insert %[[MATERIALIZED]]
// CHECK-SAME: into %[[SHMEM]] at #indexing_map

// CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ fusion {
// CHECK-SAME: }, %[[OUT:.*]]: tensor<170x160x20xf32>
//
// CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<32x1x33xf32>
// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.loop
// CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]])
// CHECK: %[[EXP:.*]] = xla_gpu.pure_call @fusion_exp
// CHECK: tensor.insert %[[EXP]] into %[[SHMEM_]]
// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @fusion_exp
// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.insert %[[MATERIALIZED]]
// CHECK-SAME: into %[[SHMEM]] at #indexing_map

// CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]]

Expand Down
97 changes: 63 additions & 34 deletions xla/service/gpu/fusions/transpose_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/MLIRContext.h"
Expand Down Expand Up @@ -211,6 +212,8 @@ IndexingMap MlirTransposeFusion::GetSharedMemoryIndexing(
std::vector<int64_t> dim_var_sizes(6, 1);
dim_var_sizes[KernelFusionInterface::kIndexingMapThreadIdxDims[0]] =
kNumThreadsPerBlock;
dim_var_sizes[KernelFusionInterface::kIndexingMapBlockIdxDims[0]] =
Product(block_counts_);
return {mlir::AffineMap::get(6, 2, thread_offsets, ctx),
DimVarsFromTensorSizes(dim_var_sizes),
RangeVarsFromTensorSizes({block_size_ / kNumRows, vector_size_}),
Expand All @@ -233,44 +236,72 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir(
} else {
++shmem_tensor_size.back();
}
int num_inputs = fusion.fused_instructions_computation()->num_parameters();
SmallVector<Value> callee_operands(
entry_function.getArguments().take_front(num_inputs));
auto tids_and_bids = EmitThreadAndBlockIds(builder);
auto identity_map =
IndexingMapAttr::get(ctx, CreateIdentityMap(shmem_tensor_size, ctx));

// We can assume that all transpose operands have the same shape.
Shape operand_shape = shmem_transposes_.front()->operand(0)->shape();

// Allocate shared memory.
SmallVector<Value> inits;
// Indexing for MaterializeOp to read from input.
auto indexing = GetIndexing(/*input=*/true, operand_shape, ctx);

// Indexing for InsertOp to write into shared memory.
IndexingMap write_indexing = GetSharedMemoryIndexing(/*read=*/false, ctx);
// As we are writing the same elements that we are reading, any read
// constraints can also be constraints for the write.
for (auto constraint : indexing.GetConstraints()) {
write_indexing.AddConstraint(constraint.first, constraint.second);
}
for (auto [index, bound] : llvm::enumerate(indexing.GetSymbolBounds())) {
write_indexing.GetMutableSymbolBound(index) = bound;
}
write_indexing.Simplify();
auto dimensions = SmallVector<int64_t>(operand_shape.dimensions().begin(),
operand_shape.dimensions().end());
SmallVector<Value> shmem_tensors;
for (auto* transpose : shmem_transposes_) {
auto elem_type = mlir_converter::PrimitiveTypeToMlirType(
transpose->shape().element_type(), builder);
inits.push_back(builder.create<AllocateSharedOp>(
RankedTensorType::get(shmem_tensor_size, elem_type)));
auto shmem = builder.create<AllocateSharedOp>(
RankedTensorType::get(shmem_tensor_size, elem_type));
auto indexed_vector =
IndexedVectorType::get(ctx, shmem_tensor_size, elem_type,
IndexingMapAttr::get(ctx, write_indexing));
auto callee =
mlir::SymbolRefAttr::get(call_target_provider(transpose->operand(0)));

auto materialized = builder.create<MaterializeOp>(
/* result_type=*/indexed_vector,
/*input=*/callee_operands,
/*indices(dimensions)=*/tids_and_bids,
/*callee=*/callee,
/*map=*/IndexingMapAttr::get(ctx, indexing));

auto insert = builder.create<InsertOp>(
/*result_type=*/shmem.getType(),
/*source=*/materialized.getResult(),
/*indices(dimensions)=*/tids_and_bids,
/*dest=*/shmem,
/*map=*/identity_map);
shmem_tensors.push_back(insert.getResult());
}

// Add output arguments for side outputs.
int num_inputs = fusion.fused_instructions_computation()->num_parameters();
// Produce all side outputs and then write them.
SmallVector<Value> side_output_inits;
for (int index : side_output_root_indices_) {
inits.push_back(entry_function.getArgument(num_inputs + index));
side_output_inits.push_back(entry_function.getArgument(num_inputs + index));
}

IndexingMap write_indexing = GetSharedMemoryIndexing(/*read=*/false, ctx);
auto body_builder = [&](ValueRange symbol_values, ValueRange map_results,
ValueRange output_tensors) -> SmallVector<Value> {
auto input_indices = [&](const HloInstruction* instr) {
return ApplyIndexing(GetIndexing(/*input=*/true, instr->shape(), ctx),
thread_and_block_ids, symbol_values, builder);
};
SmallVector<Value> result_tensors;
auto shmem_indices = ApplyIndexing(write_indexing, thread_and_block_ids,
symbol_values, builder);
for (auto [transpose, output] :
llvm::zip(shmem_transposes_, output_tensors)) {
// Emit loop that writes subgraphs of transpose operands to shmem.
auto result_scalar = mlir_converter::ProvideParameter(
root_computation, transpose,
/*operand_index=*/0, input_indices(transpose->operand(0)),
call_target_provider, entry_function, builder)[0];
result_tensors.push_back(builder.create<mlir::tensor::InsertOp>(
result_scalar, output, shmem_indices));
}

// Produce all side outputs and then write them.
SmallVector<Value> side_outputs;
SmallVector<SmallVector<Value>> side_output_indices;
auto* root_tuple = fusion.fused_expression_root();
Expand All @@ -283,22 +314,21 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir(
side_outputs.append(param_values.begin(), param_values.end());
}

SmallVector<Value> result_tensors;
for (const auto& [value, indices, output] :
llvm::zip(side_outputs, side_output_indices,
output_tensors.take_back(side_output_roots_.size()))) {
llvm::zip(side_outputs, side_output_indices, output_tensors)) {
result_tensors.push_back(
builder.create<mlir::tensor::InsertOp>(value, output, indices));
}

return result_tensors;
};

auto indexing = GetIndexing(
/*input=*/true, shmem_transposes_.front()->operand(0)->shape(), ctx);
auto written_vector = mlir_converter::EmitXlaLoopOp(
builder, thread_and_block_ids, inits, indexing, body_builder);
ValueRange written = written_vector;
auto shmem_tensors = written.take_front(shmem_transposes_.size());
mlir::ValueRange side_output_vector;
if (!side_output_inits.empty()) {
side_output_vector = mlir_converter::EmitXlaLoopOp(
builder, thread_and_block_ids, side_output_inits, indexing,
body_builder);
}

WriteResult result;
result.shmem_tensors =
Expand All @@ -307,8 +337,7 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir(
.getResults();
result.updated_outputs = output_args;
for (auto [index, side_output_result] :
llvm::zip(side_output_root_indices_,
written.take_back(side_output_roots_.size()))) {
llvm::zip(side_output_root_indices_, side_output_vector)) {
result.updated_outputs[index] = side_output_result;
}
return result;
Expand Down
15 changes: 9 additions & 6 deletions xla/service/gpu/model/indexing_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1151,18 +1151,21 @@ std::vector<int64_t> ToTransposeDimensions(const Layout& l) {

} // namespace

IndexingMap CreateIdentityMap(absl::Span<const int64_t> dimensions,
mlir::MLIRContext* mlir_context) {
return IndexingMap::FromTensorSizes(
AffineMap::getMultiDimIdentityMap(dimensions.size(), mlir_context),
/*dim_upper_bounds=*/dimensions, /*symbol_upper_bounds=*/{},
/*is_simplified=*/dimensions.empty());
}

IndexingMap CreateIdentityMap(const Shape& shape, MLIRContext* mlir_context) {
if (shape.IsTuple()) {
// Should happen only for variadic reduce. In that case all tuple shapes are
// equal.
return CreateIdentityMap(shape.tuple_shapes(0), mlir_context);
}

auto dimensions = shape.dimensions();
IndexingMap identity_map = IndexingMap::FromTensorSizes(
AffineMap::getMultiDimIdentityMap(dimensions.size(), mlir_context),
dimensions, {}, /*is_simplified=*/dimensions.empty());
return identity_map;
return CreateIdentityMap(shape.dimensions(), mlir_context);
}

llvm::SmallVector<AffineExpr, 4> DelinearizeInBoundsIndex(
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/model/indexing_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ std::vector<mlir::AffineExpr> DelinearizeIndex(absl::Span<const int64_t> dims,
// Creates an identity indexing map corresponding to the parameter shape.
IndexingMap CreateIdentityMap(const Shape& shape,
mlir::MLIRContext* mlir_context);
IndexingMap CreateIdentityMap(absl::Span<const int64_t> dimensions,
mlir::MLIRContext* mlir_context);

llvm::SmallVector<mlir::AffineExpr, 4> DelinearizeInBoundsIndex(
mlir::AffineExpr linear, absl::Span<const int64_t> sizes);
Expand Down

0 comments on commit b2ebc5a

Please sign in to comment.