diff --git a/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl index 339733755d6f1..2d10066010bf7 100644 --- a/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl +++ b/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl @@ -62,7 +62,7 @@ def rocm_hipblaslt(): return %{rocm_is_configured} and %{rocm_hipblaslt} def if_rocm_hipblaslt(x): - if %{rocm_is_configured} and (%{rocm_hipblaslt} == "True"): + if %{rocm_is_configured} and (%{rocm_hipblaslt} == True): return select({"//conditions:default": x}) return select({"//conditions:default": []}) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index e1edc1110968c..fc106d2b4328d 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -588,14 +588,15 @@ cc_library( "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:tensor_float_32_utils", - ] + if_cuda_is_configured([ + ] + if_gpu_is_configured([ + "@triton//:TritonNvidiaGPUTransforms", + "@triton//:TritonGPUToLLVM", + "@triton//:TritonToTritonGPU", + ]) + if_cuda_is_configured([ "@triton//third_party/nvidia:NVGPUToLLVM", "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", "@triton//:TritonGPUTransforms", - "@triton//:TritonNvidiaGPUTransforms", "@triton//:TritonLLVMIR", - "@triton//:TritonToTritonGPU", - "@triton//:TritonGPUToLLVM", ]), ) diff --git a/xla/service/gpu/ir_emitter_triton_rocm.cc b/xla/service/gpu/ir_emitter_triton_rocm.cc index c8147aa6c0bfd..7bd0ca8df7d9f 100644 --- a/xla/service/gpu/ir_emitter_triton_rocm.cc +++ b/xla/service/gpu/ir_emitter_triton_rocm.cc @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h" +// TODO(ROCm): Enable and include ROCm Triton passes when ROCm Triton is +// included in build. +//#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project @@ -24,6 +26,7 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "tsl/platform/rocm_rocdl_path.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" @@ -88,11 +91,11 @@ absl::Status CreateTritonPipeline( // Based on make_llir() in // @triton//:third_party/nvidia/backend/compiler.py - pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass()); + //pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass()); pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertIndexToLLVMPass()); pm.addPass(mt::gpu::createAllocateSharedMemoryPass()); - pm.addPass(mt::createConvertTritonAMDGPUToLLVMPass()); + //pm.addPass(mt::createConvertTritonAMDGPUToLLVMPass()); pm.addPass(mlir::createArithToLLVMConversionPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); diff --git a/xla/service/gpu/llvm_gpu_backend/BUILD b/xla/service/gpu/llvm_gpu_backend/BUILD index 72ea7fdc71d29..0bb8b72f2d950 100644 --- a/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/xla/service/gpu/llvm_gpu_backend/BUILD @@ -79,6 +79,7 @@ cc_library( ] + if_rocm_is_configured([ "@local_config_rocm//rocm:rocm_headers", "@llvm-project//llvm:AMDGPUCodeGen", + "@llvm-project//llvm:AMDGPUAsmParser", ]), )