Skip to content

Commit

Permalink
Fix compilatio issues related to aa08925
Browse files Browse the repository at this point in the history
(Triton in XLA for ROCm - ir_emitter_triton related changes)
  • Loading branch information
zoranjovanovic-ns committed Apr 8, 2024
1 parent ba124f3 commit 82818f9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
7 changes: 4 additions & 3 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -588,14 +588,15 @@ cc_library(
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:tensor_float_32_utils",
] + if_cuda_is_configured([
] + if_gpu_is_configured([
"@triton//:TritonNvidiaGPUTransforms",
"@triton//:TritonGPUToLLVM",
]) + if_cuda_is_configured([
"@triton//third_party/nvidia:NVGPUToLLVM",
"@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM",
"@triton//:TritonGPUTransforms",
"@triton//:TritonNvidiaGPUTransforms",
"@triton//:TritonLLVMIR",
"@triton//:TritonToTritonGPU",
"@triton//:TritonGPUToLLVM",
]),
)

Expand Down
9 changes: 6 additions & 3 deletions xla/service/gpu/ir_emitter_triton_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h"
// TODO(ROCm): Enable and include ROCm Triton passes when ROCm Triton is
// included in build.
//#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project
Expand All @@ -24,6 +26,7 @@ limitations under the License.
#include "xla/service/hlo_module_config.h"
#include "tsl/platform/rocm_rocdl_path.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
Expand Down Expand Up @@ -88,11 +91,11 @@ absl::Status CreateTritonPipeline(

// Based on make_llir() in
// @triton//:third_party/nvidia/backend/compiler.py
pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass());
//pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass());
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(mt::gpu::createAllocateSharedMemoryPass());
pm.addPass(mt::createConvertTritonAMDGPUToLLVMPass());
//pm.addPass(mt::createConvertTritonAMDGPUToLLVMPass());
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
Expand Down

0 comments on commit 82818f9

Please sign in to comment.