diff --git a/BUILD b/BUILD new file mode 100644 index 000000000000..9f410b75ccc3 --- /dev/null +++ b/BUILD @@ -0,0 +1,904 @@ +# This package imports OpenAI's Triton (https://github.com/openai/triton). +# +# There are two versions of Triton in google3 at the moment. The older version +# can be found at //third_party/py/triton. This is the MLIR-based version close +# to head. We expect to transition users to this version in the following +# weeks. +# +# There is no SLA associated with this package and it may get broken by LLVM +# imports at any time. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = [":license"], + # default_compatible_with = ["//buildenv/target:gce"], + # default_visibility = [ + # # Add your project here if you need to depend on Triton's C++ sources. + # # Add a point of contact we can reach out to when needed in the comment. + # # + # # If you need to use the Python fronted, add your project to + # # google3/third_party/py/triton/BUILD instead. + # # + # # By adding your project here, you agree to the Triton SLA: go/triton-google3-sla + # "//third_party/py/jax:__subpackages__", # cjfj@ + # "//third_party/tensorflow/compiler/xla:__subpackages__", # bchetioui@ + # "//platforms/xla/experimental/gpu:__subpackages__", # csigg@ + # # Triton-internal visibility + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end + # TODO(csigg): fix and remove + features = [ + "-parse_headers", + "-use_header_modules", + ], +) + +# copybara:uncomment_begin +# license(name = "license") +# +# licenses(["notice"]) +# +# exports_files(["LICENSE"]) +# copybara:uncomment_end + +config_setting( + name = "compiler_is_msvc", + flag_values = { + # copybara:comment_begin + "@bazel_tools" + + # copybara:comment_end + "//tools/cpp:compiler": "msvc-cl", + }, +) + +# TODO(csigg): fix, enable error upstream, remove. +_no_unused_variable = select({ + ":compiler_is_msvc": [], + "//conditions:default": ["-Wno-unused-variable"], +}) + +td_library( + name = "td_files", + srcs = glob(["include/triton/**/*.td"]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:CastInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "triton_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/Triton/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/Triton/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonInterfaces.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-enum-decls"], + "include/triton/Dialect/Triton/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/Triton/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-op-decls"], + "include/triton/Dialect/Triton/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/Triton/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/Triton/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/Triton/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=Triton", + ], + "include/triton/Dialect/Triton/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_combine_inc_gen", + # The generated file is #included without relative path. + strip_include_prefix = "lib/Dialect/Triton/Transforms", + tbl_outs = [ + ( + ["--gen-rewriters"], + "lib/Dialect/Triton/Transforms/TritonCombine.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lib/Dialect/Triton/Transforms/Combine.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPU", + ], + "include/triton/Dialect/TritonGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNvidiaGPU", + ], + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_to_triton_gpu_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonToTritonGPU", + ], + "include/triton/Conversion/TritonToTritonGPU/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonToTritonGPU/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_target_llvmir_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonLLVMIR", + ], + "include/triton/Target/LLVMIR/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Target/LLVMIR/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_gpu_to_llvm_pass_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPUToLLVM", + ], + "include/triton/Conversion/TritonGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonGPUToLLVM/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_type_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-type-interface-decls"], + "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.h.inc", + ), + ( + ["--gen-type-interface-defs"], + "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td", + deps = ["td_files"], +) + +cc_library( + name = "TritonAnalysis", + srcs = [ + "lib/Analysis/Alias.cpp", + "lib/Analysis/Allocation.cpp", + "lib/Analysis/Membar.cpp", + # Part of TritonDialects compilation unit to avoid circular dependencies. + # "lib/Analysis/Utility.cpp", + # "lib/Analysis/AxisInfo.cpp", + ], + hdrs = [ + "include/triton/Analysis/Alias.h", + "include/triton/Analysis/Allocation.h", + "include/triton/Analysis/Membar.h", + # Part of TritonDialects compilation unit to avoid circular dependencies. + # "include/triton/Analysis/AxisInfo.h", + # "include/triton/Analysis/Utility.h", + "include/triton/Conversion/MLIRTypes.h", + "include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h", + "include/triton/Conversion/TritonGPUToLLVM/Utility.h", + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", + ], + copts = _no_unused_variable, + deps = [ + ":TritonDialects", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonDialects", + srcs = glob([ + "lib/Dialect/Triton/IR/*.cpp", + "lib/Dialect/TritonGPU/IR/*.cpp", + "lib/Dialect/TritonNvidiaGPU/IR/*.cpp", + "lib/Tools/*.cpp", + ]) + [ + "include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h", # Avoid circular dependency. + "lib/Analysis/AxisInfo.cpp", # Avoid circular dependency. + "lib/Analysis/Utility.cpp", # Avoid circular dependency. + "lib/Dialect/TritonGPU/Transforms/Utility.cpp", # Avoid circular dependency. + ], + hdrs = glob([ + "include/triton/Dialect/Triton/IR/*.h", + "include/triton/Dialect/TritonGPU/IR/*.h", + "include/triton/Dialect/TritonNvidiaGPU/IR/*.h", + "include/triton/Tools/*.h", + ]) + [ + "include/triton/Analysis/AxisInfo.h", # Avoid circular dependency. + "include/triton/Analysis/Utility.h", # Avoid circular dependency. + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", # Avoid circular dependency. + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-logical-op-parentheses", + ], + }), + includes = ["include"], + deps = [ + ":triton_dialect_inc_gen", + ":triton_gpu_attr_inc_gen", + ":triton_gpu_dialect_inc_gen", + ":triton_gpu_ops_inc_gen", + ":triton_gpu_types_inc_gen", + ":triton_interfaces_inc_gen", + ":triton_nvidia_gpu_attr_inc_gen", + ":triton_nvidia_gpu_dialect_inc_gen", + ":triton_nvidia_gpu_ops_inc_gen", + ":triton_nvidia_gpu_types_inc_gen", + ":triton_ops_inc_gen", + ":triton_types_inc_gen", + ":triton_type_interfaces_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@triton//third_party/nvidia:NVGPUDialect", + # The following is added to make Utility compile + ":TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@triton//third_party/f2reduce", + ], +) + +cc_library( + name = "TritonTransforms", + srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]), + hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]), + copts = _no_unused_variable, + deps = [ + ":TritonDialects", + ":triton_combine_inc_gen", + ":triton_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], + alwayslink = True, # TritonDialect uses getCanonicalizationPatterns(). +) + +cc_library( + name = "TritonGPUTransforms", + srcs = glob( + [ + "lib/Dialect/TritonGPU/Transforms/*.cpp", + "lib/Dialect/TritonGPU/Transforms/*.h", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.cpp", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.h", + ], + exclude = ["lib/Dialect/TritonGPU/Transforms/Utility.cpp"], + ), + hdrs = glob( + [ + "include/triton/Dialect/TritonGPU/Transforms/*.h", + ], + exclude = ["include/triton/Dialect/TritonGPU/Transforms/Utility.h"], + ) + [ + "include/triton/Tools/Sys/GetEnv.hpp", + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-reorder-ctor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + deps = [ + ":TritonAnalysis", + ":TritonDialects", + ":TritonGPUToLLVM", + ":triton_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonGPUToLLVM", + srcs = glob([ + "lib/Conversion/TritonGPUToLLVM/*.h", + "lib/Conversion/TritonGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/triton/Tools/Sys/*.hpp", + "include/triton/Conversion/TritonGPUToLLVM/*.h", + ]), + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonAnalysis", + ":TritonDialects", + ":triton_conversion_triton_gpu_to_llvm_pass_inc_gen", + ":triton_gpu_attr_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonNvidiaGPUTransforms", + srcs = glob([ + "lib/Dialect/TritonNvidiaGPU/Transforms/*.cpp", + ]), + hdrs = glob([ + "include/triton/Dialect/TritonNvidiaGPU/Transforms/*.h", + ]), + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-ctad-maybe-unsupported", + "-Wno-logical-op-parentheses", + "-Wno-non-virtual-dtor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonToTritonGPU", + srcs = glob([ + "lib/Conversion/TritonToTritonGPU/*.h", + "lib/Conversion/TritonToTritonGPU/*.cpp", + ]), + hdrs = glob(["include/triton/Conversion/TritonToTritonGPU/*.h"]), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonLLVMIR", + srcs = glob([ + "lib/Target/LLVMIR/*.cpp", + "lib/Target/LLVMIR/*.h", + ]), + hdrs = glob(["include/triton/Target/LLVMIR/*.h"]), + copts = _no_unused_variable, + deps = [ + ":TritonTransforms", + ":triton_target_llvmir_passes_inc_gen", + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BinaryFormat", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexToLLVM", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLToLLVMIRTranslation", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + # copybara:uncomment "//third_party/py/triton/google:find_cuda", + ], +) + +cc_library( + name = "TritonPTX", + srcs = glob([ + "lib/Target/PTX/*.cpp", + ]), + hdrs = glob(["include/triton/Target/PTX/*.h"]), + deps = ["@llvm-project//llvm:Support"], +) + +cc_library( + name = "TritonHSACO", + srcs = glob([ + "lib/Target/HSACO/*.cpp", + ]), + hdrs = glob(["include/triton/Target/HSACO/*.h"]), + deps = [ + ":TritonLLVMIR", + ":TritonTools", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:ExecutionEngine", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Scalar", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:TransformUtils", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + ], +) + +cc_library( + name = "TritonTools", + hdrs = ["include/triton/Tools/Sys/GetEnv.hpp"], +) + +cc_library( + name = "AllPassesAndDialects", + srcs = [ + "include/triton/Conversion/TritonToTritonGPU/Passes.h", + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h", + ], + hdrs = ["bin/RegisterTritonDialects.h"], + includes = ["."], # because it includes third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h + deps = [ + ":TritonDialects", + ":TritonGPUToLLVM", + ":TritonGPUTransforms", + ":TritonLLVMIR", + ":TritonNvidiaGPUTransforms", + ":TritonToTritonGPU", + ":TritonTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//mlir:AllPassesAndDialects", + "@triton//test:TritonTestAnalysis", + "@triton//third_party/amd:TritonAMDGPUToLLVM", + "@triton//third_party/amd:TritonAMDGPUTransforms", + "@triton//third_party/nvidia:NVGPUDialect", + "@triton//third_party/nvidia:NVGPUToLLVM", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +cc_binary( + name = "triton-opt", + srcs = [ + "bin/triton-opt.cpp", + ], + deps = [ + ":AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", + ], +) + +cc_binary( + name = "triton-llvm-opt", + srcs = [ + "bin/triton-llvm-opt.cpp", + "lib/Target/LLVMIR/LLVMPasses.h", + ], + deps = [ + ":TritonLLVMIR", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Option", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + ], +) + +# See go/triton-debug for usage. +cc_binary( + name = "triton-reduce", + srcs = ["bin/triton-reduce.cpp"], + deps = [ + ":AllPassesAndDialects", + "@llvm-project//mlir:MlirReduceLib", + ], +) + +cc_binary( + name = "triton-tensor-layout", + srcs = ["bin/triton-tensor-layout.cpp"], + deps = [ + ":AllPassesAndDialects", + ":TritonDialects", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + ], +) + +filegroup( + name = "metadata-file", + srcs = ["METADATA"], +) diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h index 22a7ed5549e9..aad4503b4840 100644 --- a/include/triton/Analysis/AxisInfo.h +++ b/include/triton/Analysis/AxisInfo.h @@ -180,8 +180,8 @@ class ModuleAxisInfoAnalysis : public CallGraph { for (auto funcOp : llvm::reverse(sortedFuncs)) { initialize(funcOp); funcOp.walk([&](CallOpInterface callOp) { - auto callee = - dyn_cast(callOp.resolveCallable(&symbolTable)); + auto callee = dyn_cast( + callOp.resolveCallableInTable(&symbolTable)); update(callOp, callee); }); } diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index e0b22b2c795b..9e3eff155b40 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -316,7 +316,7 @@ template class CallGraph { moduleOp.walk([&](Operation *op) { auto caller = op->getParentOfType(); if (auto callOp = dyn_cast(op)) { - auto *callee = callOp.resolveCallable(&symbolTable); + auto *callee = callOp.resolveCallableInTable(&symbolTable); auto funcOp = dyn_cast_or_null(callee); if (funcOp) { graph[caller].emplace_back( diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index ac7bf96f7fdd..243b934367ad 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -91,6 +91,10 @@ def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods, diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 7d4088b16f2b..679dcc88d788 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -488,6 +488,11 @@ bool supportMMA(triton::DotOp op, int version) { if (triton::tools::getBoolEnv("DISABLE_MMA_V3")) return false; auto retType = op.getType(); + RankedTensorType typeA = op.getA().getType(); + int k = typeA.getShape().back(); + // If k size is smaller than the native mma size, we cannot use MMA. + if (k < 256 / aElemTy.getIntOrFloatBitWidth()) + return false; auto retShapePerCTA = getShapePerCTA(retType); auto rank = retShapePerCTA.size(); auto mod = op->getParentOfType(); diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 0287207be51a..2d21a1bfeca2 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -40,7 +40,8 @@ SmallVector reorderValues(const SmallVector &values, Type inType, auto ouEltTy = ouTensorTy.getElementType(); if (inBitWidth == ouBitWidth) return values; - if (inBitWidth == 16 && ouBitWidth == 32) { + if ((inBitWidth == 16 && ouBitWidth == 32) || + (inBitWidth == 32 && ouBitWidth == 16)) { SmallVector ret; for (unsigned i = 0; i < values.size(); i += 8) { ret.push_back(values[i]); @@ -610,10 +611,9 @@ struct IndexCastOpLowering if (targetBits == sourceBits) return {operands[0][0]}; if (targetBits < sourceBits) - return {rewriter.replaceOpWithNewOp(op, elemTy, - operands[0][0])}; - return { - rewriter.replaceOpWithNewOp(op, elemTy, operands[0][0])}; + return { + rewriter.create(op.getLoc(), elemTy, operands[0][0])}; + return {rewriter.create(op.getLoc(), elemTy, operands[0][0])}; } }; diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index a6c8e4df369b..0eb0d50acd95 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -87,8 +87,9 @@ struct ArithConstantSplatOpConversion // LLVM IR. if (type::isFloat8(elemType)) elemType = rewriter.getIntegerType(8); - auto constOp = rewriter.create(loc, elemType, val); auto typeConverter = getTypeConverter(); + auto constOp = rewriter.create( + loc, typeConverter->convertType(elemType), val); auto llStruct = SplatOpConversion::convertSplatLikeOp( elemType, op.getType(), constOp, typeConverter, rewriter, loc); rewriter.replaceOp(op, llStruct); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index a454fef56674..c0ca4ebfcff1 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2717,6 +2717,11 @@ struct CanonicalizeConvertFromAlloc auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); + // LocalAllocOp lowering doesn't support going from DotOperandEncoding + // to SharedEncoding, so we want to keep this layout conversion. + if (mlir::isa( + convert.getSrc().getType().getEncoding())) + return failure(); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), convert.getSrc()); return mlir::success(); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index d9bbd51bd9a1..7776a93305ff 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -153,6 +153,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, auto newType = MemDescType::get(argType.getShape(), argType.getElementType(), newLayout, SharedMemorySpace); rewriter.setInsertionPointAfterValue(arg); + + // LocalAllocOp lowering doesn't support going from DotOperandEncoding + // to SharedEncoding. + if (auto dotOpEnc = mlir::dyn_cast( + argType.getEncoding())) { + // Create a layout conversion from DotOperandEncoding to BlockedEncoding + // then pass it to the LocalAllocOp. + auto newArgType = RankedTensorType::get( + argType.getShape(), argType.getElementType(), dotOpEnc.getParent()); + auto dotOperandToBlockedCvt = + rewriter.create(arg.getLoc(), newArgType, arg); + return rewriter.create(arg.getLoc(), newType, + dotOperandToBlockedCvt); + } + return rewriter.create(arg.getLoc(), newType, arg); } @@ -162,6 +177,15 @@ class BlockedToMMA : public mlir::OpRewritePattern { mutable llvm::DenseMap dotOpInstNs; static bool bwdFilter(Operation *op) { + // Dot operand layout assignment to Predicates are not currently supported + // during lowering from TritonGPU to LLVM in Triton for MMA cases. This + // condition limits visibility of the original bit-width so that predicate + // are not considered, hence, kwidth can never be = 32. + if (isa(op)) { + Type srcType = getElementTypeOrSelf(op->getOperand(0)); + if (srcType.isInteger(1)) + return false; + } return op->getNumOperands() == 1 && (isa(op) || isPureUnaryInlineAsm(op) || diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp index b5ed64cca3ac..dd9b4ad139f5 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp @@ -14,7 +14,12 @@ namespace gpu { namespace { bool dotSupportsAccInitFlag(Operation *op) { assert(op->hasTrait() && "Expected a dot-like operation"); - return isa(op); + if (auto wgDotOp = dyn_cast(op)) { + // Partial accumulation would require a select op to handle the + // initialization that would degrade the performance. + return !wgDotOp.needsPartialAccumulator(); + } + return false; } std::pair getAccumulatorUseAndDef(Operation *op) { diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 6d8279795209..e6e0ec8d7cef 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -111,7 +111,8 @@ class HoistLayoutConversion : public OpRewritePattern { PatternRewriter &rewriter) const override { // Only consider conversions to dot operand. auto cvtTy = cast(cvt.getType()); - if (!isa(cvtTy.getEncoding())) + auto dotOpEnc = dyn_cast(cvtTy.getEncoding()); + if (!dotOpEnc) return failure(); auto src = cvt.getSrc().getDefiningOp(); @@ -126,6 +127,12 @@ class HoistLayoutConversion : public OpRewritePattern { [](Type ty) { return isa(ty); })) return failure(); + // Quick handling to fix loading issues when computing the original + // bitwidth is unable to realize that there is a mixed-precision dot + // (hence kWidth = 1) but wants to hoist through the type conversion. + if (isa(src) && dotOpEnc.getKWidth() == 1) + return failure(); + // Only consider custom conversions or arith ops. // TODO(jlebar): Is this too restrictive? if (!isa(src) && !isPureUnaryInlineAsm(src) && @@ -138,6 +145,14 @@ class HoistLayoutConversion : public OpRewritePattern { if (isa(src)) return failure(); + // Don't hoist through u1 -> fp casts as they aren't supported in + // ElementwiseOpToLLVM::reorderValues(). + if (isa(src)) { + Type srcType = getElementTypeOrSelf(src->getOperand(0)); + if (srcType.isInteger(1)) + return failure(); + } + // Check that the conversion is transitively dependent on a load, and all // operations between the load and the conversion are layout preserving. // diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 02994e1ac059..cd6fc806928d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -140,8 +140,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, type.getMemorySpace()), v, offsetsVal); + // We need to assign kwidth to zero in the case where the parent layout is + // Blocked, otherwise the verifier emits a failure. The parent layout is + // Blocked only when Tensor Cores are disabled. + int kwidth = dyn_cast(dotEncoding) + ? 0 + : prefetchWidth / 8; auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( - builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8); + builder.getContext(), opIdx, dotEncoding, kwidth); Value prefetchSlice = builder.create( v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), newSmem); @@ -187,6 +193,15 @@ LogicalResult Prefetcher::initialize() { break; if (!op->getResult(0).hasOneUse()) break; + // Similar to issues faced in HoistLayoutConversion pattern in + // OptimizeDotOperands.cpp, we can't propagate through type casts from + // predicates as they aren't supported in Triton when encoded with dot_op + // layout. + if (isa(op)) { + Type srcType = getElementTypeOrSelf(op->getOperand(0)); + if (srcType.isInteger(1)) + break; + } rets.push_back(op->getOperand(0)); if (auto cvt = dyn_cast(op)) { foundConvertFromShared = true; diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index bd19acf04fdb..37c69eef8adb 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -70,6 +70,18 @@ void WarpGroupDotOp::getEffects( mlir::triton::gpu::SharedMemory::get()); } +bool WarpGroupDotOp::needsPartialAccumulator() { + const auto &a = getA(); + const auto &d = getD(); + auto aTensorTy = cast(a.getType()); + auto aElTy = cast(a.getType()).getElementType(); + bool isFP8 = aElTy.isFloat8E5M2() || aElTy.isFloat8E4M3FN() || + aElTy.isFloat8E5M2FNUZ() || aElTy.isFloat8E4M3FNUZ(); + bool accFP32 = cast(d.getType()).getElementType().isF32(); + uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc(); + return isFP8 && accFP32 && maxNumImpreciseAcc <= aTensorTy.getShape()[1]; +} + // -- WarpGroupDotWaitOp -- LogicalResult WarpGroupDotWaitOp::inferReturnTypes( ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, diff --git a/python/BUILD b/python/BUILD new file mode 100644 index 000000000000..334dd4aec41a --- /dev/null +++ b/python/BUILD @@ -0,0 +1,77 @@ +# NOTE: Do not depend on any targets from this directory, +# but use //third_party/py/triton instead. + +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__pkg__", + "@triton//python:__subpackages__", + ], +) + +cc_library( + name = "passes", + hdrs = ["src/passes.h"], + includes = ["src"], + visibility = ["@triton//third_party:__subpackages__"], +) + +pybind_extension( + name = "libtriton", + srcs = [ + "src/interpreter.cc", + "src/ir.cc", + "src/llvm.cc", + "src/main.cc", + "src/passes.cc", + ], + copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"], + deps = [ + ":passes", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonGPUTransforms", + "//:TritonHSACO", + "//:TritonLLVMIR", + "//:TritonNvidiaGPUTransforms", + "//:TritonPTX", + "//:TritonToTritonGPU", + "//:TritonTools", + "//:TritonTransforms", + "@triton//third_party/nvidia:triton_nvidia", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["triton/**/*.py"], + ), +) diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 3b5ec1fada07..02317919107f 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -1,4 +1,4 @@ -#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp +#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "triton/Tools/Sys/GetEnv.hpp" @@ -346,8 +346,8 @@ void init_triton_llvm(py::module &&m) { // and break the lowering of some target specific intrinsics. std::unique_ptr targetMachine = nullptr; if (!arch.empty() && pluginFile.empty()) - targetMachine = std::move( - createTargetMachine(mod, arch, enable_fp_fusion, features)); + targetMachine = + createTargetMachine(mod, arch, enable_fp_fusion, features); PassBuilder pb(/*targetMachine=*/targetMachine.get(), tuningOptions, std::nullopt, instrCbPtr); diff --git a/python/test/regression/BUILD b/python/test/regression/BUILD new file mode 100644 index 000000000000..a88f4eeae1f8 --- /dev/null +++ b/python/test/regression/BUILD @@ -0,0 +1,26 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests") + +package( + default_applicable_licenses = ["//:license"], +) + +pytest_multi_tests( + name = "tests", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = glob( + include = ["test_*.py"], + exclude = [ + "test_performance.py", #TODO(b/321005767): fix failing test + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/regression/conftest.py b/python/test/regression/conftest.py new file mode 100644 index 000000000000..7a02d322b49f --- /dev/null +++ b/python/test/regression/conftest.py @@ -0,0 +1,12 @@ +# content of conftest.py + +import pytest + + +def pytest_addoption(parser): + parser.addoption("--device", action="store", default='cuda') + + +@pytest.fixture +def device(request): + return request.config.getoption("--device") diff --git a/python/test/unit/BUILD b/python/test/unit/BUILD new file mode 100644 index 000000000000..396b6dcb2f4e --- /dev/null +++ b/python/test/unit/BUILD @@ -0,0 +1,181 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests", "pytest_test") + +package( + default_applicable_licenses = ["//:license"], +) + +_requires_gpu_sm80 = [ + "config-cuda-only", + "requires-gpu-sm80", +] + +_requires_config_cuda = select( + {"@local_config_cuda//cuda:using_clang_allow_exec": []}, + no_match_error = "Requires --config=cuda", +) + +EXCLUDE_TESTS = [ + "language/test_reproducer.py", # this is not an actual test, but a tool for running reproducers + "language/test_subprocess.py", # TODO(b/320224484): fix failing test + "runtime/test_launch.py", # TODO(b/320226169): fix failing tests + "tools/test_aot.py", # TODO(b/320224484): fix failing test + "tools/test_disasm.py", # TODO(b/320224484): fix failing test + "hopper/test_persistent_warp_specialized_gemm.py", # TODO (b/342348738): fix failing test + "hopper/test_tma_descriptor.py", # TODO (b/358060133): fix failing test + "runtime/test_cublas.py", # TODO(b/346755023): fix failing test +] + +# Runs all python tests on H100 +pytest_multi_tests( + name = "hopper", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + "language/test_core.py", + ], + name_suffix = "_h100", + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm90", + ], + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["**/test_*.py"], + exclude = EXCLUDE_TESTS + [ + "language/test_core.py", + "language/test_pipeliner.py", # TODO(b/362458006): fix failing test + "hopper/test_experimental_tma.py", # TODO(b/362458006): fix failing test + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +# Shard test_core more, as it is otherwise very slow to run. +pytest_test( + name = "hopper/language/test_core_h100", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + ], + shard_count = 40, + tags = [ + "config-cuda-only", + "requires-gpu-sm90", + ], + target_compatible_with = _requires_config_cuda, + tests = ["language/test_core.py"], + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "language", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + "language/test_core.py", + ], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["language/**/test_*.py"], + exclude = EXCLUDE_TESTS + ["language/test_core.py"], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +# Shard test_core more, as it is otherwise very slow to run. +pytest_test( + name = "language/test_core", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + ], + shard_count = 40, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = ["language/test_core.py"], + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "instrumentation", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["instrumentation/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "runtime", + srcs = ["conftest.py"], + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["runtime/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "tools", + size = "large", + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["tools/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "unit", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index d0b98c9859d5..d784429b1764 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -329,7 +329,7 @@ def kernel(a=GLOBAL): pass # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={0: "i32"}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constants={})) def test_defaults_assign_no_err(): diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index af8f414ad31c..6c9b94e39af8 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2139,6 +2139,8 @@ def kernel(X, Z, BLOCK: tl.constexpr): reduce_bool = [(op, 'bool', shape, axis, False) for op in ['xor_sum'] for shape in reduce2d_shapes for axis in [0, 1]] +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] >= 9, + reason='Reduction test produces wrong results on H100, b/342347027') @pytest.mark.interpreter @pytest.mark.parametrize( "op, dtype_str, shape, axis, keep_dims", reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + @@ -3642,6 +3644,25 @@ def _kernel(out): kernel[(1, )](out) assert torch.all(out == out_ref) +@pytest.mark.interpreter +def test_dot_on_broadcast(device): + @triton.jit + def _kernel(a, b, out): + a_offsets = tl.arange(0, 64)[:, None] * 32 + tl.arange(0, 32)[None, :] + lhs = tl.load(a + a_offsets, mask=a_offsets < 32 * 64) + rhs = tl.load(b) + rhs_bc = tl.broadcast_to(rhs, [32, 32]) + c = tl.dot(lhs, rhs_bc) + out_ptr = out + tl.arange(0, 64)[:, None] * 32 + tl.arange(0, 32)[None, :] + tl.store(out_ptr, c) + + a = torch.ones((64, 32), dtype=getattr(torch, 'float32'), device=device) + b = torch.tensor([1.0], dtype=getattr(torch, 'float32'), device=device) + out_ref = torch.matmul(a, torch.broadcast_to(b, (32, 32))) + out = torch.zeros((64, 32), dtype=getattr(torch, 'float32'), device=device) + _kernel[(1, )](a, b, out, num_stages=1, num_warps=4) + assert torch.all(out == out_ref) + # --------------- # test arange diff --git a/python/test/unit/runtime/test_bindings.py b/python/test/unit/runtime/test_bindings.py index b3aebc9d8526..11ce4f8fc14b 100644 --- a/python/test/unit/runtime/test_bindings.py +++ b/python/test/unit/runtime/test_bindings.py @@ -61,10 +61,12 @@ def walk_fn(op): ] src = triton.compiler.compiler.ASTSource( fn=kernel, - signature={i: kernel._type_of(kernel._key_of(arg)) - for i, arg in enumerate(args) - if i not in kernel.constexprs}, - constants={i: arg + signature={ + kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg)) + for i, arg in enumerate(args) + if i not in kernel.constexprs + }, + constants={kernel.arg_names[i]: arg for i, arg in enumerate(args) if not isinstance(arg, torch.Tensor)}, attrs=kernel._get_config(*args, ), diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index 7240fb7bb562..c4afd1e0ed11 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -17,8 +17,8 @@ def kernel_sub(a, b, o, N: tl.constexpr): src = ASTSource( fn=kernel_sub, - constants={3: 32}, - signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, + constants={'N': 32}, + signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32"}, attrs=attrs, ) triton.compile(src=src, target=target) @@ -42,7 +42,7 @@ def kernel_dot(Z): z = tl.dot(z, z) tl.store(Z + offs, z) - src = ASTSource(fn=kernel_dot, signature={0: "*fp32"}, attrs=attrs, constants=dict()) + src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"}, attrs=attrs, constants={}) triton.compile(src=src, target=target) @@ -63,7 +63,7 @@ def empty_kernel(): import gc gc.collect() - src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constants=dict()) + src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constants={}) triton.compile(src=src, target=target) diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py index 8b793dd36095..9f4b9a0830a0 100644 --- a/python/test/unit/test_perf_warning.py +++ b/python/test/unit/test_perf_warning.py @@ -37,8 +37,9 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, triton.compile( triton.compiler.ASTSource( fn=matmul_kernel, signature={ - 0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32', 5: 'i32', 6: 'i32', 7: 'i32', 8: 'i32', 9: - 'i32', 10: 'i32', 11: 'i32' + 'a_ptr': '*fp32', 'b_ptr': '*fp32', 'c_ptr': '*fp32', 'M': 'i32', 'N': 'i32', 'K': 'i32', 'stride_am': + 'i32', 'stride_ak': 'i32', 'stride_bk': 'i32', 'stride_bn': 'i32', 'stride_cm': 'i32', 'stride_cn': + 'i32' }, constants={})) captured = capfd.readouterr() @@ -75,8 +76,10 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr) XBLOCK = 1024 triton.compile( - triton.compiler.ASTSource(fn=ldst_vec, signature={0: '*i64', 1: '*i64', 2: '*fp16', 3: '*fp32', 4: '*fp16'}, - constants={"XBLOCK": XBLOCK}), options={"num_warps": 1}) + triton.compiler.ASTSource( + fn=ldst_vec, signature={ + 'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*fp16', 'in_ptr3': '*fp32', 'out_ptr0': '*fp16' + }, constants={"XBLOCK": XBLOCK}), options={"num_warps": 1}) _, err = capfd.readouterr() assert ("remark: Warning: vectorization fails" in err), "expect vectorization failure remark" diff --git a/python/triton/_C/include b/python/triton/_C/include index b85a409837d1..8a5dba6c4b56 120000 --- a/python/triton/_C/include +++ b/python/triton/_C/include @@ -1 +1 @@ -../../../include/ \ No newline at end of file +../../../include \ No newline at end of file diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index 92ba144ba97b..f9bab523bf6c 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -46,5 +46,8 @@ def _discover_backends(): _find_concrete_subclasses(driver, DriverBase)) return backends - -backends = _discover_backends() +from triton.backends.nvidia.driver import CudaDriver +from triton.backends.nvidia.compiler import CUDABackend +backends = { + "nvidia": Backend(CUDABackend, CudaDriver) +} diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 3af4c788e4a1..19d09de85919 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -603,7 +603,7 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block): then_defs[name] = liveins[name] # variables that are both in then and else but not in liveins # TODO: could probably be cleaned up - for name in then_defs.keys() & else_defs.keys(): + for name in sorted(then_defs.keys() & else_defs.keys()): if name in names: continue then_ty = then_defs[name].type diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 40b21e1a1219..873a0844901a 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -96,8 +96,16 @@ def __init__(self, fn, signature, constants=None, attrs=None) -> None: self.attrs = attrs if isinstance(self.signature, str): self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} + else: + for k in self.signature.keys(): + if not isinstance(k, str): + raise TypeError("Signature keys must be string") if self.constants is None: - self.constants = dict() + self.constants = {} + else: + for k in self.constants.keys(): + if not isinstance(k, str): + raise TypeError("Constants keys must be string") if self.attrs is None: self.attrs = AttrsDescriptor() diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index 872332b03d99..1e9697cc8226 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -92,11 +92,15 @@ def constexpr(s): hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} hints = {k: v for k, v in hints.items() if v is not None} - constants = {i: constexpr(s) for i, s in enumerate(signature)} + constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)} constants = {k: v for k, v in constants.items() if v is not None} - signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constants} + signature = { + kernel.arg_names[i]: s.split(":")[0] + for i, s in enumerate(signature) + if kernel.arg_names[i] not in constants + } const_sig = 'x'.join([str(v) for v in constants.values()]) - doc_string = [f"{kernel.arg_names[i]}={constants[i]}" for i in constants.keys()] + doc_string = [f"{k}={v}" for k, v in constants.items()] doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] # compile ast into cubin @@ -106,16 +110,23 @@ def constexpr(s): equal_to_1 = [i for i, h in hints.items() if h == 1] attrs = triton.compiler.AttrsDescriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) for i in equal_to_1: - constants.update({i: 1}) + constants.update({kernel.arg_names[i]: 1}) src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} ccinfo = triton.compile(src, options=opts) arg_names = [] arg_types = [] - for i in signature.keys(): - if i not in equal_to_1: - arg_names += [kernel.arg_names[i]] - arg_types += [signature[i]] + arg_names_not_1 = [] + arg_types_not_1 = [] + for i, arg_name in enumerate(kernel.arg_names): + if arg_name not in constants: + arg_names.append(arg_name) + arg_types.append(signature[arg_name]) + arg_names_not_1.append(arg_name) + arg_types_not_1.append(signature[arg_name]) + elif i in equal_to_1: + arg_names.append(arg_name) + arg_types.append(signature[arg_name]) # dump C stub code suffix = kernel_suffix(signature.values(), attrs) @@ -126,10 +137,10 @@ def constexpr(s): "triton_kernel_name": args.kernel_name, "bin_size": len(hex_), "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), - "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]), - "full_signature": ", ".join([f"{ty_to_cpp(signature[i])} {kernel.arg_names[i]}" for i in signature.keys()]), - "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names]), - "num_args": len(arg_names), + "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]), + "full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]), + "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1]), + "num_args": len(arg_names_not_1), "kernel_docstring": doc_string, "shared": ccinfo.metadata.shared, "num_warps": args.num_warps, diff --git a/test/BUILD b/test/BUILD new file mode 100644 index 000000000000..e94e1c4abdf0 --- /dev/null +++ b/test/BUILD @@ -0,0 +1,63 @@ +# copybara:uncomment_begin +# load("//third_party/llvm/build_defs:lit.bzl", "glob_lit_tests") +# load("//tools/build_defs/build_test:build_test.bzl", "build_test") +# +# package( +# default_applicable_licenses = ["//:license"], +# default_compatible_with = ["//buildenv/target:gce"], +# default_visibility = ["//:__subpackages__"], +# ) +# +# glob_lit_tests( +# name = "all_tests", +# data = [ +# "@llvm-project//llvm:FileCheck", +# "//:triton-llvm-opt", +# "//:triton-opt", +# "//:triton-tensor-layout", +# ], +# driver = "@llvm-project//mlir:run_lit.sh", +# exclude = [ +# "Conversion/amd/dedup-by-constancy.mlir", # AMD-specific, broken +# "TritonGPU/dot-operands.mlir", # TODO: b/283035396 - broken by cl536931041.patch +# "TritonGPU/optimize_epilogue.mlir", # TODO: b/346283526 - AMD-specific, triggering UBSAN +# ], +# test_file_exts = [ +# "mlir", +# "ll", +# ], +# ) +# +# build_test( +# name = "build_test", +# allow_empty_target = False, +# targets = [ +# "//:TritonAnalysis", +# "//:TritonDialects", +# "//:TritonGPUToLLVM", +# "//:TritonGPUTransforms", +# "//:TritonLLVMIR", +# "//:TritonPTX", +# "//:TritonToTritonGPU", +# "//:TritonTools", +# "//:TritonTransforms", +# "//:triton-opt", +# ], +# ) +# copybara:uncomment_end + +cc_library( + name = "TritonTestAnalysis", + srcs = glob(["lib/Analysis/*.cpp"]), + deps = [ + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) diff --git a/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir b/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir new file mode 100644 index 000000000000..f30e0aa6d98e --- /dev/null +++ b/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir @@ -0,0 +1,33 @@ +// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx942 | FileCheck %s + +// CHECK-DAG: #[[DST_ENC:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK-DAG: #[[SRC_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}> +// CHECK-DAG: #[[TMP_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}> +// CHECK: large_tensor_conversion +#src = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = false}> +#dst = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @large_tensor_conversion(%arg0: tensor<128x128xf32, #src>) { + // CHECK: %[[TMP:.*]] = triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, #[[SRC_ENC]]> -> tensor<128x128xf32, #[[TMP_ENC]]> + // CHECK: {{.*}} = triton_gpu.convert_layout %[[TMP]] : tensor<128x128xf32, #[[TMP_ENC]]> -> tensor<128x128xf32, #[[DST_ENC]]> + %0 = triton_gpu.convert_layout %arg0 : tensor<128x128xf32, #src> -> tensor<128x128xf32, #dst> + tt.return + } +} + +// ----- + +// CHECK-DAG: #[[DST_ENC:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK-DAG: #[[SRC_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}> +// CHECK-DAG: #[[TMP_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}> +// CHECK: large_tensor_3d_conversion +#src = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 1, 2], instrShape = [32, 32], isTransposed = false}> +#dst = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 64, 1], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @large_tensor_3d_conversion(%arg0: tensor<2x128x64xf32, #src>) { + // CHECK: %[[TMP:.*]] = triton_gpu.convert_layout {{.*}} : tensor<2x128x64xf32, #[[SRC_ENC]]> -> tensor<2x128x64xf32, #[[TMP_ENC]]> + // CHECK: {{.*}} = triton_gpu.convert_layout %[[TMP]] : tensor<2x128x64xf32, #[[TMP_ENC]]> -> tensor<2x128x64xf32, #[[DST_ENC]]> + %0 = triton_gpu.convert_layout %arg0 : tensor<2x128x64xf32, #src> -> tensor<2x128x64xf32, #dst> + tt.return + } +} diff --git a/test/Conversion/amd/decompose-unsupported-conversions.mlir b/test/Conversion/amd/decompose-unsupported-conversions.mlir index b03387e2647f..0d6220c80da2 100644 --- a/test/Conversion/amd/decompose-unsupported-conversions.mlir +++ b/test/Conversion/amd/decompose-unsupported-conversions.mlir @@ -1,8 +1,9 @@ -// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx942 | FileCheck %s +// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx1130 | FileCheck %s // CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> // CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> // CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}> +// CHECK: wmma_to_wmma_dot_op #mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) { @@ -13,3 +14,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> +// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}> +// CHECK: wmma_to_wmma_dot3d_op +#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2, 2]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @wmma_to_wmma_dot3d_op(%arg0: tensor<2x16x16xf16, #mma>) { + // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[WMMA]]> -> tensor<2x16x16xf16, #[[BLOCKED]]> + // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<2x16x16xf16, #[[SHARED]], #triton_gpu.shared_memory> + // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>> + %0 = triton_gpu.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + tt.return + } +} diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 923324616940..54189efbc9c5 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -142,3 +142,38 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : tt.return } } + + +// ----- + +// Verify that we use mmav2 when the k dim is too small for mmav3. +// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 4], instrShape = [16, 8]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @small_k_size( + %a: tensor<128x16xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + %b: tensor<16x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) + -> tensor<128x128xf32, #blocked> { + %zero_f32 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %result = tt.dot %a, %b, %zero_f32 : tensor<128x16xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> + tt.return %result : tensor<128x128xf32, #blocked> + } +} + +// ----- + +// CHECK-DAG: #[[$BLOCKED:.*]] = #triton_gpu.blocked +// CHECK-DAG: #mma = #triton_gpu.nvidia_mma<{versionMajor = 3 +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @local_alloc_dot_operand(%in0: tensor<64x256xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> {tt.divisibility = 16 : i32}, %in1: f32, %in2: tensor<64x32xf32, #blocked>) -> (tensor<64x32xf32, #blocked>) { + // CHECK-LABEL: local_alloc_dot_operand + %splat_in1 = tt.splat %in1 : f32 -> tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK: %[[LHS_LOCAL_ALLOC:.*]] = triton_gpu.local_alloc + // CHECK: %[[RHS_CVT:.*]] = triton_gpu.convert_layout {{.*}} #triton_gpu.dot_op<{{.*}}> -> {{.*}} #[[$BLOCKED]] + // CHECK: %[[RHS_LOCAL_ALLOC:.*]] = triton_gpu.local_alloc %[[RHS_CVT]] + // CHECK: triton_nvidia_gpu.warp_group_dot %[[LHS_LOCAL_ALLOC]], %[[RHS_LOCAL_ALLOC]] + %res = tt.dot %in0, %splat_in1, %in2, inputPrecision = tf32 : tensor<64x256xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x32xf32, #blocked> + tt.return %res : tensor<64x32xf32, #blocked> + } +} diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index d680c08c1852..43dbeea6c2a1 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -108,12 +108,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: tt.func @matmul_loop // CHECK: %{{.*}}:6 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) // Stage 0.a -// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG6]], %{{.*}} // CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} // CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_21]] // CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]] -// CHECK: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_20]], %[[SPLAT_23]], %{{.*}} -// CHECK: %[[ADDPTR_25:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_20]], %[[SPLAT_23]] +// CHECK: %[[ADDPTR_25:.*]] = tt.addptr %[[ARG7]], %{{.*}} // CHECK: %[[SPLAT_26:.*]] = tt.splat %[[CMPI_22]] // CHECK: %[[LOAD_27:.*]] = tt.load %[[ADDPTR_25]], %[[SPLAT_26]] // Stage 1 @@ -126,10 +126,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[CMPI_33:.*]] = arith.cmpi slt, %[[ADDI_32]], %{{.*}} // CHECK: %[[SELECT_34:.*]] = arith.select %[[CMPI_33]], %[[ADDI_32]], %{{.*}} // CHECK: %[[MEMDESC_SUBVIEW_35:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[LOAD_27]], %[[MEMDESC_SUBVIEW_35]] +// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_35]] // CHECK: %[[MEMDESC_SUBVIEW_36:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_36]] -// CHECK: scf.yield %[[ADDPTR_25]], %[[ADDPTR_20]], %[[DOT_31]], %[[SELECT_34]], %[[MEMDESC_SUBVIEW_35]], %[[MEMDESC_SUBVIEW_36]] +// CHECK: triton_gpu.local_store %[[LOAD_27]], %[[MEMDESC_SUBVIEW_36]] +// CHECK: scf.yield %[[ADDPTR_20]], %[[ADDPTR_25]], %[[DOT_31]], %[[SELECT_34]], %[[MEMDESC_SUBVIEW_35]], %[[MEMDESC_SUBVIEW_36]] // CHECK: } tt.func @matmul_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { @@ -188,36 +188,37 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return %19#2 : tensor<128x128xf32, #mma> } + // This example tests that tt.load overlaps with independent ttg.local_store which // overlaps with independent tt.dot. // num_stages == 3, double buffered // CHECK-LABEL: tt.func @matmul_loop_mb // CHECK: %{{.*}}:8 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) -// Stage 1 -// CHECK: %[[ADDI_28:.*]] = arith.addi %[[ARG9]], %{{.*}} -// CHECK: %[[CMPI_29:.*]] = arith.cmpi slt, %[[ADDI_28]], %{{.*}} -// CHECK: %[[SELECT_30:.*]] = arith.select %[[CMPI_29]], %[[ADDI_28]], %{{.*}} -// CHECK: %[[MEMDESC_SUBVIEW_31:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_30]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_31]] -// CHECK: %[[MEMDESC_SUBVIEW_32:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_30]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_32]] // Stage 0 -// CHECK: %[[ADDPTR_33:.*]] = tt.addptr %[[ARG7]], %{{.*}} -// CHECK: %[[MULI_34:.*]] = arith.muli %{{.*}}, %{{.*}} -// CHECK: %[[SUBI_35:.*]] = arith.subi %{{.*}}, %[[MULI_34]] -// CHECK: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_35]] -// CHECK: %[[SPLAT_37:.*]] = tt.splat %[[CMPI_36]] -// CHECK: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_33]], %[[SPLAT_37]], %{{.*}} -// CHECK: %[[ADDPTR_39:.*]] = tt.addptr %[[ARG6]], %{{.*}} -// CHECK: %[[SPLAT_40:.*]] = tt.splat %[[CMPI_36]] -// CHECK: %[[LOAD_41:.*]] = tt.load %[[ADDPTR_39]], %[[SPLAT_40]] +// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[MULI_29:.*]] = arith.muli %{{.*}}, %{{.*}} +// CHECK: %[[SUBI_30:.*]] = arith.subi %{{.*}}, %[[MULI_29]] +// CHECK: %[[CMPI_31:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_30]] +// CHECK: %[[SPLAT_32:.*]] = tt.splat %[[CMPI_31]] +// CHECK: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_32]] +// CHECK: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// CHECK: %[[SPLAT_35:.*]] = tt.splat %[[CMPI_31]] +// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]], %[[SPLAT_35]] +// Stage 1 +// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} +// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_40]] +// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_41]] // Stage 2 // CHECK: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[ARG10]] // CHECK: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[ARG11]] // CHECK: %[[MULF_44:.*]] = arith.mulf %[[LOCAL_LOAD_43]], %{{.*}} // CHECK: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_42]], %[[MULF_44]], %[[ARG8]] -// CHECK: scf.yield %[[ADDPTR_39]], %[[ADDPTR_33]], %[[DOT_45]], %[[SELECT_30]], %[[MEMDESC_SUBVIEW_32]], %[[MEMDESC_SUBVIEW_31]], %[[LOAD_41]], %[[LOAD_38]] +// CHECK: scf.yield %[[ADDPTR_28]], %[[ADDPTR_34]], %[[DOT_45]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]], %[[LOAD_33]], %[[LOAD_36]] // CHECK: } tt.func @matmul_loop_mb(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { @@ -291,7 +292,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: tt.func @indirect_bmm_vector // CHECK: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) // Stage 0 -// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG8]], %{{.*}} // CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} // CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_21]] // CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]] @@ -301,13 +302,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[BROADCAST_26:.*]] = tt.broadcast %[[EXPAND_DIMS_25]] // CHECK: %[[MULI_27:.*]] = arith.muli %{{.*}}, %[[BROADCAST_26]] // CHECK: %[[ADDPTR_28:.*]] = tt.addptr %{{.*}}, %[[MULI_27]] -// CHECK: %[[SUBI_29:.*]] = arith.subi %{{.*}}, %{{.*}} -// CHECK: %[[CMPI_30:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_29]] -// CHECK: %[[SPLAT_31:.*]] = tt.splat %[[CMPI_30]] -// CHECK: %[[LOAD_32:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_31]] -// CHECK: %[[ADDPTR_33:.*]] = tt.addptr %[[ARG8]], %{{.*}} -// CHECK: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_30]] -// CHECK: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_33]], %[[SPLAT_34]] +// CHECK: %[[SPLAT_29:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_30:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_29]] +// CHECK: %[[ADDPTR_31:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[SUBI_32:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_33:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_32]] +// CHECK: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_33]] +// CHECK: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_31]], %[[SPLAT_34]] // Stage 2 // CHECK: %[[LOCAL_LOAD_36:.*]] = triton_gpu.local_load %[[ARG11]] // CHECK: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[ARG12]] @@ -317,10 +318,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[CMPI_40:.*]] = arith.cmpi slt, %[[ADDI_39]], %{{.*}} // CHECK: %[[SELECT_41:.*]] = arith.select %[[CMPI_40]], %[[ADDI_39]], %{{.*}} // CHECK: %[[MEMDESC_SUBVIEW_42:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[LOAD_35]], %[[MEMDESC_SUBVIEW_42]] +// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_42]] // CHECK: %[[MEMDESC_SUBVIEW_43:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[LOAD_32]], %[[MEMDESC_SUBVIEW_43]] -// CHECK: scf.yield %[[DOT_38]], %[[ADDPTR_33]], %[[ADDPTR_20]], %[[SELECT_41]], %[[MEMDESC_SUBVIEW_42]], %[[MEMDESC_SUBVIEW_43]], %[[LOAD_24]] +// CHECK: triton_gpu.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_43]] +// CHECK: scf.yield %[[DOT_38]], %[[ADDPTR_20]], %[[ADDPTR_31]], %[[SELECT_41]], %[[MEMDESC_SUBVIEW_42]], %[[MEMDESC_SUBVIEW_43]], %[[LOAD_35]] // CHECK: } tt.func @indirect_bmm_vector(%arg0: tensor<16x16xi64, #blocked> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> { diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index ecee359cb19a..f015f9651065 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -133,3 +133,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.return %2 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> } } // end module + +// ----- + +// CHECK: #[[$BLOCKED:.*]] = #triton_gpu.blocked +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @cvt_from_dot_op_into_local_allow_not_canonicalized(%in: tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> !tt.memdesc<256x32xf32, #shared1> { + // CHECK-LABEL: cvt_from_dot_op_into_local_allow_not_canonicalized + %cvt_in = triton_gpu.convert_layout %in : tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<256x32xf32, #blocked> + %alloc = triton_gpu.local_alloc %cvt_in : (tensor<256x32xf32, #blocked>) -> !tt.memdesc<256x32xf32, #shared1> + // CHECK: %[[ALLOC:.*]] = triton_gpu.local_alloc {{.*}} (tensor<{{.*}}, #[[$BLOCKED]]{{.*}}>) -> + tt.return %alloc : !tt.memdesc<256x32xf32, #shared1> + } +} // end module + diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index 3a9e80c04b0a..848e2b88f342 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -173,3 +173,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +// CHECK: tt.func @matmul_loop_on_blocked_layout +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @matmul_loop_on_blocked_layout(%arg_lhs: !tt.memdesc<16x512xf32, #shared, mutable>, %arg_rhs: !tt.memdesc<512x32xf32, #shared, mutable>, %arg_init: tensor<16x32xf32, #blocked>, %itr_val : i32) -> (tensor<16x32xf32, #blocked>) { + %loop:3 = scf.for %itr = %itr_val to %itr_val step %itr_val iter_args(%init = %arg_init, %lhs = %arg_lhs, %rhs = %arg_rhs) -> (tensor<16x32xf32, #blocked>, !tt.memdesc<16x512xf32, #shared, mutable>, !tt.memdesc<512x32xf32, #shared, mutable>) : i32 { + %lhs_ll = triton_gpu.local_load %lhs : !tt.memdesc<16x512xf32, #shared, mutable> -> tensor<16x512xf32, #blocked> + %lhs_ll_cvt = triton_gpu.convert_layout %lhs_ll : tensor<16x512xf32, #blocked> -> tensor<16x512xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %rhs_ll = triton_gpu.local_load %rhs : !tt.memdesc<512x32xf32, #shared, mutable> -> tensor<512x32xf32, #blocked> + %rhs_ll_cvt = triton_gpu.convert_layout %rhs_ll : tensor<512x32xf32, #blocked> -> tensor<512x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %res = tt.dot %lhs_ll_cvt, %rhs_ll_cvt, %init : tensor<16x512xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<512x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x32xf32, #blocked> + scf.yield %res, %lhs, %rhs : tensor<16x32xf32, #blocked>, !tt.memdesc<16x512xf32, #shared, mutable>, !tt.memdesc<512x32xf32, #shared, mutable> + } + tt.return %loop#0 : tensor<16x32xf32, #blocked> + } +} // end module diff --git a/third_party/amd/BUILD b/third_party/amd/BUILD new file mode 100644 index 000000000000..f7510899f69e --- /dev/null +++ b/third_party/amd/BUILD @@ -0,0 +1,144 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:gce"], + # default_visibility = [ + # "//third_party/tensorflow/compiler/xla/service/gpu/fusions/triton:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +# TODO(csigg): fix, enable error upstream, remove. +_no_unused_variable = select({ + "//:compiler_is_msvc": [], + "//conditions:default": ["-Wno-unused-variable"], +}) + +cc_library( + name = "TritonAMDGPUTransforms", + srcs = glob([ + "lib/TritonAMDGPUTransforms/**/*.h", + "lib/TritonAMDGPUTransforms/**/*.cpp", + ]) + ["include/TritonAMDGPUToLLVM/TargetUtils.h"], + hdrs = glob([ + "include/TritonAMDGPUTransforms/**/*.h", + ]), + copts = _no_unused_variable, + includes = [ + "include", + "lib/TritonAMDGPUTransforms", + ], + deps = [ + ":triton_conversion_amdgpu_transforms_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConvertToLLVM", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonGPUTransforms", + ], +) + +cc_library( + name = "TritonAMDGPUToLLVM", + srcs = glob([ + "lib/TritonAMDGPUToLLVM/**/*.h", + "lib/TritonAMDGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/TritonAMDGPUToLLVM/**/*.h", + ]), + copts = _no_unused_variable, + includes = [ + "include", + "lib/TritonAMDGPUToLLVM", + ], + deps = [ + ":TritonAMDGPUTransforms", + ":triton_conversion_amdgpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:ConvertToLLVM", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUToROCDLTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + ], +) + +td_library( + name = "td_files", + srcs = glob(["include/**/*.td"]), + includes = ["include"], + deps = ["//:td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_amdgpu_to_llvm_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonAMDGPUToLLVM", + ], + "include/TritonAMDGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonAMDGPUToLLVM/Passes.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_amdgpu_transforms_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonAMDGPU", + ], + "include/TritonAMDGPUTransforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonAMDGPUTransforms/Passes.td", + deps = [":td_files"], +) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp index 23dd5d37d520..cece47227ea0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -1,3 +1,4 @@ +#include "OptimizeLDSUtility.h" #include "TargetInfo.h" #include "TritonAMDGPUToLLVM/Passes.h" #include "mlir/Pass/Pass.h" @@ -5,7 +6,6 @@ #include "triton/Analysis/Utility.h" #include "triton/Conversion/TritonGPUToLLVM/Patterns.h" #include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include @@ -20,79 +20,6 @@ namespace triton { namespace { -constexpr int kPtrBitWidth = 64; - -static void addAttrs(Operation *op, ArrayRef attrs) { - for (const NamedAttribute attr : attrs) - op->setAttr(attr.getName(), attr.getValue()); -} - -static int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp &cvtOp) { - auto scratchConfig = mlir::triton::getScratchConfigForCvt( - cvtOp.getSrc().getType(), cvtOp.getType()); - unsigned elems = getNumScratchElements(scratchConfig.paddedRepShape); - auto srcType = cvtOp.getSrc().getType(); - auto bytes = - isa(srcType.getElementType()) - ? elems * kPtrBitWidth / 8 - : elems * std::max(8, srcType.getElementTypeBitWidth()) / 8; - - return bytes; -} - -static std::vector> factorizePowerOf2(int n) { - assert(llvm::isPowerOf2_32(n)); - int x = log2(n); - std::vector> pairs; - - for (int i = 0; i <= x / 2; ++i) { - int j = x - i; - pairs.push_back({pow(2, i), pow(2, j)}); - pairs.push_back({pow(2, j), pow(2, i)}); - } - - return pairs; -} - -static std::pair -createNewConvertOps(ModuleOp &mod, OpBuilder &builder, - triton::gpu::ConvertLayoutOp &cvtOp, - std::pair warpsPerCta) { - unsigned warpsPerCtaX = warpsPerCta.first; - unsigned warpsPerCtaY = warpsPerCta.second; - auto srcType = cvtOp.getSrc().getType(); - auto dstType = cvtOp.getType(); - - auto newDstType = RankedTensorType::get( - dstType.getShape(), dstType.getElementType(), dstType.getEncoding()); - RankedTensorType newSrcType; - if (auto srcMfma = - dyn_cast(srcType.getEncoding())) { - auto newMfmaEnc = triton::gpu::AMDMfmaEncodingAttr::get( - mod.getContext(), srcMfma.getVersionMajor(), srcMfma.getVersionMinor(), - {warpsPerCtaX, warpsPerCtaY}, srcMfma.getMDim(), srcMfma.getNDim(), - srcMfma.getIsTransposed(), srcMfma.getCTALayout()); - - newSrcType = RankedTensorType::get(srcType.getShape(), - srcType.getElementType(), newMfmaEnc); - } else if (auto srcWmma = dyn_cast( - srcType.getEncoding())) { - auto newWmmaEnc = triton::gpu::AMDWmmaEncodingAttr::get( - mod.getContext(), srcWmma.getVersion(), {warpsPerCtaX, warpsPerCtaY}, - srcWmma.getCTALayout()); - - newSrcType = RankedTensorType::get(srcType.getShape(), - srcType.getElementType(), newWmmaEnc); - } - - auto tmpCvt = builder.create( - cvtOp.getLoc(), newSrcType, cvtOp.getSrc()); - auto newEpilogueCvt = builder.create( - cvtOp.getLoc(), newDstType, tmpCvt); - - return std::make_pair(tmpCvt, newEpilogueCvt); -} - struct DecomposeUnsupportedAMDConversions : public mlir::triton::impl::DecomposeUnsupportedAMDConversionsBase< DecomposeUnsupportedAMDConversions> { @@ -172,52 +99,48 @@ struct DecomposeUnsupportedAMDConversions return; } - auto currLDSUsage = getCvtOpLDSUsage(cvtOp); + auto currLDSUsage = triton::AMD::getCvtOpLDSUsage(cvtOp); if (currLDSUsage <= sharedMemoryLimit) { return; } unsigned numWarps = triton::gpu::getNumWarpsPerCTA(srcEnc); - triton::gpu::ConvertLayoutOp tmpCvt; - triton::gpu::ConvertLayoutOp newEpilogueCvt; - // Find all possible shapes of WarpsPerCTA by finding all possible // factorizations of numWarps. Pick shape for which both conversions in - // decomposition use LDS less than limit and for which sum of LDS usage - // is minimal. If no such shape exists, do not decompose. + // decomposition use LDS less than sharedMemoryLimit and for which sum of + // LDS usage is minimal. If no such shape exists, do not decompose. unsigned minLDSUsage = 2 * sharedMemoryLimit; int minIdx = -1; - auto factorizedNumWarps = factorizePowerOf2(numWarps); + int rank = dstBlocked.getWarpsPerCTA().size(); + auto factorizedNumWarps = + mlir::triton::AMD::factorizePowerOf2(numWarps, rank); + SmallVector tmpLayouts; for (int i = 0; i < factorizedNumWarps.size(); i++) { - auto warpsPerCTAPair = factorizedNumWarps[i]; - std::tie(tmpCvt, newEpilogueCvt) = - createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair); - - int tmpCvtLDS = getCvtOpLDSUsage(tmpCvt); - int newCvtLDS = getCvtOpLDSUsage(newEpilogueCvt); - if (tmpCvtLDS <= sharedMemoryLimit && newCvtLDS <= sharedMemoryLimit) { - int LDSUsage = tmpCvtLDS + newCvtLDS; - if (LDSUsage < minLDSUsage) { - minLDSUsage = LDSUsage; - minIdx = i; - } + auto warpsPerCTA = factorizedNumWarps[i]; + tmpLayouts.push_back( + mlir::triton::AMD::createTmpLayout(srcEnc, warpsPerCTA)); + } + + for (int i = 0; i < tmpLayouts.size(); i++) { + auto resources = mlir::triton::AMD::estimateResourcesForReplacement( + builder, cvtOp, tmpLayouts[i]); + if (resources.LDS <= sharedMemoryLimit && resources.LDS < minLDSUsage) { + minLDSUsage = resources.LDS; + minIdx = i; } - newEpilogueCvt.erase(); - tmpCvt.erase(); } - if (minIdx == -1) { + if (minIdx == -1 || minLDSUsage > sharedMemoryLimit) { return; } - assert(minIdx >= 0 && minIdx < factorizedNumWarps.size()); - auto warpsPerCTAPair = factorizedNumWarps[minIdx]; - std::tie(tmpCvt, newEpilogueCvt) = - createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair); + assert(minIdx >= 0 && minIdx < tmpLayouts.size()); + auto replacementCvts = mlir::triton::AMD::createNewConvertOps( + builder, cvtOp, tmpLayouts[minIdx]); - cvtOp.replaceAllUsesWith(newEpilogueCvt.getResult()); + cvtOp.replaceAllUsesWith(replacementCvts.second.getResult()); cvtOp.erase(); }); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 0d5b6476c80d..ced5933c3fc0 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -69,7 +69,7 @@ using namespace mlir; class PointerCanonicalizer { public: explicit PointerCanonicalizer(ModuleOp moduleOp) - : mod(moduleOp), rewriter(moduleOp.getContext()) {} + : rewriter(moduleOp.getContext()), mod(moduleOp) {} // Propagate fat pointers in all the functions of the module LogicalResult run(); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index c46b76ed48e2..41ab479336de 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -174,7 +174,7 @@ class TritonAMDGPUReorderInstructionsPass // Best perf on GEMM when these precede global loads. m.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); - for (auto op : moveOps) { + for (auto op : llvm::reverse(moveOps)) { // Gather use-def chain in block. Block *block = op->getBlock(); bool leadsToLoad = false; diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 221dba8295ff..56756b5fed25 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -197,8 +197,7 @@ void init_triton_amd(py::module &&m) { target->createMCAsmBackend(*sti, *mri, mcOptions)); std::unique_ptr ow(mab->createObjectWriter(svos)); mcStreamer.reset(target->createMCObjectStreamer( - triple, ctx, std::move(mab), mab->createObjectWriter(svos), - std::move(ce), *sti)); + triple, ctx, std::move(mab), std::move(ow), std::move(ce), *sti)); std::unique_ptr parser( createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai)); diff --git a/third_party/f2reduce/BUILD b/third_party/f2reduce/BUILD new file mode 100644 index 000000000000..a1a4f3a7ca02 --- /dev/null +++ b/third_party/f2reduce/BUILD @@ -0,0 +1,31 @@ +# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:gce"], + # default_visibility = [ + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +# copybara:uncomment_begin +# license( +# name = "license", +# license_text = "LICENCE.txt", +# ) +# +# licenses(["notice"]) +# +# exports_files(["LICENCE.txt"]) +# copybara:uncomment_end + +cc_library( + name = "f2reduce", + srcs = ["f2reduce.cpp"], + hdrs = ["f2reduce.h"], + # copybara:uncomment strip_include_prefix = "/third_party/triton", +) diff --git a/third_party/nvidia/BUILD b/third_party/nvidia/BUILD new file mode 100644 index 000000000000..6af127c11ec6 --- /dev/null +++ b/third_party/nvidia/BUILD @@ -0,0 +1,306 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@pybind11_bazel//:build_defs.bzl", "pybind_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:gce"], + # default_visibility = [ + # "//third_party/tensorflow/compiler/xla/service/gpu:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +pybind_library( + name = "cublas_headers", + hdrs = glob([ + "include/*.h", + ]), + deps = ["@local_config_cuda//cuda:cuda_headers"], +) + +pybind_library( + name = "triton_nvidia", + srcs = [ + "triton_nvidia.cc", + ], + compatible_with = [], + # copybara:uncomment_begin + # visibility = [ + # "@triton//python:__subpackages__", + # ], + # copybara:uncomment_end + deps = [ + ":NVGPUDialect", + ":NVGPUToLLVM", + ":TritonNVIDIAGPUToLLVM", + ":cublas_headers", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonNvidiaGPUTransforms", + "@triton//python:passes", + ], +) + +cc_library( + name = "NVGPUToLLVM", + srcs = glob([ + "lib/NVGPUToLLVM/*.cpp", + ]), + hdrs = glob([ + "include/NVGPUToLLVM/*.h", + ]), + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:gce"], + # copybara:uncomment_end + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + includes = [ + "..", + "include", + ], + deps = [ + ":NVGPUDialect", + ":TritonNVIDIAGPUToLLVM", + ":triton_conversion_nvgpu_to_llvm_passes_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + ], +) + +cc_library( + name = "TritonNVIDIAGPUToLLVM", + srcs = glob([ + "lib/TritonNVIDIAGPUToLLVM/*.h", + "lib/TritonNVIDIAGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/TritonNVIDIAGPUToLLVM/*.h", + ]) + [ + "lib/TritonNVIDIAGPUToLLVM/Utility.h", + ], + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:gce"], + # copybara:uncomment_end + copts = select({ + "//conditions:default": [ + "-Wno-reorder-ctor", + "-Wno-unused-variable", + ], + }), + includes = [ + "..", + "include", + "lib/TritonNVIDIAGPUToLLVM", + ], + deps = [ + ":NVGPUDialect", + ":triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:triton_gpu_attr_inc_gen", + ], +) + +gentbl_cc_library( + name = "triton_conversion_nvgpu_to_llvm_passes_inc_gen", + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:gce"], + # copybara:uncomment_end + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=NVGPUToLLVM", + ], + "include/NVGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/NVGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:gce"], + # copybara:uncomment_end + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNVIDIAGPUToLLVM", + ], + "include/TritonNVIDIAGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonNVIDIAGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) + +td_library( + name = "td_files", + srcs = glob(["include/Dialect/NVGPU/IR/*.td"]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:CastInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "nvgpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-llvmir-conversions"], + "include/Dialect/NVGPU/IR/OpsConversions.inc", + ), + ( + ["--gen-op-decls"], + "include/Dialect/NVGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/Dialect/NVGPU/IR/Ops.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/Dialect/NVGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/Dialect/NVGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "nvgpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "nvgpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/Dialect/NVGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/Dialect/NVGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUDialect.td", + deps = ["td_files"], +) + +cc_library( + name = "NVGPUDialect", + srcs = glob([ + "lib/Dialect/NVGPU/IR/*.cpp", + ]), + hdrs = glob([ + "include/Dialect/NVGPU/IR/*.h", + ]), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-logical-op-parentheses", + ], + }), + includes = [ + "..", # because nvidia/include/Dialect/NVGPU/IR/Dialect.h.inc + "../..", # because third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h + "include", + ], + deps = [ + ":nvgpu_attr_inc_gen", + ":nvgpu_dialect_inc_gen", + ":nvgpu_ops_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + # The following is added to make Utility compile + "//:TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/third_party/nvidia/backend/BUILD b/third_party/nvidia/backend/BUILD new file mode 100644 index 000000000000..a5b34aa5c29b --- /dev/null +++ b/third_party/nvidia/backend/BUILD @@ -0,0 +1,30 @@ +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__subpackages__", + ], +) + +pybind_extension( + name = "cuda_utils", + srcs = ["cuda_utils.cc"], + visibility = [ + "//learning/deepmind/jax/triton/ops:__subpackages__", + "//third_party/py/triton:__subpackages__", + ], + deps = [ + "//platforms/gpus/cuda/dynamic_libcuda", + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cuda_runtime", + "@llvm-project//llvm:Support", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["**/*.py"], + ), +) diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index bb0d86888120..19c732c354d1 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -154,6 +154,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { typedef CUresult (*cuOccupancyMaxActiveClusters_t)( int *numClusters, CUfunction func, const CUlaunchConfig *config); +#if CUDA_VERSION >= 12000 typedef CUresult (*cuTensorMapEncodeTiled_t)( CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, @@ -161,6 +162,7 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)( const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); +#endif #define defineGetFunctionHandle(name, symbolName) \ static symbolName##_t name() { \ @@ -187,8 +189,10 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)( defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, cuOccupancyMaxActiveClusters); +#if CUDA_VERSION >= 12000 defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, cuTensorMapEncodeTiled); +#endif static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, @@ -281,6 +285,9 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { // Simple helper to experiment creating TMA descriptors on the host. // This is a useful to test TMA operations independently. static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { +#if CUDA_VERSION < 12000 + return NULL; +#else unsigned long long global_address; uint64_t dim; uint32_t tensorDim; @@ -321,11 +328,15 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); Py_INCREF(Py_None); return Py_None; +#endif } // Simple helper to experiment creating TMA descriptors on the host. // This is a useful to test TMA operations independently. static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { +#if CUDA_VERSION < 12000 + return NULL; +#else unsigned long long global_address; uint64_t dims[2]; uint32_t tensorDims[2]; @@ -384,6 +395,7 @@ static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); Py_INCREF(Py_None); return Py_None; +#endif } static PyMethodDef ModuleMethods[] = { diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 8de0efefca84..637071275e39 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -291,10 +291,36 @@ class WGMMAWaitGroupOpPattern : public OpRewritePattern { Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const { auto outputStructType = cast(op.getType()); - uint32_t numOutputRegs = outputStructType.getBody().size(); - std::string output = - outputStructType.getBody().front().isF32() ? "=f" : "=r"; - return Constraints(numOutputRegs, output); + std::vector outputConstraints; + outputConstraints.reserve(outputStructType.getBody().size()); + for (mlir::Type type : outputStructType.getBody()) { + if (type.isF32()) { + outputConstraints.push_back("=f"); + continue; + } else if (type.isF64()) { + outputConstraints.push_back("=d"); + continue; + } + unsigned bitwidth = isa(type) ? + 64 : type.getIntOrFloatBitWidth(); + switch (bitwidth) { + case 1: + outputConstraints.push_back("=b"); + break; + case 16: + outputConstraints.push_back("=h"); + break; + case 32: + outputConstraints.push_back("=r"); + break; + case 64: + outputConstraints.push_back("=l"); + break; + default: + assert(false && "unsupported bitwidth"); + } + } + return outputConstraints; } OperandsAndConstraints diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index c10a6e777987..1bb55373e046 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -316,11 +316,6 @@ SmallVector unpackAccumulator(ConversionPatternRewriter &rewriter, return results; } -static bool isFP8(triton::nvgpu::WGMMAEltType eltType) { - return eltType == triton::nvgpu::WGMMAEltType::e5m2 || - eltType == triton::nvgpu::WGMMAEltType::e4m3; -} - static Value faddAccumulate(ConversionPatternRewriter &rewriter, Location loc, Value a, Value b) { int numEl = cast(a.getType()).getBody().size(); @@ -359,6 +354,7 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, Operation *op, Value a, Value b, Value c, Value d, Value useCOperand, Value loadedA, Value loadedB, Value loadedC, bool allowTF32, + bool needsPartialAccumulator, uint32_t maxNumImpreciseAcc, bool sync, Value thread) { auto aTensorTy = cast(a.getType()); auto bTensorTy = cast(b.getType()); @@ -420,10 +416,6 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, auto func = op->getParentOfType(); Operation *startSequence = rewriter.create(loc); - // WGMMA fp8 -> fp32 accumulates in lower precision than fp32. - bool needsPartialAccumulator = isFP8(eltTypeA) && - eltTypeC == triton::nvgpu::WGMMAEltType::f32 && - maxNumImpreciseAcc <= aTensorTy.getShape()[1]; SmallVector mmaResults; for (int m = 0; m < numRepM; ++m) { for (int n = 0; n < numRepN; ++n) { @@ -439,7 +431,7 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, Value useC = i1_val(0); if (!zeroAcc) { d = packLLElements(loc, typeConverter, mmaOut, rewriter, accTy); - useC = i1_val(true); + useC = i1_val(1); } if (useCOperand) useC = and_(useC, useCOperand); @@ -520,5 +512,6 @@ LogicalResult convertWGMMA(triton::nvidia_gpu::WarpGroupDotOp op, op.getA(), op.getB(), op.getC(), op.getD(), op.getUseC(), // adaptor.getA(), adaptor.getB(), adaptor.getC(), // op.getInputPrecision() == InputPrecision::TF32, - op.getMaxNumImpreciseAcc(), !op.getIsAsync(), thread); + op.needsPartialAccumulator(), op.getMaxNumImpreciseAcc(), + !op.getIsAsync(), thread); } diff --git a/third_party/nvidia/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc index 1269dcda00aa..3cccc5fb6a1c 100644 --- a/third_party/nvidia/triton_nvidia.cc +++ b/third_party/nvidia/triton_nvidia.cc @@ -1,4 +1,4 @@ -#include "Dialect/NVGPU/IR/Dialect.h" +#include "Dialect/NVGPU/IR/Dialect.h" #include "NVGPUToLLVM/NVGPUToLLVMPass.h" #include "TritonNVIDIAGPUToLLVM/Passes.h" #include "cublas_instance.h" diff --git a/third_party/proton/proton/_C/include b/third_party/proton/proton/_C/include index fe4f4a1aa9bd..4400934bdf78 120000 --- a/third_party/proton/proton/_C/include +++ b/third_party/proton/proton/_C/include @@ -1 +1 @@ -../../csrc/include/ \ No newline at end of file +../../csrc/include \ No newline at end of file diff --git a/unittest/BUILD b/unittest/BUILD new file mode 100644 index 000000000000..cb885459e19c --- /dev/null +++ b/unittest/BUILD @@ -0,0 +1,144 @@ +load("//tools/build_defs/build_test:build_test.bzl", "build_test") + +package( + default_applicable_licenses = ["//:license"], + default_compatible_with = ["//buildenv/target:gce"], + default_visibility = ["//:__subpackages__"], +) + +cc_test( + name = "AnalysisTest", + srcs = glob(["Analysis/*.cpp"]), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "//:TritonDialects", + ], +) + +cc_test( + name = "DialectTestCatchAll", + srcs = glob( + [ + "Dialect/**/*.cpp", + ], + exclude = [ + "Dialect/TritonGPU/DialectTest.cpp", + "Dialect/TritonGPU/LinearLayoutConversionsTest.cpp", + "Dialect/TritonGPU/SwizzleTest.cpp", + ], + ), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "DialectTest", + srcs = [ + "Dialect/TritonGPU/DialectTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "LinearLayoutConversionsTest", + srcs = [ + "Dialect/TritonGPU/LinearLayoutConversionsTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "SwizzleTest", + srcs = [ + "Dialect/TritonGPU/SwizzleTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "ConversionTest", + srcs = glob( + [ + "Conversion/**/*.cpp", + "Conversion/**/*.h", + ], + exclude = [ + "Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.h", + ], + ), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "//:TritonDialects", + "//:TritonNvidiaGPUTransforms", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +build_test( + name = "build_test", + allow_empty_target = False, + targets = [ + ":ConversionTest", + ":AnalysisTest", + ":DialectTest", + ], +)