diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index e1edc1110968c2..011d4ed28bd92b 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -588,14 +588,15 @@ cc_library( "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:tensor_float_32_utils", - ] + if_cuda_is_configured([ + ] + if_gpu_is_configured([ + "@triton//:TritonNvidiaGPUTransforms", + "@triton//:TritonGPUToLLVM", + ]) + if_cuda_is_configured([ "@triton//third_party/nvidia:NVGPUToLLVM", "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", "@triton//:TritonGPUTransforms", - "@triton//:TritonNvidiaGPUTransforms", "@triton//:TritonLLVMIR", "@triton//:TritonToTritonGPU", - "@triton//:TritonGPUToLLVM", ]), ) diff --git a/xla/service/gpu/ir_emitter_triton_rocm.cc b/xla/service/gpu/ir_emitter_triton_rocm.cc index c8147aa6c0bfda..7bd0ca8df7d9f9 100644 --- a/xla/service/gpu/ir_emitter_triton_rocm.cc +++ b/xla/service/gpu/ir_emitter_triton_rocm.cc @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h" +// TODO(ROCm): Enable and include ROCm Triton passes when ROCm Triton is +// included in build. +//#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project @@ -24,6 +26,7 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "tsl/platform/rocm_rocdl_path.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" @@ -88,11 +91,11 @@ absl::Status CreateTritonPipeline( // Based on make_llir() in // @triton//:third_party/nvidia/backend/compiler.py - pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass()); + //pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass()); pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertIndexToLLVMPass()); pm.addPass(mt::gpu::createAllocateSharedMemoryPass()); - pm.addPass(mt::createConvertTritonAMDGPUToLLVMPass()); + //pm.addPass(mt::createConvertTritonAMDGPUToLLVMPass()); pm.addPass(mlir::createArithToLLVMConversionPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass());