Skip to content

Commit

Permalink
OpenXLA-specific changes.
Browse files Browse the repository at this point in the history
- Disable short pointer.
- Fixup expensiveLoadOrStore().
  • Loading branch information
chsigg authored and karupayun committed Jul 14, 2023
1 parent 336d0e9 commit 87956e1
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 17 deletions.
1 change: 1 addition & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ cc_library(
name = "TritonTransforms",
srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]),
hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]),
copts = ["-Wno-parentheses"],
includes = ["include"],
deps = [
":TritonDialects",
Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
// same
if (isSingleValue(op->getOperand(0)))
return false;
// TODO(manany): Investigate with Openai why the change here
// https://github.com/openai/triton/commit/640f3c392184cd14291c1bca6a4795eb0f32a61a
// which introduces Case 2 causes breakage to this test
// //third_party/py/jax_triton/tests:pallas_test_sm80 --test_filter=test_fused_attention_bwd
return true;
// Case 2: Tensor of pointers has more threads than elements
// we can presume a high hit-rate that makes it cheap to load
auto ptrType = op->getOperand(0).getType().cast<RankedTensorType>();
Expand Down
5 changes: 2 additions & 3 deletions lib/Target/LLVMIR/LLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Support/SourceMgr.h"
#include "third_party/py/triton/google/find_cuda.h"
#ifdef _WIN32
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
Expand Down Expand Up @@ -187,10 +188,8 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
// Search for libdevice relative to its library path if used from Python
// Then native code is in `triton/_C/libtriton.so` and libdevice in
// `triton/third_party/cuda/lib/libdevice.10.bc`
static const auto this_library_path = getThisLibraryPath();
static const auto runtime_path =
this_library_path.parent_path().parent_path() / "third_party" / "cuda" /
"lib" / "libdevice.10.bc";
fs::path(PathToLibdevice()) / "libdevice.10.bc";
if (fs::exists(runtime_path)) {
externLibs.try_emplace(libdevice, runtime_path.string());
} else {
Expand Down
2 changes: 1 addition & 1 deletion lib/Target/PTX/PTXTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
auto *shortPtr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(shortPtr);
shortPtr->setValue(true);
shortPtr->setValue(false);
std::string sm = cc == 90 ? "sm_90a" : "sm_" + std::to_string(cc);
// max PTX version
int ptxMajor = maxPTX / 10;
Expand Down
12 changes: 0 additions & 12 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,6 @@ tt.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
tt.return
}

tt.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<16x!tt.ptr<i32>, #layout1>
%1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #layout1>
%2 = tt.addptr %0, %1 : tensor<16x!tt.ptr<i32>, #layout1>, tensor<16xi32, #layout1>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16xi32, #layout1>
// CHECK-NOT: triton_gpu.convert_layout
%4 = triton_gpu.convert_layout %3 : (tensor<16xi32, #layout1>) -> tensor<16xi32, #layout0>
%5 = triton_gpu.convert_layout %2 : (tensor<16x!tt.ptr<i32>, #layout1>) -> tensor<16x!tt.ptr<i32>, #layout0>
tt.store %5, %4 : tensor<16xi32, #layout0>
tt.return
}

// CHECK-LABEL: if
tt.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
Expand Down
1 change: 0 additions & 1 deletion third_party/intel_xpu_backend
Submodule intel_xpu_backend deleted from 0bcc48

0 comments on commit 87956e1

Please sign in to comment.