diff --git a/BUILD b/BUILD index 65e12b002f49..38843a597039 100644 --- a/BUILD +++ b/BUILD @@ -53,9 +53,9 @@ _no_unused_variable = select({ "//conditions:default": ["-Wno-unused-variable"], }) -_no_unused_variable_no_parentheses = select({ +_no_parentheses = select({ ":compiler_is_msvc": [], - "//conditions:default": ["-Wno-unused-variable -Wno-parentheses"], + "//conditions:default": ["-Wno-parentheses"], }) td_library( @@ -69,6 +69,7 @@ td_library( "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", "@llvm-project//mlir:FunctionInterfacesTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:PassBaseTdFiles", "@llvm-project//mlir:SideEffectInterfacesTdFiles", @@ -76,6 +77,23 @@ td_library( ], ) +gentbl_cc_library( + name = "triton_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonAttrDefs.td", + deps = ["td_files"], +) + gentbl_cc_library( name = "triton_dialect_inc_gen", tbl_outs = [ @@ -93,6 +111,23 @@ gentbl_cc_library( deps = ["td_files"], ) +gentbl_cc_library( + name = "triton_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonInterfaces.td", + deps = ["td_files"], +) + gentbl_cc_library( name = "triton_ops_inc_gen", tbl_outs = [ @@ -112,14 +147,6 @@ gentbl_cc_library( ["--gen-op-defs"], "include/triton/Dialect/Triton/IR/Ops.cpp.inc", ), - ( - ["--gen-typedef-decls"], - "include/triton/Dialect/Triton/IR/Types.h.inc", - ), - ( - ["--gen-typedef-defs"], - "include/triton/Dialect/Triton/IR/Types.cpp.inc", - ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/triton/Dialect/Triton/IR/TritonOps.td", @@ -127,19 +154,19 @@ gentbl_cc_library( ) gentbl_cc_library( - name = "triton_interfaces_inc_gen", + name = "triton_types_inc_gen", tbl_outs = [ ( - ["--gen-attr-interface-decls"], - "include/triton/Dialect/Triton/IR/AttrInterfaces.h.inc", + ["--gen-typedef-decls"], + "include/triton/Dialect/Triton/IR/Types.h.inc", ), ( - ["--gen-attr-interface-defs"], - "include/triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc", + ["--gen-typedef-defs"], + "include/triton/Dialect/Triton/IR/Types.cpp.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/triton/Dialect/Triton/IR/TritonInterfaces.td", + td_file = "include/triton/Dialect/Triton/IR/TritonTypes.td", deps = ["td_files"], ) @@ -174,6 +201,31 @@ gentbl_cc_library( deps = ["td_files"], ) +gentbl_cc_library( + name = "triton_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td", + deps = ["td_files"], +) + gentbl_cc_library( name = "triton_gpu_dialect_inc_gen", tbl_outs = [ @@ -209,19 +261,19 @@ gentbl_cc_library( ) gentbl_cc_library( - name = "triton_gpu_attr_inc_gen", + name = "triton_gpu_types_inc_gen", tbl_outs = [ ( - ["--gen-attrdef-decls"], - "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc", + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonGPU/IR/Types.h.inc", ), ( - ["--gen-attrdef-defs"], - "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc", + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonGPU/IR/Types.cpp.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td", deps = ["td_files"], ) @@ -241,6 +293,177 @@ gentbl_cc_library( deps = ["td_files"], ) +gentbl_cc_library( + name = "triton_nvgpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvgpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/NVGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/NVGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/NVGPU/IR/NVGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvgpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-llvmir-conversions"], + "include/triton/Dialect/NVGPU/IR/OpsConversions.inc", + ), + ( + ["--gen-op-decls"], + "include/triton/Dialect/NVGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/NVGPU/IR/Ops.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/NVGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/NVGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/NVGPU/IR/NVGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNvidiaGPU", + ], + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_nvgpu_to_llvm_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=NVGPUToLLVM", + ], + "include/triton/Conversion/NVGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/NVGPUToLLVM/Passes.td", + deps = ["td_files"], +) + gentbl_cc_library( name = "triton_conversion_triton_gpu_to_llvm_passes_inc_gen", tbl_outs = [ @@ -289,38 +512,85 @@ gentbl_cc_library( deps = ["td_files"], ) +cc_library( + name = "NVGPUToLLVM", + srcs = glob([ + "lib/Conversion/NVGPUToLLVM/*.cpp", + ]), + hdrs = glob([ + "include/triton/Conversion/NVGPUToLLVM/*.h", + ]), + copts = _no_unused_variable, + includes = ["include"], + deps = [ + ":TritonAnalysis", + ":TritonDialects", + ":TritonGPUToLLVM", + ":triton_conversion_nvgpu_to_llvm_passes_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) + cc_library( name = "TritonAnalysis", - srcs = glob(["lib/Analysis/*.cpp"]), - hdrs = glob(["include/triton/Analysis/*.h"]), + srcs = [ + "lib/Analysis/Alias.cpp", + "lib/Analysis/Allocation.cpp", + "lib/Analysis/AxisInfo.cpp", + "lib/Analysis/Membar.cpp", + "lib/Analysis/Utility.cpp", + ], + hdrs = [ + "include/triton/Analysis/Alias.h", + "include/triton/Analysis/Allocation.h", + "include/triton/Analysis/AxisInfo.h", + "include/triton/Analysis/Membar.h", + "include/triton/Analysis/Utility.h", + "include/triton/Conversion/MLIRTypes.h", + "include/triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h", + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h", + "lib/Conversion/TritonGPUToLLVM/Utility.h", + ], copts = _no_unused_variable, includes = ["include"], deps = [ ":TritonDialects", ":TritonTools", - ":triton_gpu_attr_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", ], ) cc_library( name = "TritonDialects", srcs = glob([ + "lib/Dialect/NVGPU/IR/*.cpp", "lib/Dialect/Triton/IR/*.cpp", "lib/Dialect/TritonGPU/IR/*.cpp", + "lib/Dialect/TritonNvidiaGPU/IR/*.cpp", ]) + [ "include/triton/Analysis/Utility.h", # Avoid circular dependency. ], hdrs = glob([ + "include/triton/Dialect/NVGPU/IR/*.h", "include/triton/Dialect/Triton/IR/*.h", "include/triton/Dialect/TritonGPU/IR/*.h", + "include/triton/Dialect/TritonNvidiaGPU/IR/*.h", ]), copts = _no_unused_variable, includes = ["include"], @@ -329,25 +599,31 @@ cc_library( ":triton_gpu_attr_inc_gen", ":triton_gpu_dialect_inc_gen", ":triton_gpu_ops_inc_gen", - ":triton_gpu_transforms_inc_gen", + ":triton_gpu_types_inc_gen", ":triton_interfaces_inc_gen", + ":triton_nvgpu_attr_inc_gen", + ":triton_nvgpu_dialect_inc_gen", + ":triton_nvgpu_ops_inc_gen", + ":triton_nvidia_gpu_attr_inc_gen", + ":triton_nvidia_gpu_dialect_inc_gen", + ":triton_nvidia_gpu_ops_inc_gen", + ":triton_nvidia_gpu_types_inc_gen", ":triton_ops_inc_gen", + ":triton_types_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:ControlFlowInterfaces", - "@llvm-project//mlir:DestinationStyleOpInterface", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:Transforms", ], ) @@ -355,23 +631,16 @@ cc_library( name = "TritonTransforms", srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]), hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]), - copts = ["-Wno-parentheses"], + copts = _no_parentheses, includes = ["include"], deps = [ ":TritonDialects", ":triton_combine_inc_gen", ":triton_transforms_inc_gen", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ControlFlowDialect", - "@llvm-project//mlir:ControlFlowInterfaces", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", - "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], alwayslink = True, # TritonDialect uses getCanonicalizationPatterns(). @@ -383,8 +652,19 @@ cc_library( "lib/Dialect/TritonGPU/Transforms/*.cpp", "lib/Dialect/TritonGPU/Transforms/*.h", ]), - hdrs = glob(["include/triton/Dialect/TritonGPU/Transforms/*.h"]), - copts = _no_unused_variable, + hdrs = glob([ + "include/triton/Dialect/TritonGPU/Transforms/*.h", + ]) + [ + "include/triton/Tools/Sys/GetEnv.hpp", + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-reorder-ctor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), includes = ["include"], deps = [ ":TritonAnalysis", @@ -392,13 +672,8 @@ cc_library( ":triton_gpu_transforms_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ControlFlowDialect", - "@llvm-project//mlir:ControlFlowInterfaces", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", - "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", @@ -419,7 +694,14 @@ cc_library( "include/triton/Tools/Sys/*.hpp", "include/triton/Conversion/TritonGPUToLLVM/*.h", ]), - copts = _no_unused_variable_no_parentheses, + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-parentheses", + "-Wno-reorder-ctor", + "-Wno-unused-variable", + ], + }), includes = [ "include", "lib/Conversion/TritonGPUToLLVM", @@ -427,6 +709,9 @@ cc_library( deps = [ ":TritonAnalysis", ":TritonDialects", + ":TritonNvidiaGPUTransforms", + ":TritonTmaMetadata", + ":cuda_compat", ":triton_conversion_triton_gpu_to_llvm_passes_inc_gen", ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", "@llvm-project//llvm:Support", @@ -454,6 +739,40 @@ cc_library( ], ) +cc_library( + name = "TritonNvidiaGPUTransforms", + srcs = glob([ + "lib/Dialect/TritonNvidiaGPU/Transforms/*.cpp", + ]), + hdrs = glob([ + "include/triton/Dialect/TritonNvidiaGPU/Transforms/*.h", + ]), + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-ctad-maybe-unsupported", + "-Wno-logical-op-parentheses", + "-Wno-non-virtual-dtor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonAnalysis", + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) + cc_library( name = "TritonToTritonGPU", srcs = glob([ @@ -466,18 +785,14 @@ cc_library( ":TritonAnalysis", ":TritonDialects", ":TritonGPUTransforms", - ":triton_conversion_triton_gpu_to_llvm_passes_inc_gen", + ":TritonTmaMetadata", ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:IR", "@llvm-project//mlir:IndexDialect", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ], ) @@ -486,32 +801,39 @@ cc_library( name = "TritonLLVMIR", srcs = glob([ "lib/Target/LLVMIR/*.cpp", - ]) + [ - "include/triton/Tools/Sys/GetEnv.hpp", - ], + "lib/Target/LLVMIR/*.h", + ]), hdrs = glob(["include/triton/Target/LLVMIR/*.h"]), copts = _no_unused_variable, includes = ["include"], deps = [ + ":NVGPUToLLVM", ":TritonGPUToLLVM", + ":TritonTmaMetadata", ":TritonTransforms", ":triton_target_llvmir_passes_inc_gen", + "@llvm-project//llvm:Analysis", "@llvm-project//llvm:BinaryFormat", "@llvm-project//llvm:Core", "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", "@llvm-project//llvm:Linker", + "@llvm-project//llvm:Passes", "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:ArithToLLVM", "@llvm-project//mlir:BuiltinToLLVMIRTranslation", - "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:ExecutionEngine", "@llvm-project//mlir:ExecutionEngineUtils", "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexToLLVM", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMIRTransforms", "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:NVVMToLLVMIRTranslation", "@llvm-project//mlir:Pass", "@llvm-project//mlir:ROCDLToLLVMIRTranslation", + "@llvm-project//mlir:SCFToControlFlow", "@llvm-project//mlir:ToLLVMIRTranslation", "@llvm-project//mlir:Transforms", # copybara:uncomment "//third_party/py/triton/google:find_cuda", @@ -528,6 +850,7 @@ cc_library( deps = [ ":TritonLLVMIR", "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", "@llvm-project//llvm:MC", "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", @@ -562,29 +885,52 @@ cc_library( ], ) +cc_library( + name = "TritonTmaMetadata", + hdrs = ["include/triton/Target/PTX/TmaMetadata.h"], + includes = ["include"], + deps = [ + "@llvm-project//llvm:Support", + ], +) + cc_library( name = "TritonTools", hdrs = ["include/triton/Tools/Sys/GetEnv.hpp"], includes = ["include"], ) +cc_library( + name = "cuda_compat", + hdrs = ["include/triton/Tools/cuda_compat.h"], + includes = ["include"], + deps = ["//third_party/gpus/cuda:cuda_headers"], +) + cc_binary( name = "triton-opt", srcs = [ "bin/RegisterTritonDialects.h", "bin/triton-opt.cpp", + "include/triton/Conversion/NVGPUToLLVM/Passes.h", "include/triton/Conversion/TritonGPUToLLVM/Passes.h", "include/triton/Conversion/TritonToTritonGPU/Passes.h", + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h", ], includes = ["include"], deps = [ + ":NVGPUToLLVM", ":TritonDialects", ":TritonGPUToLLVM", ":TritonGPUTransforms", + ":TritonNvidiaGPUTransforms", + ":TritonTmaMetadata", ":TritonToTritonGPU", ":TritonTransforms", + ":triton_conversion_nvgpu_to_llvm_passes_inc_gen", ":triton_conversion_triton_gpu_to_llvm_passes_inc_gen", ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + ":triton_nvidia_gpu_transforms_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:AllPassesAndDialects", diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index dd763d3454b4..18d947fc5007 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -9,7 +9,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" -#include "mlir/IR/FunctionInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "triton/Dialect/Triton/IR/Dialect.h.inc" diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index f3de5a21e330..dae00fd984ad 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -647,7 +647,7 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI operand_range getArgOperands() { return {arg_operand_begin(), arg_operand_end()}; } - MutableOperandRange getArgOperandsMutable() { + mlir::MutableOperandRange getArgOperandsMutable() { return getOperandsMutable(); } diff --git a/include/triton/Target/PTX/TmaMetadata.h b/include/triton/Target/PTX/TmaMetadata.h index f183f4e5b0fb..1ae5f4e8a6e2 100644 --- a/include/triton/Target/PTX/TmaMetadata.h +++ b/include/triton/Target/PTX/TmaMetadata.h @@ -24,7 +24,6 @@ #ifndef TRITON_TARGET_PTX_TMAMETADATA_H #define TRITON_TARGET_PTX_TMAMETADATA_H -#include "python/triton/third_party/cuda/include/cuda.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Format.h" diff --git a/include/triton/Tools/cuda_compat.h b/include/triton/Tools/cuda_compat.h new file mode 100644 index 000000000000..443af54b9cd8 --- /dev/null +++ b/include/triton/Tools/cuda_compat.h @@ -0,0 +1,51 @@ +#include "cuda.h" + +// Compatibility with CUDA 11 +#ifndef CU_TENSOR_MAP_NUM_QWORDS +#define CU_TENSOR_MAP_NUM_QWORDS 16 + +typedef struct CUtensorMap_st { + cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; +} CUtensorMap; + +typedef enum CUtensorMapDataType_enum { + CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0, + CU_TENSOR_MAP_DATA_TYPE_UINT16, + CU_TENSOR_MAP_DATA_TYPE_UINT32, + CU_TENSOR_MAP_DATA_TYPE_INT32, + CU_TENSOR_MAP_DATA_TYPE_UINT64, + CU_TENSOR_MAP_DATA_TYPE_INT64, + CU_TENSOR_MAP_DATA_TYPE_FLOAT16, + CU_TENSOR_MAP_DATA_TYPE_FLOAT32, + CU_TENSOR_MAP_DATA_TYPE_FLOAT64, + CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ, + CU_TENSOR_MAP_DATA_TYPE_TFLOAT32, + CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ +} CUtensorMapDataType; + +typedef enum CUtensorMapInterleave_enum { + CU_TENSOR_MAP_INTERLEAVE_NONE = 0, + CU_TENSOR_MAP_INTERLEAVE_16B, + CU_TENSOR_MAP_INTERLEAVE_32B +} CUtensorMapInterleave; + +typedef enum CUtensorMapSwizzle_enum { + CU_TENSOR_MAP_SWIZZLE_NONE = 0, + CU_TENSOR_MAP_SWIZZLE_32B, + CU_TENSOR_MAP_SWIZZLE_64B, + CU_TENSOR_MAP_SWIZZLE_128B +} CUtensorMapSwizzle; + +typedef enum CUtensorMapL2promotion_enum { + CU_TENSOR_MAP_L2_PROMOTION_NONE = 0, + CU_TENSOR_MAP_L2_PROMOTION_L2_64B, + CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CU_TENSOR_MAP_L2_PROMOTION_L2_256B +} CUtensorMapL2promotion; + +typedef enum CUtensorMapFloatOOBfill_enum { + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE = 0, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA +} CUtensorMapFloatOOBfill; +#endif diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 70e675c7bc5f..5ce2d2952b13 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -178,9 +178,11 @@ struct ConvertLayoutOpConversion Value _16 = i32_val(16); if (mmaLayout.isAmpere() || mmaLayout.isHopper()) { multiDimWarpId[0] = - urem(multiDimWarpId[0], i32_val(shapePerCTA[0] / instrShape[0])); + urem(multiDimWarpId[0], + i32_val(ceil(shapePerCTA[0], instrShape[0]))); multiDimWarpId[1] = - urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / instrShape[1])); + urem(multiDimWarpId[1], + i32_val(ceil(shapePerCTA[1], instrShape[1]))); Value mmaGrpId = udiv(laneId, _4); Value mmaGrpIdP8 = add(mmaGrpId, _8); diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index b1c73cd3230f..503ab4aa6b98 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1230,8 +1230,8 @@ void populateElementwiseOpToLLVMPatterns( POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> - POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin - POPULATE_BINARY_OP(arith::MaxFOp, LLVM::MaxNumOp) // fmax + POPULATE_BINARY_OP(arith::MinimumFOp, LLVM::MinNumOp) // fmin + POPULATE_BINARY_OP(arith::MaximumFOp, LLVM::MaxNumOp) // fmax POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index c050c3ad3a38..5be924dd1ba7 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -7,6 +7,7 @@ #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" +#include "triton/Tools/cuda_compat.h" #include diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 06d338685909..22f92244bc3b 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -240,7 +240,7 @@ Value linearize(ConversionPatternRewriter &rewriter, Location loc, Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, Value val, Value pred) { MLIRContext *ctx = rewriter.getContext(); - unsigned bits = val.getType().getIntOrFloatBitWidth(); + unsigned bits = std::max(8u, val.getType().getIntOrFloatBitWidth()); const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r"); PTXBuilder builder; @@ -257,7 +257,7 @@ Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, auto ptrTy = ptr.getType().cast(); assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for loadShared"); auto elemTy = ptrTy.getElementType(); - unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); + unsigned bitwidth = std::max(8u, elemTy.getIntOrFloatBitWidth()); const char *c = bitwidth == 64 ? "=l" : (bitwidth == 16 ? "=h" : "=r"); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index 9dd072d0c942..f8e0d44ccb50 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -89,7 +89,7 @@ ptxBuilder.launch(rewriter, op->getLoc(), voidTy); \ } while (0) #define undef(...) rewriter.create(loc, __VA_ARGS__) -#define null(...) rewriter.create(loc, __VA_ARGS__) +#define null(...) rewriter.create(loc, __VA_ARGS__) #define call(...) rewriter.create(loc, __VA_ARGS__) // Types diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index de5ad6947c53..564a44768ca4 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -129,8 +129,8 @@ void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter, // Floating point GenericOpPattern, GenericOpPattern, // MaxMin - GenericOpPattern, GenericOpPattern, - GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, // Floating point GenericOpPattern, GenericOpPattern, @@ -728,8 +728,8 @@ struct SCFForPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); - rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(), - newOp.getLoopBody().end()); + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); // Now, update all the types. @@ -738,7 +738,7 @@ struct SCFForPattern : public OpConversionPattern { // The entry block may have a special conversion if `entryConversion` is // provided. On success, the new entry block to the region is returned for // convenience. Otherwise, failure is returned. - if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(), + if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *getTypeConverter()))) { return rewriter.notifyMatchFailure(op, "could not convert body types"); } diff --git a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp index 2393599143be..503fa5ca5316 100644 --- a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -325,9 +325,9 @@ class RewriteTensorPointerPass Operation *rewriteForOp(OpBuilder &builder, scf::ForOp op, std::stack &eraser) { // Generate new iteration operands and set rewrited information - SmallVector oldIterOperands = op.getIterOperands(); - SmallVector newIterOperands = op.getIterOperands(); - for (unsigned i = 0, oldI = 0, size = op.getNumIterOperands(); i < size; + SmallVector oldIterOperands = llvm::to_vector(op.getInitArgs()); + SmallVector newIterOperands = llvm::to_vector(op.getInitArgs()); + for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size; ++i, ++oldI) { if (!triton::isTensorPointerType(newIterOperands[i].getType())) continue; @@ -350,7 +350,7 @@ class RewriteTensorPointerPass // mapping. It may refer to a value in the old loop, but we will rewrite it // later IRMapping mapping; - for (unsigned i = 0, oldI = 0; oldI < op.getNumIterOperands(); + for (unsigned i = 0, oldI = 0; oldI < op.getInitArgs().size(); ++i, ++oldI) { auto oldRegionIterArg = op.getRegionIterArg(oldI); if (triton::isTensorPointerType(oldRegionIterArg.getType())) { @@ -377,7 +377,7 @@ class RewriteTensorPointerPass } // Replace later usages - assert(op.getNumResults() == op.getNumIterOperands()); + assert(op.getNumResults() == op.getInitArgs().size()); for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) { auto oldResult = op.getResult(oldI); if (triton::isTensorPointerType(oldResult.getType())) { diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 169939eb37f2..adb32ffa2470 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -118,7 +118,7 @@ class BlockedToMMA : public mlir::RewritePattern { int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth(); int origBitWidth = finalBitWidth; SetVector slice; - mlir::getBackwardSlice(x, &slice, bwdFilter); + mlir::getBackwardSlice(x, &slice, {{bwdFilter}}); Operation *firstOp = slice.empty() ? nullptr : *slice.begin(); if (firstOp) if (Value arg = firstOp->getOperand(0)) @@ -298,8 +298,10 @@ class BlockedToMMA : public mlir::RewritePattern { } else { // convert operands - int minBitwidth = - std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)); + // TODO(b/296812125): Fix minBitwidth issue upstream and uncomment. + // int minBitwidth = + // std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)); + int minBitwidth = 0; Type minType = IntegerType::get(ctx, minBitwidth); // convert A operand auto newAEncoding = ttg::DotOperandEncodingAttr::get( diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 5a1d93c2569a..2475ec5a035f 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -252,8 +252,9 @@ class TritonGPUOptimizeDotOperandsPass mlir::RewritePatternSet patterns(context); patterns.add(context); - if (triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80) - patterns.add(context); + // TODO(b/283035396): Fix CUDA_ERROR_MISALIGNED_ADDRESS and uncomment. + // if (triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80) + // patterns.add(context); patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 13de5d266cb6..cc71b208ddef 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -347,7 +347,7 @@ LogicalResult LoopPipeliner::collectOps(SetVector &ops) { void LoopPipeliner::collectValueDep(Value v, int stage, SetVector &deps) { // Loop-invariant value, skip - if (v.getParentRegion() != &forOp.getLoopBody()) + if (v.getParentRegion() != &forOp.getRegion()) return; // Since we only need to peel the loop numStages-1 times, don't worry @@ -671,7 +671,7 @@ void LoopPipeliner::createBufferTypes() { } void LoopPipeliner::createOrderedDeps() { - for (Operation &op : forOp.getLoopBody().front()) { + for (Operation &op : *forOp.getBody()) { if (depOps.contains(&op)) orderedDeps.push_back(&op); else if (op.getNumResults() > 0 && validLoads.contains(op.getResult(0))) @@ -1007,8 +1007,7 @@ SmallVector LoopPipeliner::collectNewLoopArgs() { // We need this to update operands for yield // original block arg => new arg's idx SmallVector newLoopArgs; - for (auto v : forOp.getIterOperands()) - newLoopArgs.push_back(v); + for (auto v : forOp.getInitArgs()) newLoopArgs.push_back(v); bufferIdx = newLoopArgs.size(); for (auto loadOp : validLoads) diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 07e982dbf65d..24dfdb61dd70 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -269,8 +269,7 @@ scf::ForOp Prefetcher::createNewForOp() { OpBuilder builder(forOp); SmallVector loopArgs; - for (auto v : forOp.getIterOperands()) - loopArgs.push_back(v); + for (auto v : forOp.getInitArgs()) loopArgs.push_back(v); for (Value dot : dots) { loopArgs.push_back( operand2headPrefetch[dot.getDefiningOp().getA()]); diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index cbdb59c88d2b..3e33c9a2d873 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -760,16 +760,16 @@ static scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, // Create a new loop before the existing one, with the extra operands. rewriter.setInsertionPoint(loop); - auto operands = llvm::to_vector<4>(loop.getIterOperands()); + auto operands = llvm::to_vector<4>(loop.getInitArgs()); operands.append(newIterOperands.begin(), newIterOperands.end()); scf::ForOp newLoop = rewriter.create( loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), operands); newLoop.getBody()->erase(); - newLoop.getLoopBody().getBlocks().splice( - newLoop.getLoopBody().getBlocks().begin(), - loop.getLoopBody().getBlocks()); + newLoop.getRegion().getBlocks().splice( + newLoop.getRegion().getBlocks().begin(), + loop.getRegion().getBlocks()); for (Value operand : newIterOperands) newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); @@ -805,8 +805,8 @@ static void rewriteSlice(SetVector &slice, if (slice.count(arg)) { OpOperand &initVal = forOp.getOpOperandForRegionIterArg(arg); argMapping.push_back( - std::make_pair(*forOp.getIterArgNumberForOpOperand(initVal), - forOp.getNumIterOperands() + newOperands.size())); + std::make_pair(forOp.getResultForOpOperand(initVal).getResultNumber(), + forOp.getInitArgs().size() + newOperands.size())); newOperands.push_back(mapping.lookup(initVal.get())); } } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 6e92fb2901ae..fdd96d73c7a6 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -597,7 +597,7 @@ struct ForOpDeadArgElimination : public OpRewritePattern { Value yieldOperand = forOwner.getBody()->getTerminator()->getOperand(iterIdx); markLive(yieldOperand); - markLive(forOwner.getIterOperands()[iterIdx]); + markLive(forOwner.getInitArgs()[iterIdx]); } } SmallVector deadArg; diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp index d79da1ee9961..17c896e6aeeb 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp @@ -628,7 +628,7 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const { arith::CeilDivUIOp, arith::DivFOp, arith::DivSIOp, arith::DivUIOp, arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp, arith::FloorDivSIOp, arith::FPToSIOp, arith::FPToUIOp, - arith::MaxFOp, arith::MaxSIOp, arith::MaxUIOp, arith::MinFOp, + arith::MaximumFOp, arith::MaxSIOp, arith::MaxUIOp, arith::MinimumFOp, arith::MinSIOp, arith::MinUIOp, arith::MulFOp, arith::MulIOp, arith::NegFOp, arith::OrIOp, arith::RemFOp, arith::RemSIOp, arith::RemUIOp, arith::ShLIOp, arith::ShRSIOp, arith::ShRUIOp, diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp index e13cf8bd9179..5a2d3beaa4b4 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp @@ -523,9 +523,9 @@ class TritonGPURewriteTensorPointerPass std::stack &eraser, DenseSet &valueToRemove) { // Generate new iteration operands and set rewrited information - SmallVector oldIterOperands = op.getIterOperands(); - SmallVector newIterOperands = op.getIterOperands(); - for (unsigned i = 0, oldI = 0, size = op.getNumIterOperands(); i < size; + SmallVector oldIterOperands = llvm::to_vector(op.getInitArgs()); + SmallVector newIterOperands = llvm::to_vector(op.getInitArgs()); + for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size; ++i, ++oldI) { if (!tt::isTensorPointerType(newIterOperands[i].getType())) continue; @@ -550,7 +550,7 @@ class TritonGPURewriteTensorPointerPass // mapping. It may refer to a value in the old loop, but we will rewrite it // later IRMapping mapping; - for (unsigned i = 0, oldI = 0; oldI < op.getNumIterOperands(); + for (unsigned i = 0, oldI = 0; oldI < op.getInitArgs().size(); ++i, ++oldI) { auto oldRegionIterArg = op.getRegionIterArg(oldI); if (tt::isTensorPointerType(oldRegionIterArg.getType()) && @@ -586,7 +586,7 @@ class TritonGPURewriteTensorPointerPass valueToRemove.insert(v); // Replace later usages - assert(op.getNumResults() == op.getNumIterOperands()); + assert(op.getNumResults() == op.getInitArgs().size()); for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) { auto oldResult = op.getResult(oldI); if (tt::isTensorPointerType(oldResult.getType()) && @@ -787,8 +787,8 @@ class TritonGPURewriteTensorPointerPass } } if (auto forOp = dyn_cast(op)) { - SmallVector iterOperands = forOp.getIterOperands(); - for (unsigned i = 0, size = forOp.getNumIterOperands(); i < size; ++i) { + SmallVector iterOperands = llvm::to_vector(forOp.getInitArgs()); + for (unsigned i = 0, size = forOp.getInitArgs().size(); i < size; ++i) { if (tt::isTensorPointerType(iterOperands[i].getType())) { auto makeTensorPtrOp = getMakeTensorPtrOp(iterOperands[i]); if (shouldRemove(makeTensorPtrOp, computeCapability)) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp index 9ebc78497cbb..95a7fed0e572 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp @@ -79,6 +79,7 @@ void materializeGetAgentIdOp(Operation *parentOp) { builder.setInsertionPoint(agentIdOp); Value globalRoleId = builder.create(loc, 0, 32); int globalNumWarps = 0; + SmallVector deprecatedOps; for (auto cmpOp : agentIdOp->getUsers()) { assert(isa(cmpOp)); for (auto u : cmpOp->getUsers()) { @@ -111,11 +112,14 @@ void materializeGetAgentIdOp(Operation *parentOp) { Value cond = builder.create(loc, lowerBound, upperBound); cmpOp->getResult(0).replaceAllUsesWith(cond); - cmpOp->erase(); + deprecatedOps.push_back(cmpOp); break; } } } + for (Operation* cmpOp : deprecatedOps) { + cmpOp->erase(); + } }); } @@ -530,6 +534,7 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId, builder.create(loc, nameBarrierId - 1, 32); // Process mutex users int numUsers = 0; + SmallVector deprecatedOps; for (Operation *user : createMutexOp.getResult().getUsers()) { numUsers++; assert(numUsers <= 2); @@ -543,12 +548,19 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId, Value barLeave = builder.create( loc, isRole0, namedBarrierId1, namedBarrierId0); builder.create(loc, barLeave, numThreads); - } else + } else { llvm_unreachable("Unexpected user of mutex"); + } + deprecatedOps.push_back(user); + } + for (Operation *user : deprecatedOps) { user->erase(); } nameBarrierId -= 2; nameBarrierIdEnd -= 2; + }); + + parentOp->walk([](ttng::CreateMutexOp createMutexOp) { createMutexOp.erase(); }); } @@ -587,6 +599,7 @@ void materializeMutexOperationsOthers(ModuleOp parentOp) { OpBuilder builder(createMutexOp); // Process mutex users + SmallVector deprecatedOps; for (Operation *user : createMutexOp.getResult().getUsers()) { auto loc = user->getLoc(); builder.setInsertionPoint(user); @@ -596,6 +609,10 @@ void materializeMutexOperationsOthers(ModuleOp parentOp) { processUnlockOp(builder, op); else llvm_unreachable("Unexpected user of mutex"); + deprecatedOps.push_back(user); + } + + for (Operation *user : deprecatedOps) { user->erase(); } diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp index 4ed9f0c64996..8c1e654fece6 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp @@ -153,7 +153,7 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp, Value newIdx = builder.createWithAgentIds(loc, pipelineIdx, curRoleId); - persistentForOp.setIterArg(persistentForOp.getNumIterOperands() - 1, newIdx); + persistentForOp.getInitArgsMutable().slice(persistentForOp.getInitArgs().size()-1, 1).assign(newIdx); auto yield = llvm::cast(persistentForOp.getBody()->getTerminator()); auto idxPlusOneOp = diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp index 373eac0e548b..6fab67f555af 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp @@ -162,7 +162,7 @@ scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages, // Copy iter operands of forOp SmallVector newLoopArgs; - for (auto operand : forOp.getIterOperands()) + for (auto operand : llvm::to_vector(forOp.getInitArgs())) newLoopArgs.push_back(operand); // Append initial value of pipelineIdx to newLoopArgs @@ -302,7 +302,7 @@ DenseMap createForOpsForEachAgentId(scf::ForOp forOp) { // Prepare newLoopArgs SmallVector newLoopArgs; for (unsigned argNumber : usedArgs) - newLoopArgs.push_back(forOp.getIterOperands()[argNumber]); + newLoopArgs.push_back(forOp.getInitArgs()[argNumber]); // Create newForOp builder.setAgentIdsFromArray({agentId}); diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 3acc6a92e09c..d2a3f7c74f10 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -2,7 +2,6 @@ #include "LLVMPasses.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" -#include "mlir/Conversion/Passes.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" @@ -256,19 +255,6 @@ static std::map getExternLibs(mlir::ModuleOp module) { funcs.push_back(func); }); - for (LLVM::LLVMFuncOp func : funcs) { - if (auto libnameAttr = func->getDiscardableAttr("libname")) { - auto name = libnameAttr.dyn_cast(); - auto path = func.getOperation() - ->getDiscardableAttr("libpath") - .dyn_cast(); - if (name) { - std::string libName = name.str(); - externLibs[libName] = path.str(); - } - } - } - if (auto externsAttr = module->getDiscardableAttr("triton_gpu.externs")) { for (auto &attr : externsAttr.cast()) { externLibs[attr.getName().strref().trim().str()] = @@ -287,10 +273,8 @@ static std::map getExternLibs(mlir::ModuleOp module) { // Search for libdevice relative to its library path if used from Python // Then native code is in `triton/_C/libtriton.so` and libdevice in // `triton/third_party/cuda/lib/libdevice.10.bc` - static const auto this_library_path = getThisLibraryPath(); static const auto runtime_path = - this_library_path.parent_path().parent_path() / "third_party" / "cuda" / - "lib" / "libdevice.10.bc"; + fs::path(PathToLibdevice()) / "libdevice.10.bc"; if (fs::exists(runtime_path)) { externLibs.try_emplace(libdevice, runtime_path.string()); } else { diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp index fe8841997c35..fe2cd3690436 100644 --- a/lib/Target/PTX/PTXTranslation.cpp +++ b/lib/Target/PTX/PTXTranslation.cpp @@ -54,7 +54,7 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) { auto *shortPtr = static_cast *>(options["nvptx-short-ptr"]); assert(shortPtr); - shortPtr->setValue(true); + shortPtr->setValue(false); std::string sm = cc == 90 ? "sm_90a" : "sm_" + std::to_string(cc); // max PTX version int ptxMajor = maxPTX / 10; @@ -89,7 +89,7 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) { opt.NoNaNsFPMath = true; llvm::TargetMachine *machine = target->createTargetMachine( module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, - std::nullopt, llvm::CodeGenOpt::Aggressive); + std::nullopt, llvm::CodeGenOptLevel::Aggressive); // set data layout if (layout.empty()) module.setDataLayout(machine->createDataLayout()); @@ -105,7 +105,7 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) { llvm::legacy::PassManager pass; // emit machine->addPassesToEmitFile(pass, pstream, nullptr, - llvm::CodeGenFileType::CGFT_AssemblyFile); + llvm::CodeGenFileType::AssemblyFile); pass.run(module); } // post-process diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index 7ae22a56492c..58b294881ab4 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -257,9 +257,9 @@ tt.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr, % %5 = arith.cmpi slt, %1, %arg1 : index cf.cond_br %5, ^bb2, ^bb3 ^bb2: // pred: ^bb1 - %6 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #blocked> + %6 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL> gpu.barrier - %7 = tt.cat %2, %3 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #blocked> + %7 = tt.cat %2, %3 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL> %8 = arith.addi %1, %arg2 : index cf.br ^bb1(%8, %4, %2, %3 : index, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) ^bb3: // pred: ^bb1 diff --git a/test/BUILD b/test/BUILD index 7bade11ef83d..3e8cf6d90308 100644 --- a/test/BUILD +++ b/test/BUILD @@ -9,6 +9,7 @@ # ) # # glob_lit_tests( +# name = "all_tests", # data = [ # "@llvm-project//llvm:FileCheck", # "//:triton-opt", @@ -19,8 +20,10 @@ # "Target/tritongpu_to_llvmir_noinline.mlir", # "Target/tritongpu_to_llvmir.mlir", # "Target/tritongpu_to_ptx.mlir", -# # TODO(b/283035396): broken because pattern is disabled by cl532546169.patch. +# # TODO(b/283035396): broken by cl536931041.patch # "TritonGPU/dot-operands.mlir", +# # TODO(b/303352886): hopper integration issues +# "NVGPU/test_wgmma.mlir", # ], # test_file_exts = ["mlir"], # ) @@ -48,7 +51,6 @@ cc_library( name = "TritonTestAnalysis", srcs = glob(["lib/Analysis/*.cpp"]), deps = [ - "@llvm-project//mlir:Analysis", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 053330d47a49..7a8322d3292e 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -12,7 +12,7 @@ module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : %dst = triton_gpu.alloc_tensor : tensor<1x64x64xf16, #shared> %c0 = arith.constant 0 : i32 %src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array} : !tt.ptr, 1> - // CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {operand_segment_sizes = array} : !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32 + // CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {operandSegmentSizes = array} : !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32 %res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr -> tensor<1x64x64xf16, #shared> tt.return } diff --git a/test/NVGPU/test_cga.mlir b/test/NVGPU/test_cga.mlir index 8b9705db54f2..8b5a333c7370 100644 --- a/test/NVGPU/test_cga.mlir +++ b/test/NVGPU/test_cga.mlir @@ -14,7 +14,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : nvgpu.cga_barrier_arrive nvgpu.cga_barrier_wait - %ptr = llvm.mlir.null : !llvm.ptr + %ptr = llvm.mlir.zero : !llvm.ptr // CHECK: llvm.inline_asm %v = nvgpu.cluster_id diff --git a/test/NVGPU/test_mbarrier.mlir b/test/NVGPU/test_mbarrier.mlir index b12ea58647c7..95b608810378 100644 --- a/test/NVGPU/test_mbarrier.mlir +++ b/test/NVGPU/test_mbarrier.mlir @@ -2,7 +2,7 @@ #SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { tt.func @test_mbarrier() { - %mbarrier = llvm.mlir.null : !llvm.ptr + %mbarrier = llvm.mlir.zero : !llvm.ptr %pred = arith.constant 1 : i1 // CHECK: llvm.inline_asm nvgpu.mbarrier_init %mbarrier, %pred { count = 32 : i32 } : !llvm.ptr diff --git a/test/NVGPU/test_tma.mlir b/test/NVGPU/test_tma.mlir index 4cf7f9b5e838..791cab166228 100644 --- a/test/NVGPU/test_tma.mlir +++ b/test/NVGPU/test_tma.mlir @@ -2,9 +2,9 @@ #SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { tt.func @test_tma(%im2colOffsets0 : !llvm.struct<(i16, i16)>, %im2colOffsets1 : !llvm.struct<(i16, i16, i16)>) { - %mbarrier = llvm.mlir.null : !llvm.ptr - %tmaDesc = llvm.mlir.null : !llvm.ptr - %dst = llvm.mlir.null : !llvm.ptr + %mbarrier = llvm.mlir.zero : !llvm.ptr + %tmaDesc = llvm.mlir.zero : !llvm.ptr + %dst = llvm.mlir.zero : !llvm.ptr %l2desc = arith.constant 0 : i64 %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 diff --git a/test/NVGPU/test_wgmma.mlir b/test/NVGPU/test_wgmma.mlir index f4ae65ad04cf..0098b359a37a 100644 --- a/test/NVGPU/test_wgmma.mlir +++ b/test/NVGPU/test_wgmma.mlir @@ -2,7 +2,7 @@ #SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { tt.func @test_tma(%opC : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) { - %buffer = llvm.mlir.null : !llvm.ptr + %buffer = llvm.mlir.zero : !llvm.ptr %height = arith.constant 16 : i32 // CHECK: llvm.ptrtoint // CHECK: llvm.inline_asm diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index c6be17f8eb04..9d74f97d9e4a 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -69,18 +69,6 @@ tt.func @remat_single_value(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { tt.return } -tt.func @remat_fast_load(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { - %0 = tt.splat %arg : (!tt.ptr) -> tensor<16x!tt.ptr, #layout1> - %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #layout1> - %2 = tt.addptr %0, %1 : tensor<16x!tt.ptr, #layout1>, tensor<16xi32, #layout1> - %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16xi32, #layout1> - // CHECK-NOT: triton_gpu.convert_layout - %4 = triton_gpu.convert_layout %3 : (tensor<16xi32, #layout1>) -> tensor<16xi32, #layout0> - %5 = triton_gpu.convert_layout %2 : (tensor<16x!tt.ptr, #layout1>) -> tensor<16x!tt.ptr, #layout0> - tt.store %5, %4 : tensor<16xi32, #layout0> - tt.return -} - // Hoist the convert on top of ext to make it cheaper. // CHECK-LABEL: hoist_above_ext tt.func @hoist_above_ext(%arg0: tensor<1024xf16, #layout0>, %arg1: f32) -> tensor<1024xf32, #layout1> { @@ -920,7 +908,7 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! %29 = "triton_gpu.select"(%28, %26, %cst_2) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2> %30 = "tt.reduce" (%29) ({ ^bb0(%arg4: f32, %arg5: f32): - %max = arith.maxf %arg4, %arg5 : f32 + %max = arith.maximumf %arg4, %arg5 : f32 tt.reduce.return %max : f32 }) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %31 = triton_gpu.convert_layout %30 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0> @@ -1695,11 +1683,11 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %122 = arith.extf %121 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2> %123 = "tt.reduce"(%122) <{axis = 1 : i32}> ({ ^bb0(%arg28: f32, %arg29: f32): - %153 = arith.maxf %arg28, %arg29 : f32 + %153 = arith.maximumf %arg28, %arg29 : f32 tt.reduce.return %153 : f32 }) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %124 = triton_gpu.convert_layout %123 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xf32, #blocked1> - %125 = arith.maxf %arg25, %124 : tensor<128xf32, #blocked1> + %125 = arith.maximumf %arg25, %124 : tensor<128xf32, #blocked1> %126 = arith.subf %arg25, %125 : tensor<128xf32, #blocked1> %127 = tt.extern_elementwise %126 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #blocked1> %128 = triton_gpu.convert_layout %125 : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> diff --git a/test/TritonGPU/materialize-load-store.mlir b/test/TritonGPU/materialize-load-store.mlir index 65ca0e6c65a7..518694774c5e 100644 --- a/test/TritonGPU/materialize-load-store.mlir +++ b/test/TritonGPU/materialize-load-store.mlir @@ -12,7 +12,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %a_tileptr_init = tt.make_tensor_ptr %A, [%c64, %c16], [%c16, %c1], [%c0, %c0] { order = array } : !tt.ptr, 1> // CHECK: %[[BUFFER:.*]] = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared> // CHECK: %[[MBAR:.*]] = triton_nvidia_gpu.alloc_mbarrier {count = 1 : i32} : !tt.ptr - // CHECK: triton_nvidia_gpu.mbarrier_arrive %[[MBAR]], %{{.*}} {operand_segment_sizes = array, trackAsyncOp = false, txCount = 2048 : i32} : !tt.ptr, i1 + // CHECK: triton_nvidia_gpu.mbarrier_arrive %[[MBAR]], %{{.*}} {operandSegmentSizes = array, trackAsyncOp = false, txCount = 2048 : i32} : !tt.ptr, i1 // CHECK: %[[INSERT:.*]] = triton_nvidia_gpu.insert_slice_async_v2 %[[TENSOR_PTR]], %[[BUFFER]], %{{.*}}, %[[MBAR]] // CHECK: %[[EXT:.*]] = triton_gpu.extract_slice %[[INSERT]][0, 0, 0] [1, 64, 16] [1, 1, 1] : tensor<1x64x16xf16, #shared> to tensor<64x16xf16, #shared> // CHECK: triton_nvidia_gpu.mbarrier_wait %[[MBAR]], %false :