Skip to content

Commit

Permalink
OpenXLA-specific changes.
Browse files Browse the repository at this point in the history
- Disable short pointer.
- Fixup expensiveLoadOrStore().
  • Loading branch information
chsigg authored and karupayun committed Aug 18, 2023
1 parent 8a313ed commit dcc85b4
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 17 deletions.
8 changes: 7 additions & 1 deletion BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ _no_unused_variable = select({
"//conditions:default": ["-Wno-unused-variable"],
})

_no_unused_variable_no_parentheses = select({
":compiler_is_msvc": [],
"//conditions:default": ["-Wno-unused-variable -Wno-parentheses"],
})

td_library(
name = "td_files",
srcs = glob(["include/triton/**/*.td"]),
Expand Down Expand Up @@ -350,6 +355,7 @@ cc_library(
name = "TritonTransforms",
srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]),
hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]),
copts = ["-Wno-parentheses"],
includes = ["include"],
deps = [
":TritonDialects",
Expand Down Expand Up @@ -413,7 +419,7 @@ cc_library(
"include/triton/Tools/Sys/*.hpp",
"include/triton/Conversion/TritonGPUToLLVM/*.h",
]),
copts = _no_unused_variable,
copts = _no_unused_variable_no_parentheses,
includes = [
"include",
"lib/Conversion/TritonGPUToLLVM",
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() {

unsigned bytesPerElem = 0;
for (const auto &ty : srcElementTypes) {
bytesPerElem += ty.getIntOrFloatBitWidth() / 8;
bytesPerElem += (ty.getIntOrFloatBitWidth() + 7) / 8;
}
return bytesPerElem * elems;
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class BlockedToMMA : public mlir::RewritePattern {
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
int origBitWidth = finalBitWidth;
SetVector<Operation *> slice;
mlir::getBackwardSlice(x, &slice, bwdFilter);
mlir::getBackwardSlice(x, &slice, {{bwdFilter}});
Operation *firstOp = slice.empty() ? nullptr : *slice.begin();
if (firstOp)
if (Value arg = firstOp->getOperand(0))
Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
// same
if (isSingleValue(op->getOperand(0)))
return false;
// TODO(manany): Investigate with Openai why the change here
// https://github.com/openai/triton/commit/640f3c392184cd14291c1bca6a4795eb0f32a61a
// which introduces Case 2 causes breakage to this test
// //third_party/py/jax_triton/tests:pallas_test_sm80 --test_filter=test_fused_attention_bwd
return true;
// Case 2: Tensor of pointers has more threads than elements
// we can presume a high hit-rate that makes it cheap to load
auto ptrType = op->getOperand(0).getType().cast<RankedTensorType>();
Expand Down
2 changes: 1 addition & 1 deletion lib/Target/PTX/PTXTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
auto *shortPtr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(shortPtr);
shortPtr->setValue(true);
shortPtr->setValue(false);
std::string sm = cc == 90 ? "sm_90a" : "sm_" + std::to_string(cc);
// max PTX version
int ptxMajor = maxPTX / 10;
Expand Down
12 changes: 0 additions & 12 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,6 @@ tt.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
tt.return
}

tt.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<16x!tt.ptr<i32>, #layout1>
%1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #layout1>
%2 = tt.addptr %0, %1 : tensor<16x!tt.ptr<i32>, #layout1>, tensor<16xi32, #layout1>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16xi32, #layout1>
// CHECK-NOT: triton_gpu.convert_layout
%4 = triton_gpu.convert_layout %3 : (tensor<16xi32, #layout1>) -> tensor<16xi32, #layout0>
%5 = triton_gpu.convert_layout %2 : (tensor<16x!tt.ptr<i32>, #layout1>) -> tensor<16x!tt.ptr<i32>, #layout0>
tt.store %5, %4 : tensor<16xi32, #layout0>
tt.return
}

// CHECK-LABEL: if
tt.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
Expand Down
1 change: 0 additions & 1 deletion third_party/intel_xpu_backend
Submodule intel_xpu_backend deleted from 0bcc48

0 comments on commit dcc85b4

Please sign in to comment.