Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Compute a graph of TiledHloInstruction for given tile parameters. #11266

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ cc_library(
"//xla/service/gpu/model:indexing_map",
"//xla/service/gpu/model:symbolic_tile_analysis",
"//xla/service/gpu/model:symbolic_tiled_hlo_instruction",
"//xla/service/gpu/model:tiled_hlo_instruction",
"//xla/service/llvm_ir:llvm_util",
"//xla/stream_executor:device_description",
"//xla/stream_executor:launch_dim",
Expand Down
80 changes: 36 additions & 44 deletions xla/service/gpu/ir_emitter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ limitations under the License.
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/model/indexing_map.h"
#include "xla/service/gpu/model/symbolic_tile_analysis.h"
#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h"
#include "xla/service/gpu/model/tiled_hlo_instruction.h"
#include "xla/service/gpu/target_util.h"
#include "xla/service/gpu/triton_fusion_analysis.h"
#include "xla/service/gpu/triton_tiling_propagation.h"
Expand Down Expand Up @@ -738,11 +738,10 @@ absl::StatusOr<Value> EmitNestedFusion(

// TODO(b/331332678): Add unit tests to target this function specifically.
Value EmitTiledBroadcast(
ImplicitLocOpBuilder& b, const SymbolicTileAnalysis& analysis,
const SymbolicTiledHloInstruction& tiled_broadcast,
absl::flat_hash_map<const SymbolicTiledHloInstruction*, Value>& values) {
auto input_tile_shape = analysis.TileSizes(*tiled_broadcast.operand(0));
auto output_tile_shape = analysis.TileSizes(tiled_broadcast);
ImplicitLocOpBuilder& b, const TiledHloInstruction& tiled_broadcast,
absl::flat_hash_map<const TiledHloInstruction*, Value>& values) {
auto input_tile_shape = tiled_broadcast.operand(0)->tile_sizes();
auto output_tile_shape = tiled_broadcast.tile_sizes();

Value expanded_input = values[tiled_broadcast.operand(0)];

Expand Down Expand Up @@ -799,11 +798,10 @@ Value EmitTiledBroadcast(
absl::StatusOr<Value> EmitTiledHloInstruction(
ImplicitLocOpBuilder& b, absl::string_view libdevice_path,
const se::DeviceDescription& device_info,
const SymbolicTileAnalysis& analysis,
const SymbolicTiledHloInstruction& tiled_hlo,
std::function<absl::StatusOr<Value>(const SymbolicTiledHloInstruction&)>
const TiledHloInstruction& tiled_hlo,
std::function<absl::StatusOr<Value>(const TiledHloInstruction&)>
emit_param_load_fn,
absl::flat_hash_map<const SymbolicTiledHloInstruction*, Value>& values) {
absl::flat_hash_map<const TiledHloInstruction*, Value>& values) {
const HloInstruction* hlo = tiled_hlo.hlo();

if (hlo->opcode() == HloOpcode::kParameter) {
Expand All @@ -817,7 +815,7 @@ absl::StatusOr<Value> EmitTiledHloInstruction(
}

if (hlo->opcode() == HloOpcode::kBroadcast) {
return EmitTiledBroadcast(b, analysis, tiled_hlo, values);
return EmitTiledBroadcast(b, tiled_hlo, values);
}

if (hlo->opcode() == HloOpcode::kReduce) {
Expand All @@ -829,7 +827,7 @@ absl::StatusOr<Value> EmitTiledHloInstruction(
std::vector<Value> operands;
operands.reserve(hlo->operands().size());

for (const SymbolicTiledHloInstruction* operand : tiled_hlo.operands()) {
for (const TiledHloInstruction* operand : tiled_hlo.operands()) {
operands.push_back(values[operand]);
}
return EmitElementwise(b, libdevice_path, device_info, *hlo, operands);
Expand All @@ -852,21 +850,22 @@ absl::StatusOr<Value> EmitTiledHloInstruction(
absl::StatusOr<Value> EmitTiledScope(
ImplicitLocOpBuilder& b, absl::string_view libdevice_path,
const se::DeviceDescription& device_info,
const SymbolicTileAnalysis& analysis,
std::function<absl::StatusOr<Value>(const SymbolicTiledHloInstruction&)>
const std::vector<std::unique_ptr<TiledHloInstruction>>&
tiled_hlo_instructions,
std::function<absl::StatusOr<Value>(const TiledHloInstruction&)>
emit_param_load_fn,
absl::flat_hash_map<const SymbolicTiledHloInstruction*, Value>& values) {
for (const auto& tiled_hlo : analysis.GetTiledHloInstructions()) {
absl::flat_hash_map<const TiledHloInstruction*, Value>& values) {
for (const auto& tiled_hlo : tiled_hlo_instructions) {
TF_ASSIGN_OR_RETURN(
Value result,
EmitTiledHloInstruction(b, libdevice_path, device_info, analysis,
*tiled_hlo, emit_param_load_fn, values));
EmitTiledHloInstruction(b, libdevice_path, device_info, *tiled_hlo,
emit_param_load_fn, values));
TF_RET_CHECK(values.insert({tiled_hlo.get(), result}).second)
<< tiled_hlo->hlo()->ToString();
VLOG(8) << "Emitted "
<< tiled_hlo->hlo()->ToString(HloPrintOptions::ShortParsable());
}
return values[analysis.GetRoot()];
return values[tiled_hlo_instructions.back().get()];
}

// Emit sequence of instructions using compatible tiling ordered producers
Expand Down Expand Up @@ -2219,10 +2218,10 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
// `tile_offset_indexing` is a mapping from
// (program_id) -> [tile_offset0, ..., tile_offsetN]
Value ComputeBasePtrOffset(ImplicitLocOpBuilder b, Value pid,
const Shape& shape,
const IndexingMap& tile_offset_indexing) {
const TiledHloInstruction& tiled_hlo) {
const Shape& shape = tiled_hlo.hlo()->shape();
ArrayRef<mlir::AffineExpr> dimension_exprs =
tile_offset_indexing.GetAffineMap().getResults();
tiled_hlo.block_id_to_tile_offsets_indexing().GetAffineMap().getResults();

mlir::AffineExpr linear_index =
mlir::getAffineConstantExpr(0, b.getContext());
Expand Down Expand Up @@ -2292,7 +2291,9 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder,
computation->root_instruction()->shape().rank(), 1);
output_tile_sizes.back() = row_len;

analysis->SetTileSizes(output_tile_sizes);
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<TiledHloInstruction>> tiled_hlo_instructions,
analysis->ComputeTiledHloInstructions(output_tile_sizes));

// block_size must be a power of two.
int result_block_size = llvm::PowerOf2Ceil(row_len);
Expand All @@ -2303,27 +2304,21 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder,
}

// Emits load instructions
auto emit_param_load = [&](const SymbolicTiledHloInstruction& tiled_hlo)
-> absl::StatusOr<Value> {
auto emit_param_load =
[&](const TiledHloInstruction& tiled_hlo) -> absl::StatusOr<Value> {
std::vector<Value> tile_sizes, tile_strides, tile_offsets;
for (auto [size, stride, offset] : llvm::zip(
analysis->TileSizes(tiled_hlo), analysis->TileStrides(tiled_hlo),
analysis->TileOffsets(tiled_hlo))) {
for (auto [size, stride] :
llvm::zip(tiled_hlo.tile_sizes(), tiled_hlo.tile_strides())) {
if (size == 1) continue;

tile_sizes.push_back(CreateConst(b, b.getI64Type(), size));
tile_strides.push_back(CreateConst(b, b.getI64Type(), stride));
tile_offsets.push_back(CreateConst(b, b.getI32Type(), offset));
tile_offsets.push_back(CreateConst(b, b.getI32Type(), 0));
}

TF_ASSIGN_OR_RETURN(
IndexingMap program_id_to_input_tile_indexing,
analysis->ComputeBlockIdToTileOffsetIndexing(tiled_hlo));

// Manually compute pointer offset to avoid materialized fully parallel
// dimensions in the tile. Current codegen tried to avoid size-1 dims.
Value ptr_offset = ComputeBasePtrOffset(b, pid, tiled_hlo.hlo()->shape(),
program_id_to_input_tile_indexing);
Value ptr_offset = ComputeBasePtrOffset(b, pid, tiled_hlo);

auto fn_arg = fn.getArgument(tiled_hlo.hlo()->parameter_number());
auto tile_ptr = AddPtr(b, fn_arg, ptr_offset);
Expand All @@ -2343,17 +2338,14 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder,
return EmitParameterLoad(b, emitted_tensor, boundary_checks);
};

absl::flat_hash_map<const SymbolicTiledHloInstruction*, Value> values_out;
TF_ASSIGN_OR_RETURN(Value result,
EmitTiledScope(b, libdevice_path, device_info, *analysis,
emit_param_load, values_out));

absl::flat_hash_map<const TiledHloInstruction*, Value> values_out;
TF_ASSIGN_OR_RETURN(
IndexingMap program_id_to_output_tile_indexing,
analysis->ComputeBlockIdToTileOffsetIndexing(*analysis->GetRoot()));
Value result,
EmitTiledScope(b, libdevice_path, device_info, tiled_hlo_instructions,
emit_param_load, values_out));

Value ptr_offset = ComputeBasePtrOffset(b, pid, root_shape,
program_id_to_output_tile_indexing);
Value ptr_offset =
ComputeBasePtrOffset(b, pid, *tiled_hlo_instructions.back());

Value store_tensor = b.create<mt::MakeTensorPtrOp>(
/*base=*/AddPtr(b, fn.getArgument(computation->num_parameters()),
Expand Down
6 changes: 4 additions & 2 deletions xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,8 @@ cc_library(
":indexing_map",
":symbolic_tile",
":symbolic_tiled_hlo_instruction",
":tiled_hlo_instruction",
"//xla:status",
"//xla:status_macros",
"//xla/hlo/ir:hlo",
"//xla/service:instruction_fusion",
"@com_google_absl//absl/algorithm:container",
Expand All @@ -646,6 +646,8 @@ cc_library(
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:statusor",
],
)

Expand All @@ -655,7 +657,7 @@ xla_cc_test(
deps = [
":indexing_test_utils",
":symbolic_tile_analysis",
":symbolic_tiled_hlo_instruction",
":tiled_hlo_instruction",
"//xla/hlo/ir:hlo",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
Expand Down
Loading
Loading