Skip to content

Commit

Permalink
OpenXLA-specific changes.
Browse files Browse the repository at this point in the history
- Disable short pointer.
- Fixup expensiveLoadOrStore().
  • Loading branch information
chsigg committed Jul 11, 2023
1 parent c12fd7f commit ee1163a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
5 changes: 5 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
// same
if (isSingleValue(op->getOperand(0)))
return false;
// TODO(manany): Investigate with Openai why the change here
// https://github.com/openai/triton/commit/640f3c392184cd14291c1bca6a4795eb0f32a61a
// which introduces Case 2 causes breakage to this test
// //third_party/py/jax_triton/tests:pallas_test_sm80 --test_filter=test_fused_attention_bwd
return true;
// Case 2: Tensor of pointers has more threads than elements
// we can presume a high hit-rate that makes it cheap to load
auto ptrType = op->getOperand(0).getType().cast<RankedTensorType>();
Expand Down
2 changes: 1 addition & 1 deletion lib/Target/PTX/PTXTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
auto *shortPtr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(shortPtr);
shortPtr->setValue(true);
shortPtr->setValue(false);
std::string sm = cc == 90 ? "sm_90a" : "sm_" + std::to_string(cc);
// max PTX version
int ptxMajor = maxPTX / 10;
Expand Down

0 comments on commit ee1163a

Please sign in to comment.