Skip to content

Commit

Permalink
Update OpenXLA's Triton dependency to include the AMD backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
chsigg authored and tensorflower-gardener committed May 30, 2024
1 parent 4dbe674 commit b4451fc
Show file tree
Hide file tree
Showing 10 changed files with 187 additions and 36 deletions.
4 changes: 2 additions & 2 deletions third_party/triton/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ load("//third_party/triton/xla_extensions:series.bzl", "extensions_files_patch_l
def repo():
"""Imports Triton."""

TRITON_COMMIT = "cl637553582"
TRITON_SHA256 = "400077180416fc59486b698a6523013ee11589c6269e1aeb992292ca12cc1e58"
TRITON_COMMIT = "cl638583630"
TRITON_SHA256 = "769385a2295fa7256a04fcdc886054fb0853a25ee1c35dcdc0aabf755508f9fc"
tf_http_archive(
name = "triton",
sha256 = TRITON_SHA256,
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/third_party/triton/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ load("//third_party/triton/xla_extensions:series.bzl", "extensions_files_patch_l
def repo():
"""Imports Triton."""

TRITON_COMMIT = "cl637553582"
TRITON_SHA256 = "400077180416fc59486b698a6523013ee11589c6269e1aeb992292ca12cc1e58"
TRITON_COMMIT = "cl638583630"
TRITON_SHA256 = "769385a2295fa7256a04fcdc886054fb0853a25ee1c35dcdc0aabf755508f9fc"
tf_http_archive(
name = "triton",
sha256 = TRITON_SHA256,
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,7 @@ cc_library(
hdrs = ["triton_support.h"],
deps = [
":variant_visitor",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service:instruction_fusion",
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/gpu_fusible.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,9 @@ static int64_t SharedMemoryUsageNoCache(const HloInstruction& instr) {
// __shared__[32] is used for row reduction.
return 32 * primitive_size * num_variadic;
} else {
// __shared__[2][32][33] cache is used for column reduction ("2" comes
// __shared__[4][32][33] cache is used for column reduction ("4" comes
// from potential x-tiling).
return 2 * 32 * 33 * primitive_size * num_variadic;
return 4 * 32 * 33 * primitive_size * num_variadic;
}
} else if (GetDescriptionForTiledTransposeEmitter(instr, instr).has_value()) {
// Tile size for transposition.
Expand Down
23 changes: 13 additions & 10 deletions third_party/xla/xla/service/gpu/ir_emitter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1556,21 +1556,24 @@ class MatMulEmitterHelper {
majormost_dim_start_index_ptr_val, mt::CacheModifier::NONE,
mt::EvictionPolicy::NORMAL,
/*isVolatile=*/false);
Value majormost_dim_start_index_lower_limit_val =
CreateConst(b_, majormost_dim_start_index_val.getType(), 0);
int64_t majormost_dim_start_index_upper_limit =
hlo->operand(0)->shape().dimensions(majormost_dim) -
hlo->dynamic_slice_sizes().at(majormost_dim);
Value majormost_dim_start_index_upper_limit_val =
CreateConst(b_, majormost_dim_start_index_val.getType(),
majormost_dim_start_index_upper_limit);
// Our Triton codegen only supports signed integers so far.
// We don't want to cast S64 indices to S32, because that could result
// in an incorrect value.
if (majormost_dim_start_index_val.getType().isInteger() &&
majormost_dim_start_index_val.getType().getIntOrFloatBitWidth() ==
64) {
return UncompilableMatmul(
"64 bit dynamic-slice indices are not supported yet.");
}
majormost_dim_start_index_val =
b_.create<ma::MaxSIOp>(majormost_dim_start_index_val,
majormost_dim_start_index_lower_limit_val);
Cast(b_, majormost_dim_start_index_val, b_.getI32Type());
majormost_dim_start_index_val =
b_.create<ma::MinSIOp>(majormost_dim_start_index_val,
majormost_dim_start_index_upper_limit_val);
b_.create<ma::MaxSIOp>(majormost_dim_start_index_val, Cst32(0));
majormost_dim_start_index_val = b_.create<ma::MinSIOp>(
majormost_dim_start_index_val,
Cst32(majormost_dim_start_index_upper_limit));

// How many "rows" (non-contracting dim values) are there in a slice of
// size 1?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,7 @@ TEST_F(MultiOutputFusionTest, SharedMemoryBudget) {
.value();
ASSERT_TRUE(mof_.Run(module.get()).value());

EXPECT_EQ(2, CountMultiOutputFusions(module.get()));
EXPECT_EQ(5, CountMultiOutputFusions(module.get()));
}

TEST_F(MultiOutputFusionTest, DoNotGroupTooManyReductions) {
Expand Down
43 changes: 43 additions & 0 deletions third_party/xla/xla/service/gpu/triton_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "xla/service/gpu/triton_support.h"

#include <cstdint>
#include <iterator>
#include <variant>
#include <vector>
Expand All @@ -25,6 +26,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/layout.h"
#include "xla/service/gpu/variant_visitor.h"
#include "xla/stream_executor/device_description.h"
#include "xla/xla_data.pb.h"
Expand Down Expand Up @@ -314,6 +316,43 @@ CodegenDecision CanTritonHandleReduce(
return "Reduction is not a row-reduction of a single operand.";
}

CodegenDecision IsTritonSupportedDynamicSlice(
const HloDynamicSliceInstruction& instr) {
for (const HloInstruction* index_operand : instr.index_operands()) {
switch (index_operand->shape().element_type()) {
case S8:
case S16:
case S32:
break; // supported
default:
return CodegenDecision(
"Dynamic slice is only supported with S8, S16, or S32 indices.");
}
}

// Similar to normal slice, we cannot slice a non-major-most dimension as
// that would introduce non-contiguous strides under tiling. The existing
// check against this in GetRequirementsIfSupportedOrder is not suitable for
// dynamic slices, so we instead check for this here.
const HloInstruction* input = instr.operand(0);
Layout in_layout = input->shape().layout();
int64_t majormost_dim_id =
in_layout.minor_to_major(in_layout.minor_to_major_size() - 1);

for (int i = 0; i < input->shape().dimensions_size(); ++i) {
if (i == majormost_dim_id) {
continue;
} else if (input->shape().dimensions(i) != instr.slice_sizes(i)) {
return CodegenDecision(
"Unsupported dynamic slice on non-major-most dimension.");
}
}

// TODO(b/343143854): Check the subtleties of which dynamic slices are
// supported, for example that a fragmented dimension cannot be sliced.
return CodegenDecision{};
}

CodegenDecision IsTritonSupportedInstruction(
const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) {
if (instr.IsElementwise()) {
Expand All @@ -334,6 +373,10 @@ CodegenDecision IsTritonSupportedInstruction(
}
return "Only supports root tuples.";
}
case HloOpcode::kDynamicSlice: {
return IsTritonSupportedDynamicSlice(
*Cast<HloDynamicSliceInstruction>(&instr));
}
case HloOpcode::kBitcast:
case HloOpcode::kTranspose:
case HloOpcode::kSlice:
Expand Down
9 changes: 9 additions & 0 deletions third_party/xla/xla/service/gpu/triton_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <vector>

#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/instruction_fusion.h"
#include "xla/stream_executor/device_description.h"
Expand Down Expand Up @@ -52,6 +53,14 @@ bool IsTritonSupportedElementwise(HloOpcode, PrimitiveType);
CodegenDecision IsTritonSupportedInstruction(
const HloInstruction& instr, const se::GpuComputeCapability& gpu_version);

// Checks dynamic slice against requirements of triton emitter.
//
// This is exposed separately from IsTritonSupportedInstruction because we can
// use it in the dimension order propagation without adding a dependency on the
// GPU version.
CodegenDecision IsTritonSupportedDynamicSlice(
const HloDynamicSliceInstruction& instr);

} // namespace gpu
} // namespace xla

Expand Down
109 changes: 109 additions & 0 deletions third_party/xla/xla/service/gpu/triton_support_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "xla/error_spec.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/primitive_util.h"
Expand Down Expand Up @@ -474,6 +476,113 @@ INSTANTIATE_TEST_SUITE_P(DotTestTestSuite, DotTest,
::testing::Values(HloOpcode::kDot)),
TestParamsToString);

struct DynamicSliceTestParam {
PrimitiveType data_type;
PrimitiveType index_type;
bool is_the_majormost_dim_being_sliced;

using TupleType = std::tuple<PrimitiveType, PrimitiveType, bool>;

explicit DynamicSliceTestParam(const TupleType& tuple)
: data_type(std::get<0>(tuple)),
index_type(std::get<1>(tuple)),
is_the_majormost_dim_being_sliced(std::get<2>(tuple)) {}
};

std::string DynamicSliceTestParamToString(
const ::testing::TestParamInfo<DynamicSliceTestParam>& info) {
return absl::StrCat(
primitive_util::LowercasePrimitiveTypeName(info.param.data_type), "_",
primitive_util::LowercasePrimitiveTypeName(info.param.index_type), "_",
info.param.is_the_majormost_dim_being_sliced ? "majormost"
: "not_majormost");
}

class DynamicSliceTest
: public TritonSupportTest,
public ::testing::WithParamInterface<DynamicSliceTestParam> {};

TEST_P(DynamicSliceTest, IsTritonSupportedExecutesCorrectlyForDynamicSlice) {
if (!GetCudaComputeCapability().IsAtLeast(
se::CudaComputeCapability::AMPERE) &&
GetParam().data_type == BF16) {
GTEST_SKIP() << "No BF16 before Ampere.";
}

constexpr absl::string_view kHloTestTemplate =
R"(
HloModule m
triton_gemm {
dynamic_slice_input = $0[$2,$3] parameter(0)
dot_rhs = f32[2,4] parameter(1)
start_index0 = $1[] parameter(2)
start_index1 = $1[] parameter(3)
dynamic_slice = $0[5,2] dynamic-slice(dynamic_slice_input, start_index0, start_index1),
dynamic_slice_sizes={5,2}
convert = f32[5,2] convert(dynamic_slice)
ROOT dot = f32[5, 4] dot(convert, dot_rhs),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
ENTRY e {
dynamic_slice_input = $0[$2,$3] parameter(0)
dot_rhs = f32[2,4] parameter(1)
start_index0 = $1[] constant($4)
start_index1 = $1[] constant($5)
ROOT fusion = f32[5,4] fusion(dynamic_slice_input, dot_rhs, start_index0, start_index1),
kind=kCustom, calls=triton_gemm,
backend_config={
"fusion_backend_config":{
"kind":"__triton_gemm","triton_gemm_config":{
"block_m":"32","block_n":"32","block_k":"32","split_k":"1",
"num_stages":"1","num_warps":"4","num_ctas":"1"}}}
})";

const std::string hlo_test = absl::Substitute(
kHloTestTemplate,
primitive_util::LowercasePrimitiveTypeName(GetParam().data_type),
primitive_util::LowercasePrimitiveTypeName(GetParam().index_type),
GetParam().is_the_majormost_dim_being_sliced ? 7 : 5, // input dim0
GetParam().is_the_majormost_dim_being_sliced ? 2 : 4, // input dim1
GetParam().is_the_majormost_dim_being_sliced ? 1 : 0, // start_index0
GetParam().is_the_majormost_dim_being_sliced ? 0 : 1 // start_index1
);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_test));
const HloComputation* computation =
module->GetComputationWithName("triton_gemm");
ASSERT_NE(computation, nullptr);
const HloInstruction* instr = hlo_query::GetFirstInstructionWithOpcode(
*computation, HloOpcode::kDynamicSlice);

const bool is_supported_instruction =
IsTritonSupportedInstruction(*instr, GetCudaComputeCapability())
.CanFuse();
const bool is_supported_dynamic_slice =
IsTritonSupportedDynamicSlice(*Cast<HloDynamicSliceInstruction>(instr))
.CanFuse();
EXPECT_EQ(is_supported_instruction, is_supported_dynamic_slice);

if (is_supported_instruction) {
TF_EXPECT_OK(ApplyFloatNormalization(module.get()));
EXPECT_TRUE(RunAndCompareNoHloPasses(
std::move(module), ErrorSpec{/*aabs=*/2e-4, /*arel=*/2e-4}));
} else {
EXPECT_THAT(TritonFusionAnalysis::Execute(*computation),
tsl::testing::StatusIs(absl::StatusCode::kFailedPrecondition));
}
}

INSTANTIATE_TEST_SUITE_P(
All, DynamicSliceTest,
::testing::ConvertGenerator<DynamicSliceTestParam::TupleType>(
::testing::Combine(::testing::Values(F16, BF16, F32),
::testing::Values(S8, S16, S32, S64, U8, U16, U32,
U64),
::testing::Bool())),
DynamicSliceTestParamToString);

TEST_F(TritonSupportTest, UnsupportedDotOutputTypeFailsGracefullyWithTriton) {
const std::string kHloTest = R"(
triton_gemm___computation {
Expand Down
24 changes: 5 additions & 19 deletions third_party/xla/xla/service/gpu/triton_tiling_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -977,25 +977,11 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo,
return "Dynamic slices for now are only supported in GEMM fusions.";
}

// Similar to normal slice, we cannot slice a non-major-most dimension as
// that would introduce non-contiguous strides under tiling. The existing
// check against this in GetRequirementsIfSupportedOrder is not suitable for
// dynamic slices, so we instead check for this here.
const HloInstruction* input = hlo.operand(0);
Layout in_layout = input->shape().layout();
int64_t majormost =
in_layout.minor_to_major(in_layout.minor_to_major_size() - 1);
const HloDynamicSliceInstruction* dynamic_slice =
Cast<HloDynamicSliceInstruction>(&hlo);

for (int i = 0; i < input->shape().dimensions_size(); ++i) {
if (i == majormost) {
continue;
} else if (input->shape().dimensions(i) !=
dynamic_slice->slice_sizes(i)) {
return FusionDecision(
"Unsupported dynamic slice on non-major-most dimension.");
}
if (CodegenDecision decision = IsTritonSupportedDynamicSlice(
*Cast<HloDynamicSliceInstruction>(&hlo));
!decision.CanFuse()) {
// CodegenDecision is actually the same type as FusionDecision.
return decision;
}

return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order,
Expand Down

0 comments on commit b4451fc

Please sign in to comment.