From 09d5d6b198a6af9bdab0ee5c211a459070f71b88 Mon Sep 17 00:00:00 2001 From: Someone Serge Date: Tue, 21 Mar 2023 03:29:07 +0200 Subject: [PATCH 01/10] python3Packages.torch: 1.13.1 -> 2.0.0 --- .../python-modules/torch/default.nix | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix index 062fcea4334ab..3373c3d3d3b2c 100644 --- a/pkgs/development/python-modules/torch/default.nix +++ b/pkgs/development/python-modules/torch/default.nix @@ -12,6 +12,7 @@ Accelerate, CoreServices, libobjc, # Propagated build inputs + sympy, numpy, pyyaml, cffi, click, typing-extensions, # Unit tests @@ -49,9 +50,7 @@ let inherit (cudaPackages) cudatoolkit cudaFlags cudnn nccl; in -# assert that everything needed for cuda is present and that the correct cuda versions are used -assert !cudaSupport || (let majorIs = lib.versions.major cudatoolkit.version; - in majorIs == "9" || majorIs == "10" || majorIs == "11"); +assert cudaSupport -> (cudaPackages.cudaMajorVersion == "11"); # confirm that cudatoolkits are sync'd across dependencies assert !(MPISupport && cudaSupport) || mpi.cudatoolkit == cudatoolkit; @@ -129,10 +128,10 @@ let in buildPythonPackage rec { pname = "torch"; # Don't forget to update torch-bin to the same version. - version = "1.13.1"; + version = "2.0.0"; format = "setuptools"; - disabled = pythonOlder "3.7.0"; + disabled = pythonOlder "3.8.0"; outputs = [ "out" # output standard python package @@ -145,7 +144,7 @@ in buildPythonPackage rec { repo = "pytorch"; rev = "refs/tags/v${version}"; fetchSubmodules = true; - hash = "sha256-yQz+xHPw9ODRBkV9hv1th38ZmUr/fXa+K+d+cvmX3Z8="; + hash = "sha256-cSw7+AYBUcZLz3UyK/+JWWjQxKwVBXcFvBq0XAcL3tE="; }; patches = lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [ @@ -155,15 +154,6 @@ in buildPythonPackage rec { # base is 10.12. Until we upgrade, we can fall back on the older # pthread support. ./pthreadpool-disable-gcd.diff - ] ++ [ - # PyTorch fails to build on gcc 12 due to gloo - # https://github.com/pytorch/pytorch/issues/77614 - (fetchpatch { - url = "https://github.com/facebookincubator/gloo/commit/4a5e339b764261d20fc409071dc7a8b8989aa195.patch"; - stripLen = 1; - extraPrefix = "third_party/gloo/"; - hash = "sha256-UxR1r7F6g76BWj3GBIrSy5t+YZDCWy6mMddwx+hon5w="; - }) ]; postPatch = lib.optionalString rocmSupport '' @@ -261,7 +251,16 @@ in buildPythonPackage rec { # Suppress gcc regression: avx512 math function raises uninitialized variable warning # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105593 # See also: Fails to compile with GCC 12.1.0 https://github.com/pytorch/pytorch/issues/77939 - ++ lib.optionals stdenv.cc.isGNU [ "-Wno-error=maybe-uninitialized" "-Wno-error=uninitialized" ])); + ++ lib.optionals stdenv.cc.isGNU [ + "-Wno-error=maybe-uninitialized" + "-Wno-error=uninitialized" + ] + # Since pytorch 2.0: + # gcc-12.2.0/include/c++/12.2.0/bits/new_allocator.h:158:33: error: ‘void operator delete(void*, std::size_t)’ + # ... called on pointer ‘’ with nonzero offset [1, 9223372036854775800] [-Werror=free-nonheap-object] + ++ lib.optionals (stdenv.cc.isGNU && lib.versions.major stdenv.cc.version == "12" ) [ + "-Wno-error=free-nonheap-object" + ])); nativeBuildInputs = [ cmake @@ -287,6 +286,7 @@ in buildPythonPackage rec { numpy pyyaml typing-extensions + sympy # the following are required for tensorboard support pillow six future tensorboard protobuf ] ++ lib.optionals MPISupport [ mpi ] From a9faf1b9efee58458e0174e696471a7db3ff97ed Mon Sep 17 00:00:00 2001 From: Someone Serge Date: Tue, 21 Mar 2023 03:34:20 +0200 Subject: [PATCH 02/10] python3Packages.torch: add missing install_requires --- pkgs/development/python-modules/torch/default.nix | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix index 3373c3d3d3b2c..9f701493f07ae 100644 --- a/pkgs/development/python-modules/torch/default.nix +++ b/pkgs/development/python-modules/torch/default.nix @@ -12,7 +12,10 @@ Accelerate, CoreServices, libobjc, # Propagated build inputs + filelock, sympy, + networkx, + jinja2, numpy, pyyaml, cffi, click, typing-extensions, # Unit tests @@ -285,8 +288,14 @@ in buildPythonPackage rec { click numpy pyyaml + + # From install_requires: + filelock typing-extensions sympy + networkx + jinja2 + # the following are required for tensorboard support pillow six future tensorboard protobuf ] ++ lib.optionals MPISupport [ mpi ] From 0f76efb48143e4f38f7714588e3cb38d055404b2 Mon Sep 17 00:00:00 2001 From: Someone Serge Date: Tue, 21 Mar 2023 16:30:06 +0200 Subject: [PATCH 03/10] python3Packages.torchWithRocm: ignore config.cudaSupport --- pkgs/top-level/python-packages.nix | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index 1e9075ace6ad0..5349c922a4e91 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -11856,6 +11856,7 @@ self: super: with self; { torchWithCuda = self.torch.override { magma = pkgs.magma-cuda; cudaSupport = true; + rocmSupport = false; }; torchWithoutCuda = self.torch.override { @@ -11865,6 +11866,7 @@ self: super: with self; { torchWithRocm = self.torch.override { magma = pkgs.magma-hip; rocmSupport = true; + cudaSupport = false; }; torchWithoutRocm = self.torch.override { From 378c0c69832c350f672a21128b5d2782a29c2d6e Mon Sep 17 00:00:00 2001 From: Someone Serge Date: Wed, 22 Mar 2023 05:34:47 +0200 Subject: [PATCH 04/10] python3Packages.openai-triton: init at 2.0.0 --- pkgs/development/compilers/llvm/rocm/llvm.nix | 6 +- .../python-modules/openai-triton/default.nix | 246 + .../python-modules/openai-triton/llvm15.patch | 4617 +++++++++++++++++ .../python-modules/torch/default.nix | 25 +- pkgs/tools/audio/tts/default.nix | 2 + pkgs/top-level/python-packages.nix | 2 + 6 files changed, 4893 insertions(+), 5 deletions(-) create mode 100644 pkgs/development/python-modules/openai-triton/default.nix create mode 100644 pkgs/development/python-modules/openai-triton/llvm15.patch diff --git a/pkgs/development/compilers/llvm/rocm/llvm.nix b/pkgs/development/compilers/llvm/rocm/llvm.nix index 1f1add5cf6799..6092bc1a9fc05 100644 --- a/pkgs/development/compilers/llvm/rocm/llvm.nix +++ b/pkgs/development/compilers/llvm/rocm/llvm.nix @@ -24,6 +24,8 @@ , targetDir ? "llvm" , targetProjects ? [ ] , targetRuntimes ? [ ] +# "NATIVE" resolves into x86 or aarch64 depending on stdenv +, llvmTargetsToBuild ? [ "NATIVE" ] , extraPatches ? [ ] , extraNativeBuildInputs ? [ ] , extraBuildInputs ? [ ] @@ -46,6 +48,8 @@ let if stdenv.isx86_64 then "X86" else if stdenv.isAarch64 then "AArch64" else throw "Unsupported ROCm LLVM platform"; + inferNativeTarget = t: if t == "NATIVE" then llvmNativeTarget else t; + llvmTargetsToBuild' = [ "AMDGPU" ] ++ builtins.map inferNativeTarget llvmTargetsToBuild; in stdenv.mkDerivation (finalAttrs: { pname = "rocm-llvm-${targetName}"; version = "5.4.4"; @@ -98,7 +102,7 @@ in stdenv.mkDerivation (finalAttrs: { sourceRoot = "${finalAttrs.src.name}/${targetDir}"; cmakeFlags = [ - "-DLLVM_TARGETS_TO_BUILD=AMDGPU;${llvmNativeTarget}" + "-DLLVM_TARGETS_TO_BUILD=${builtins.concatStringsSep ";" llvmTargetsToBuild'}" ] ++ lib.optionals (finalAttrs.passthru.isLLVM && targetProjects != [ ]) [ "-DLLVM_ENABLE_PROJECTS=${lib.concatStringsSep ";" targetProjects}" ] ++ lib.optionals ((finalAttrs.passthru.isLLVM || targetDir == "runtimes") && targetRuntimes != [ ]) [ diff --git a/pkgs/development/python-modules/openai-triton/default.nix b/pkgs/development/python-modules/openai-triton/default.nix new file mode 100644 index 0000000000000..9340aad3a9545 --- /dev/null +++ b/pkgs/development/python-modules/openai-triton/default.nix @@ -0,0 +1,246 @@ +{ lib +, buildPythonPackage +, python +, fetchpatch +, fetchFromGitHub +, addOpenGLRunpath +, cmake +, cudaPackages +, llvmPackages +, pybind11 +, gtest +, zlib +, ncurses +, libxml2 +, lit +, filelock +, torchWithRocm +, pytest +, pytestCheckHook +, pythonRelaxDepsHook +, pkgsTargetTarget +}: + +let + pname = "triton"; + version = "2.0.0"; + + inherit (cudaPackages) cuda_cudart backendStdenv; + ptxas = "${pkgsTargetTarget.cudaPackages.cuda_nvcc}/bin/ptxas"; + + llvm = (llvmPackages.llvm.override { + llvmTargetsToBuild = [ "NATIVE" "NVPTX" ]; + # Upstream CI sets these too: + # targetProjects = [ "mlir" ]; + extraCMakeFlags = [ + "-DLLVM_INSTALL_UTILS=ON" + ]; + }); +in +buildPythonPackage { + inherit pname version; + + format = "setuptools"; + + src = fetchFromGitHub { + owner = "openai"; + repo = pname; + rev = "v${version}"; + hash = "sha256-9GZzugab+Pdt74Dj6zjlEzjj4BcJ69rzMJmqcVMxsKU="; + }; + + patches = [ + # Prerequisite for llvm15 patch + (fetchpatch { + url = "https://github.com/openai/triton/commit/2aba985daaa70234823ea8f1161da938477d3e02.patch"; + hash = "sha256-LGv0+Ut2WYPC4Ksi4803Hwmhi3FyQOF9zElJc/JCobk="; + }) + (fetchpatch { + url = "https://github.com/openai/triton/commit/e3941f9d09cdd31529ba4a41018cfc0096aafea6.patch"; + hash = "sha256-A+Gor6qzFlGQhVVhiaaYOzqqx8yO2MdssnQS6TIfUWg="; + }) + + # Source: https://github.com/openai/triton/commit/fc7a8e35819bda632bdcf1cf75fd9abe4d4e077a.patch + # The original patch adds ptxas binary, so we include our own clean copy + # Drop with the next update + ./llvm15.patch + + # TODO: there have been commits upstream aimed at removing the "torch" + # circular dependency, but the patches fail to apply on the release + # revision. Keeping the link for future reference + # Also cf. https://github.com/openai/triton/issues/1374 + + # (fetchpatch { + # url = "https://github.com/openai/triton/commit/fc7c0b0e437a191e421faa61494b2ff4870850f1.patch"; + # hash = "sha256-f0shIqHJkVvuil2Yku7vuqWFn7VCRKFSFjYRlwx25ig="; + # }) + ]; + + postPatch = '' + substituteInPlace python/setup.py \ + --replace \ + '= get_thirdparty_packages(triton_cache_path)' \ + '= os.environ["cmakeFlags"].split()' + '' + # Wiring triton=2.0.0 with llcmPackages_rocm.llvm=5.4.3 + # Revisit when updating either triton or llvm + + '' + substituteInPlace CMakeLists.txt \ + --replace "nvptx" "NVPTX" \ + --replace "LLVM 11" "LLVM" + sed -i '/AddMLIR/a set(MLIR_TABLEGEN_EXE "${llvmPackages.mlir}/bin/mlir-tblgen")' CMakeLists.txt + sed -i '/AddMLIR/a set(MLIR_INCLUDE_DIR ''${MLIR_INCLUDE_DIRS})' CMakeLists.txt + find -iname '*.td' -exec \ + sed -i \ + -e '\|include "mlir/IR/OpBase.td"|a include "mlir/IR/AttrTypeBase.td"' \ + -e 's|include "mlir/Dialect/StandardOps/IR/Ops.td"|include "mlir/Dialect/Func/IR/FuncOps.td"|' \ + '{}' ';' + substituteInPlace unittest/CMakeLists.txt --replace "include(GoogleTest)" "find_package(GTest REQUIRED)" + sed -i 's/^include.*$//' unittest/CMakeLists.txt + sed -i '/LINK_LIBS/i NVPTXInfo' lib/Target/PTX/CMakeLists.txt + sed -i '/LINK_LIBS/i NVPTXCodeGen' lib/Target/PTX/CMakeLists.txt + '' + # TritonMLIRIR already links MLIRIR. Not transitive? + # + '' + # echo "target_link_libraries(TritonPTX PUBLIC MLIRIR)" >> lib/Target/PTX/CMakeLists.txt + # '' + # Already defined in llvm, when built with -DLLVM_INSTALL_UTILS + + '' + substituteInPlace bin/CMakeLists.txt \ + --replace "add_subdirectory(FileCheck)" "" + + rm cmake/FindLLVM.cmake + '' + + + ( + let + # Bash was getting weird without linting, + # but basically upstream contains [cc, ..., "-lcuda", ...] + # and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...] + old = [ "-lcuda" ]; + new = [ "-lcuda" "-L${addOpenGLRunpath.driverLink}" "-L${cuda_cudart}/lib/stubs/" ]; + + quote = x: ''"${x}"''; + oldStr = lib.concatMapStringsSep ", " quote old; + newStr = lib.concatMapStringsSep ", " quote new; + in + '' + substituteInPlace python/triton/compiler.py \ + --replace '${oldStr}' '${newStr}' + '' + ) + # Triton seems to be looking up cuda.h + + '' + sed -i 's|cu_include_dir = os.path.join.*$|cu_include_dir = "${cuda_cudart}/include"|' python/triton/compiler.py + ''; + + nativeBuildInputs = [ + cmake + pythonRelaxDepsHook + + # Requires torch (circular dependency) and probably needs GPUs: + # pytestCheckHook + + # Note for future: + # These *probably* should go in depsTargetTarget + # ...but we cannot test cross right now anyway + # because we only support cudaPackages on x86_64-linux atm + lit + llvm + llvmPackages.mlir + ]; + + buildInputs = [ + gtest + libxml2.dev + ncurses + pybind11 + zlib + ]; + + propagatedBuildInputs = [ + filelock + ]; + + # Avoid GLIBCXX mismatch with other cuda-enabled python packages + preConfigure = + '' + export CC="${backendStdenv.cc}/bin/cc"; + export CXX="${backendStdenv.cc}/bin/c++"; + '' + # Upstream's setup.py tries to write cache somewhere in ~/ + + '' + export HOME=$TMPDIR + '' + # Upstream's github actions patch setup.cfg to write base-dir. May be redundant + + '' + echo "" >> python/setup.cfg + echo "[build_ext]" >> python/setup.cfg + echo "base-dir=$PWD" >> python/setup.cfg + '' + # The rest (including buildPhase) is relative to ./python/ + + '' + cd python/ + '' + # Work around download_and_copy_ptxas() + + '' + dst_cuda="$PWD/triton/third_party/cuda/bin" + mkdir -p "$dst_cuda" + ln -s "${ptxas}" "$dst_cuda/" + ''; + + # CMake is run by setup.py instead + dontUseCmakeConfigure = true; + cmakeFlags = [ + "-DMLIR_DIR=${llvmPackages.mlir}/lib/cmake/mlir" + ]; + + postFixup = + let + ptxasDestination = "$out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas"; + in + # Setuptools (?) strips runpath and +x flags. Let's just restore the symlink + '' + rm -f ${ptxasDestination} + ln -s ${ptxas} ${ptxasDestination} + ''; + + checkInputs = [ + cmake # ctest + ]; + dontUseSetuptoolsCheck = true; + preCheck = + # build/temp* refers to build_ext.build_temp (looked up in the build logs) + '' + (cd /build/source/python/build/temp* ; ctest) + '' # For pytestCheckHook + + '' + cd test/unit + ''; + pythonImportsCheck = [ + # Circular dependency on torch + # "triton" + # "triton.language" + ]; + + # Ultimately, torch is our test suite: + passthru.tests = { + inherit torchWithRocm; + }; + + pythonRemoveDeps = [ + # Circular dependency, cf. https://github.com/openai/triton/issues/1374 + "torch" + + # CLI tools without dist-info + "cmake" + "lit" + ]; + meta = with lib; { + description = "Development repository for the Triton language and compiler"; + homepage = "https://github.com/openai/triton/"; + platforms = lib.platforms.unix; + license = licenses.mit; + maintainers = with maintainers; [ SomeoneSerge ]; + }; +} diff --git a/pkgs/development/python-modules/openai-triton/llvm15.patch b/pkgs/development/python-modules/openai-triton/llvm15.patch new file mode 100644 index 0000000000000..3e20cce238013 --- /dev/null +++ b/pkgs/development/python-modules/openai-triton/llvm15.patch @@ -0,0 +1,4617 @@ +From fc7a8e35819bda632bdcf1cf75fd9abe4d4e077a Mon Sep 17 00:00:00 2001 +From: Christian Sigg +Date: Thu, 16 Feb 2023 15:40:53 +0100 +Subject: [PATCH] Rebase Triton to LLVM-15. (#1070) + +This PR rebases Triton from LLVM-14 to LLVM-15. Most changes are +mechanical, except for the analysis framework changes. +--- + CMakeLists.txt | 6 +- + bin/CMakeLists.txt | 2 +- + bin/FileCheck/FileCheck.cpp | 3 + + bin/triton-opt.cpp | 6 +- + bin/triton-translate.cpp | 7 +- + include/triton/Analysis/Alias.h | 21 +- + include/triton/Analysis/Allocation.h | 2 + + include/triton/Analysis/AxisInfo.h | 56 ++- + include/triton/Analysis/Utility.h | 6 +- + include/triton/Conversion/Passes.td | 4 +- + include/triton/Dialect/Triton/IR/Dialect.h | 7 +- + .../triton/Dialect/Triton/IR/TritonDialect.td | 8 +- + include/triton/Dialect/Triton/IR/TritonOps.td | 12 +- + .../triton/Dialect/Triton/IR/TritonTypes.td | 2 + + .../Dialect/Triton/Transforms/Passes.td | 3 +- + include/triton/Dialect/TritonGPU/IR/Dialect.h | 4 +- + .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 7 + + .../Dialect/TritonGPU/IR/TritonGPUDialect.td | 2 +- + .../Dialect/TritonGPU/IR/TritonGPUOps.td | 13 +- + lib/Analysis/Alias.cpp | 14 +- + lib/Analysis/Allocation.cpp | 30 +- + lib/Analysis/AxisInfo.cpp | 79 ++-- + lib/Analysis/CMakeLists.txt | 2 +- + lib/Analysis/Membar.cpp | 2 +- + lib/Analysis/Utility.cpp | 54 +++ + .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 3 - + lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h | 10 +- + .../TritonGPUToLLVM/DotOpToLLVM.cpp | 5 - + .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 2 - + .../TritonGPUToLLVM/LoadStoreOpToLLVM.cpp | 5 +- + .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 2 - + .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 7 +- + .../TritonGPUToLLVM/TritonGPUToLLVMBase.h | 26 +- + .../TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 52 +-- + lib/Conversion/TritonGPUToLLVM/Utility.h | 5 +- + .../TritonToTritonGPUPass.cpp | 69 ++-- + lib/Dialect/Triton/IR/CMakeLists.txt | 10 +- + lib/Dialect/Triton/IR/Ops.cpp | 34 +- + lib/Dialect/Triton/Transforms/Combine.cpp | 6 +- + lib/Dialect/Triton/Transforms/Combine.td | 2 +- + lib/Dialect/TritonGPU/IR/Dialect.cpp | 27 +- + lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 20 +- + lib/Dialect/TritonGPU/Transforms/Combine.cpp | 2 +- + lib/Dialect/TritonGPU/Transforms/Combine.td | 1 + + .../Transforms/DecomposeConversions.cpp | 2 +- + lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 10 +- + .../Transforms/ReorderInstructions.cpp | 2 +- + .../Transforms/TritonGPUConversion.cpp | 12 +- + .../Transforms/UpdateMmaForVolta.cpp | 6 +- + lib/Dialect/TritonGPU/Transforms/Utility.cpp | 2 +- + lib/Target/LLVMIR/CMakeLists.txt | 3 +- + lib/Target/PTX/PTXTranslation.cpp | 3 + + python/setup.py | 15 +- + python/src/triton.cc | 85 +++-- + python/test/unit/language/test_core.py | 2 +- + python/triton/compiler.py | 4 +- + test/Analysis/test-alias.mlir | 24 +- + test/Analysis/test-alignment.mlir | 344 +++++++++--------- + test/Analysis/test-allocation.mlir | 32 +- + test/Analysis/test-membar.mlir | 38 +- + test/Conversion/triton_ops.mlir | 10 +- + test/Conversion/triton_to_tritongpu.mlir | 6 +- + test/Conversion/tritongpu_to_llvm.mlir | 94 ++--- + test/Target/tritongpu_to_llvmir.mlir | 4 +- + test/Target/tritongpu_to_ptx.mlir | 2 +- + test/Triton/combine.mlir | 40 +- + test/Triton/vecadd.mlir | 4 +- + test/TritonGPU/coalesce.mlir | 2 +- + test/TritonGPU/combine.mlir | 38 +- + test/TritonGPU/loop-pipeline.mlir | 22 +- + test/TritonGPU/matmul.mlir | 4 +- + test/TritonGPU/prefetch.mlir | 4 +- + test/TritonGPU/update-mma-for-volta.mlir | 4 +- + test/lib/Analysis/TestAlias.cpp | 29 +- + test/lib/Analysis/TestAllocation.cpp | 5 +- + test/lib/Analysis/TestAxisInfo.cpp | 51 +-- + test/lib/Analysis/TestMembar.cpp | 7 +- + 78 files changed, 808 insertions(+), 742 deletions(-) + +diff --git a/CMakeLists.txt b/CMakeLists.txt +index d0d361fc7c..b281a28400 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -1,4 +1,7 @@ + cmake_minimum_required(VERSION 3.6) ++ ++cmake_policy(SET CMP0116 OLD) ++ + include(ExternalProject) + + set(CMAKE_CXX_STANDARD 17) +@@ -155,7 +158,6 @@ if(TRITON_BUILD_PYTHON_MODULE) + endif() + endif() + +- + # # Triton + # file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc) + # if (WIN32 AND TRITON_BUILD_PYTHON_MODULE) +@@ -212,7 +214,7 @@ if(TRITON_BUILD_PYTHON_MODULE) + # optimizations + MLIRPass + MLIRTransforms +- MLIRLLVMIR ++ MLIRLLVMDialect + MLIRSupport + MLIRTargetLLVMIRExport + MLIRExecutionEngine +diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt +index 906f635f8b..695b3479fd 100644 +--- a/bin/CMakeLists.txt ++++ b/bin/CMakeLists.txt +@@ -48,7 +48,7 @@ llvm_update_compile_flags(triton-translate) + # MLIR core + MLIROptLib + MLIRIR +- MLIRLLVMIR ++ MLIRLLVMDialect + MLIRPass + MLIRSupport + MLIRTransforms +diff --git a/bin/FileCheck/FileCheck.cpp b/bin/FileCheck/FileCheck.cpp +index 819efc3541..9ac6f1b277 100644 +--- a/bin/FileCheck/FileCheck.cpp ++++ b/bin/FileCheck/FileCheck.cpp +@@ -19,6 +19,7 @@ + #include "llvm/Support/CommandLine.h" + #include "llvm/Support/InitLLVM.h" + #include "llvm/Support/Process.h" ++#include "llvm/Support/SourceMgr.h" + #include "llvm/Support/WithColor.h" + #include "llvm/Support/raw_ostream.h" + #include +@@ -360,6 +361,8 @@ static std::string GetCheckTypeAbbreviation(Check::FileCheckType Ty) { + return "bad-not"; + case Check::CheckBadCount: + return "bad-count"; ++ case Check::CheckMisspelled: ++ return "misspelled"; + case Check::CheckNone: + llvm_unreachable("invalid FileCheckType"); + } +diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp +index 9f3b53b7ae..f96232e1b0 100644 +--- a/bin/triton-opt.cpp ++++ b/bin/triton-opt.cpp +@@ -8,7 +8,7 @@ + + #include "mlir/IR/Dialect.h" + #include "mlir/InitAllPasses.h" +-#include "mlir/Support/MlirOptMain.h" ++#include "mlir/Tools/mlir-opt/MlirOptMain.h" + + namespace mlir { + namespace test { +@@ -33,8 +33,8 @@ int main(int argc, char **argv) { + // TODO: register Triton & TritonGPU passes + mlir::DialectRegistry registry; + registry.insert(); + + return mlir::asMainReturnCode(mlir::MlirOptMain( +diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp +index 05ba15e453..56b5d65857 100644 +--- a/bin/triton-translate.cpp ++++ b/bin/triton-translate.cpp +@@ -3,7 +3,7 @@ + #include "mlir/IR/AsmState.h" + #include "mlir/IR/BuiltinOps.h" + #include "mlir/IR/Dialect.h" +-#include "mlir/Parser.h" ++#include "mlir/Parser/Parser.h" + #include "mlir/Pass/Pass.h" + #include "mlir/Pass/PassManager.h" + #include "mlir/Support/FileUtilities.h" +@@ -38,7 +38,7 @@ OwningOpRef loadMLIRModule(llvm::StringRef inputFilename, + mlir::DialectRegistry registry; + registry.insert(); ++ scf::SCFDialect>(); + + context.appendDialectRegistry(registry); + +@@ -50,7 +50,8 @@ OwningOpRef loadMLIRModule(llvm::StringRef inputFilename, + context.loadAllAvailableDialects(); + context.allowUnregisteredDialects(); + +- OwningOpRef module(parseSourceFile(sourceMgr, &context)); ++ OwningOpRef module = ++ parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Parse MLIR file failed."; + return nullptr; +diff --git a/include/triton/Analysis/Alias.h b/include/triton/Analysis/Alias.h +index fa6b906fc9..631df518bc 100644 +--- a/include/triton/Analysis/Alias.h ++++ b/include/triton/Analysis/Alias.h +@@ -2,7 +2,7 @@ + #define TRITON_ANALYSIS_ALIAS_H + + #include "mlir/Analysis/AliasAnalysis.h" +-#include "mlir/Analysis/DataFlowAnalysis.h" ++#include "mlir/Analysis/DataFlow/SparseAnalysis.h" + #include "llvm/ADT/DenseSet.h" + + namespace mlir { +@@ -21,7 +21,7 @@ class AliasInfo { + } + + /// The pessimistic value state of a value without alias +- static AliasInfo getPessimisticValueState(MLIRContext *context) { ++ static AliasInfo getPessimisticValueState(MLIRContext *context = nullptr) { + return AliasInfo(); + } + static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); } +@@ -29,6 +29,10 @@ class AliasInfo { + /// The union of both arguments + static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs); + ++ void print(raw_ostream &os) const { ++ llvm::interleaveComma(allocs, os, [&](Value alloc) { alloc.print(os); }); ++ } ++ + private: + /// The set of allocated values that are aliased by this lattice. + /// For now, we only consider aliased value produced by the following +@@ -58,9 +62,13 @@ class AliasInfo { + //===----------------------------------------------------------------------===// + // Shared Memory Alias Analysis + //===----------------------------------------------------------------------===// +-class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis { ++class SharedMemoryAliasAnalysis ++ : public dataflow::SparseDataFlowAnalysis> { + public: +- using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; ++ using dataflow::SparseDataFlowAnalysis< ++ dataflow::Lattice>::SparseDataFlowAnalysis; ++ using dataflow::SparseDataFlowAnalysis< ++ dataflow::Lattice>::getLatticeElement; + + /// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use. + /// Given two values, returns their aliasing behavior. +@@ -70,9 +78,10 @@ class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis { + ModRefResult getModRef(Operation *op, Value location); + + /// Computes if the alloc set of the results are changed. +- ChangeResult ++ void + visitOperation(Operation *op, +- ArrayRef *> operands) override; ++ ArrayRef *> operands, ++ ArrayRef *> results) override; + }; + + } // namespace mlir +diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h +index b7c136d602..89b77034cc 100644 +--- a/include/triton/Analysis/Allocation.h ++++ b/include/triton/Analysis/Allocation.h +@@ -188,6 +188,8 @@ class Allocation { + friend class triton::AllocationAnalysis; + }; + ++template Interval(T, T) -> Interval; ++ + } // namespace mlir + + #endif // TRITON_ANALYSIS_ALLOCATION_H +diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h +index fdfbd8fbb3..7083b9c43b 100644 +--- a/include/triton/Analysis/AxisInfo.h ++++ b/include/triton/Analysis/AxisInfo.h +@@ -1,9 +1,10 @@ + #ifndef TRITON_ANALYSIS_AXISINFO_H + #define TRITON_ANALYSIS_AXISINFO_H + +-#include "mlir/Analysis/DataFlowAnalysis.h" ++#include "mlir/Analysis/DataFlow/SparseAnalysis.h" + #include "llvm/Support/raw_ostream.h" + ++#include "mlir/Support/LLVM.h" + #include "triton/Analysis/Utility.h" + #include "triton/Dialect/Triton/IR/Dialect.h" + #include "triton/Dialect/TritonGPU/IR/Dialect.h" +@@ -62,7 +63,7 @@ class AxisInfo { + } + + /// The pessimistic value state of the contiguity is unknown. +- static AxisInfo getPessimisticValueState(MLIRContext *context) { ++ static AxisInfo getPessimisticValueState(MLIRContext *context = nullptr) { + return AxisInfo(); + } + static AxisInfo getPessimisticValueState(Value value); +@@ -70,6 +71,22 @@ class AxisInfo { + /// The gcd of both arguments for each dimension + static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs); + ++ void print(raw_ostream &os) const { ++ auto print = [&](StringRef name, DimVectorT vec) { ++ os << name << " = ["; ++ llvm::interleaveComma(vec, os); ++ os << "]"; ++ }; ++ print("contiguity", contiguity); ++ print(", divisibility", divisibility); ++ print(", constancy", constancy); ++ os << ", constant_value = "; ++ if (constantValue) ++ os << *constantValue; ++ else ++ os << ""; ++ } ++ + private: + /// The _contiguity_ information maps the `d`-th + /// dimension to the length of the shortest +@@ -147,7 +164,8 @@ class AxisInfoVisitor { + } + + virtual AxisInfo +- getAxisInfo(Operation *op, ArrayRef *> operands) = 0; ++ getAxisInfo(Operation *op, ++ ArrayRef *> operands) = 0; + + virtual bool match(Operation *op) = 0; + }; +@@ -157,15 +175,16 @@ template class AxisInfoVisitorImpl : public AxisInfoVisitor { + public: + using AxisInfoVisitor::AxisInfoVisitor; + +- AxisInfo getAxisInfo(Operation *op, +- ArrayRef *> operands) final { ++ AxisInfo ++ getAxisInfo(Operation *op, ++ ArrayRef *> operands) final { + return getAxisInfo(cast(op), operands); + } + + bool match(Operation *op) final { return isa(op); } + +- virtual AxisInfo getAxisInfo(OpTy op, +- ArrayRef *> operands) { ++ virtual AxisInfo ++ getAxisInfo(OpTy op, ArrayRef *> operands) { + llvm_unreachable("Unimplemented getAxisInfo"); + } + }; +@@ -176,8 +195,9 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl { + public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + +- AxisInfo getAxisInfo(OpTy op, +- ArrayRef *> operands) override { ++ AxisInfo ++ getAxisInfo(OpTy op, ++ ArrayRef *> operands) override { + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto rank = lhsInfo.getRank(); +@@ -230,7 +250,8 @@ class AxisInfoVisitorList { + (visitors.emplace_back(std::make_unique()), ...); + } + +- AxisInfo apply(Operation *op, ArrayRef *> operands) { ++ AxisInfo apply(Operation *op, ++ ArrayRef *> operands) { + for (auto &visitor : visitors) + if (visitor->match(op)) + return visitor->getAxisInfo(op, operands); +@@ -241,16 +262,19 @@ class AxisInfoVisitorList { + std::vector> visitors; + }; + +-class AxisInfoAnalysis : public ForwardDataFlowAnalysis { ++class AxisInfoAnalysis ++ : public dataflow::SparseDataFlowAnalysis> { + private: + AxisInfoVisitorList visitors; + + public: +- AxisInfoAnalysis(MLIRContext *context); ++ AxisInfoAnalysis(DataFlowSolver &solver); ++ using dataflow::SparseDataFlowAnalysis< ++ dataflow::Lattice>::getLatticeElement; + +- ChangeResult +- visitOperation(Operation *op, +- ArrayRef *> operands) override; ++ void visitOperation(Operation *op, ++ ArrayRef *> operands, ++ ArrayRef *> results) override; + + unsigned getPtrContiguity(Value ptr); + +@@ -261,4 +285,4 @@ class AxisInfoAnalysis : public ForwardDataFlowAnalysis { + + } // namespace mlir + +-#endif +\ No newline at end of file ++#endif +diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h +index c5ac137dc1..ee7fadb59d 100644 +--- a/include/triton/Analysis/Utility.h ++++ b/include/triton/Analysis/Utility.h +@@ -1,6 +1,7 @@ + #ifndef TRITON_ANALYSIS_UTILITY_H + #define TRITON_ANALYSIS_UTILITY_H + ++#include "mlir/Analysis/DataFlowFramework.h" + #include "mlir/Analysis/SliceAnalysis.h" + #include "triton/Dialect/TritonGPU/IR/Dialect.h" + #include +@@ -12,7 +13,7 @@ namespace mlir { + class ReduceOpHelper { + public: + explicit ReduceOpHelper(triton::ReduceOp op) : op(op) { +- srcTy = op.operand().getType().cast(); ++ srcTy = op.getOperand().getType().cast(); + } + + ArrayRef getSrcShape() { return srcTy.getShape(); } +@@ -103,6 +104,9 @@ SetVector + multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr, + TransitiveFilter forwardFilter = nullptr); + ++// Create a basic DataFlowSolver with constant and dead code analysis included. ++std::unique_ptr createDataFlowSolver(); ++ + } // namespace mlir + + #endif // TRITON_ANALYSIS_UTILITY_H +diff --git a/include/triton/Conversion/Passes.td b/include/triton/Conversion/Passes.td +index 70bb20b78e..be00eb2dac 100644 +--- a/include/triton/Conversion/Passes.td ++++ b/include/triton/Conversion/Passes.td +@@ -12,7 +12,6 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO + + let dependentDialects = ["mlir::arith::ArithmeticDialect", + "mlir::math::MathDialect", +- "mlir::StandardOpsDialect", + // TODO: Does this pass depend on SCF? + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", +@@ -41,8 +40,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" + "mlir::tensor::TensorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", +- "mlir::NVVM::NVVMDialect", +- "mlir::StandardOpsDialect"]; ++ "mlir::NVVM::NVVMDialect"]; + + let options = [ + Option<"computeCapability", "compute-capability", +diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h +index e8012a51df..15869e262e 100644 +--- a/include/triton/Dialect/Triton/IR/Dialect.h ++++ b/include/triton/Dialect/Triton/IR/Dialect.h +@@ -1,14 +1,15 @@ + #ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_ + #define TRITON_DIALECT_TRITON_IR_DIALECT_H_ + ++#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" ++#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" ++#include "mlir/Dialect/Func/IR/FuncOps.h" + #include "mlir/Dialect/Math/IR/Math.h" +-#include "mlir/Dialect/SCF/SCF.h" +-#include "mlir/Dialect/StandardOps/IR/Ops.h" ++#include "mlir/Dialect/SCF/IR/SCF.h" + #include "mlir/Dialect/Tensor/IR/Tensor.h" + #include "mlir/IR/BuiltinOps.h" + #include "mlir/IR/Dialect.h" + #include "mlir/Interfaces/ControlFlowInterfaces.h" +- + #include "triton/Dialect/Triton/IR/Dialect.h.inc" + #include "triton/Dialect/Triton/IR/OpsEnums.h.inc" + #include "triton/Dialect/Triton/IR/Traits.h" +diff --git a/include/triton/Dialect/Triton/IR/TritonDialect.td b/include/triton/Dialect/Triton/IR/TritonDialect.td +index 07b069e14f..d98ce73884 100644 +--- a/include/triton/Dialect/Triton/IR/TritonDialect.td ++++ b/include/triton/Dialect/Triton/IR/TritonDialect.td +@@ -25,12 +25,9 @@ def Triton_Dialect : Dialect { + let dependentDialects = [ + "arith::ArithmeticDialect", + "math::MathDialect", +- "StandardOpsDialect", + "scf::SCFDialect", +- +- // Since LLVM 15 +- // "cf::ControlFlowDialect", +- // "func::FuncDialect" ++ "cf::ControlFlowDialect", ++ "func::FuncDialect" + ]; + + let extraClassDeclaration = [{ +@@ -38,6 +35,7 @@ def Triton_Dialect : Dialect { + }]; + + let hasConstantMaterializer = 1; ++ let useDefaultTypePrinterParser = 1; + } + + include "triton/Dialect/Triton/IR/TritonTypes.td" +diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td +index 779e0b648c..0a69211179 100644 +--- a/include/triton/Dialect/Triton/IR/TritonOps.td ++++ b/include/triton/Dialect/Triton/IR/TritonOps.td +@@ -141,11 +141,7 @@ def TT_LoadOp : TT_Op<"load", + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + ]; + +- // let assemblyFormat = "operands attr-dict `:` type($result)"; +- let parser = [{ return mlir::triton::parseLoadOp(parser, result); }]; +- +- let printer = [{ return mlir::triton::printLoadOp(p, *this); }]; +- ++ let hasCustomAssemblyFormat = 1; + let hasCanonicalizer = 1; + } + +@@ -170,11 +166,7 @@ def TT_StoreOp : TT_Op<"store", + "triton::EvictionPolicy":$evict)>, + ]; + +- // let assemblyFormat = "operands attr-dict `:` type($value)"; +- let parser = [{ return mlir::triton::parseStoreOp(parser, result); }]; +- +- let printer = [{ return mlir::triton::printStoreOp(p, *this); }]; +- ++ let hasCustomAssemblyFormat = 1; + let hasCanonicalizer = 1; + } + +diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td +index 66d2a7b9a9..2fe2fd077d 100644 +--- a/include/triton/Dialect/Triton/IR/TritonTypes.td ++++ b/include/triton/Dialect/Triton/IR/TritonTypes.td +@@ -1,6 +1,7 @@ + #ifndef TRITON_TYPES + #define TRITON_TYPES + ++include "mlir/IR/AttrTypeBase.td" + include "triton/Dialect/Triton/IR/TritonDialect.td" + + // +@@ -58,6 +59,7 @@ def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> { + }]> + ]; + ++ let hasCustomAssemblyFormat = 1; + let skipDefaultBuilders = 1; + } + def TT_PtrTensor : TensorOf<[TT_Ptr]>; +diff --git a/include/triton/Dialect/Triton/Transforms/Passes.td b/include/triton/Dialect/Triton/Transforms/Passes.td +index 8f77aed774..a25cdc5680 100644 +--- a/include/triton/Dialect/Triton/Transforms/Passes.td ++++ b/include/triton/Dialect/Triton/Transforms/Passes.td +@@ -16,8 +16,7 @@ def TritonCombineOps : Pass + + let constructor = "mlir::triton::createCombineOpsPass()"; + +- let dependentDialects = ["mlir::arith::ArithmeticDialect", +- /*SelectOp*/"mlir::StandardOpsDialect"]; ++ let dependentDialects = ["mlir::arith::ArithmeticDialect"]; + } + + #endif +diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h +index b4c8daec7b..dfc5f53ab1 100644 +--- a/include/triton/Dialect/TritonGPU/IR/Dialect.h ++++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h +@@ -1,19 +1,17 @@ + #ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ + #define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ + +-#include "mlir/Dialect/GPU/GPUDialect.h" ++#include "mlir/Dialect/GPU/IR/GPUDialect.h" + #include "mlir/Dialect/Tensor/IR/Tensor.h" + #include "mlir/IR/BuiltinOps.h" + #include "mlir/IR/Dialect.h" + + // TritonGPU depends on Triton + #include "triton/Dialect/Triton/IR/Dialect.h" +- + #include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" + #include "triton/Dialect/TritonGPU/IR/Traits.h" + + #define GET_ATTRDEF_CLASSES +-#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" + #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc" + + #define GET_OP_CLASSES +diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +index 0242c3cc17..af2aeb03a8 100644 +--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td ++++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +@@ -1,6 +1,7 @@ + #ifndef TRITONGPU_ATTRDEFS + #define TRITONGPU_ATTRDEFS + ++include "mlir/IR/AttrTypeBase.td" + include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" + include "triton/Dialect/Triton/IR/TritonInterfaces.td" + +@@ -136,6 +137,7 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / + ]; + + let extraClassDeclaration = extraBaseClassDeclaration; ++ let hasCustomAssemblyFormat = 1; + } + + //===----------------------------------------------------------------------===// +@@ -273,6 +275,7 @@ for + // ArrayRefParameter<"unsigned">:$sizePerCTA + ); + ++ let hasCustomAssemblyFormat = 1; + } + + //===----------------------------------------------------------------------===// +@@ -422,6 +425,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: + static constexpr int numBitsToHoldMmaV1ID{5}; + }]; + ++ let hasCustomAssemblyFormat = 1; + } + + def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> { +@@ -456,6 +460,8 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> { + template + SmallVector paddedShape(ArrayRef shape) const; + }]; ++ ++ let hasCustomAssemblyFormat = 1; + } + + def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> { +@@ -492,6 +498,7 @@ section 9.7.13.4.1 for more details. + + ]; + ++ let hasCustomAssemblyFormat = 1; + let extraClassDeclaration = extraBaseClassDeclaration; + } + +diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td +index 87ec1d36c6..6489a721b4 100644 +--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td ++++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td +@@ -30,7 +30,7 @@ def TritonGPU_Dialect : Dialect { + } + }]; + +- ++ let useDefaultAttributePrinterParser = 1; + } + + #endif +diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +index 510f8d0183..7aba11dc75 100644 +--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td ++++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +@@ -59,7 +59,7 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> { + // This is needed because these ops don't + // handle encodings + // e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111 +-def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise, ++def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding]> { + let summary = "integer comparison operation"; +@@ -73,7 +73,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise, + let results = (outs TT_BoolLike:$result); + } + +-def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise, ++def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding]> { + let summary = "floating-point comparison operation"; +@@ -88,8 +88,8 @@ def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise, + } + + // TODO: migrate to arith::SelectOp on LLVM16 +-def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise, +- SameOperandsAndResultShape, ++def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise, ++ SameOperandsAndResultShape, + SameOperandsAndResultEncoding]> { + let summary = "select operation"; + +@@ -188,10 +188,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", + } + }]; + +- // The custom parser could be replaced with oilist in LLVM-16 +- let parser = [{ return parseInsertSliceAsyncOp(parser, result); }]; +- +- let printer = [{ return printInsertSliceAsyncOp(p, *this); }]; ++ let hasCustomAssemblyFormat = 1; + } + + def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [MemoryEffects<[MemAlloc]>, // Allocate shared memory +diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp +index a39e4de9aa..208fdd4afc 100644 +--- a/lib/Analysis/Alias.cpp ++++ b/lib/Analysis/Alias.cpp +@@ -18,8 +18,9 @@ AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) { + return ret; + } + +-ChangeResult SharedMemoryAliasAnalysis::visitOperation( +- Operation *op, ArrayRef *> operands) { ++void SharedMemoryAliasAnalysis::visitOperation( ++ Operation *op, ArrayRef *> operands, ++ ArrayRef *> results) { + AliasInfo aliasInfo; + bool pessimistic = true; + if (maybeSharedAllocationOp(op)) { +@@ -44,14 +45,11 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation( + } + + if (pessimistic) { +- return markAllPessimisticFixpoint(op->getResults()); ++ return markAllPessimisticFixpoint(results); + } + // Join all lattice elements +- ChangeResult result = ChangeResult::NoChange; +- for (Value value : op->getResults()) { +- result |= getLatticeElement(value).join(aliasInfo); +- } +- return result; ++ for (auto *result : results) ++ propagateIfChanged(result, result->join(aliasInfo)); + } + + AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) { +diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp +index 712c08c475..b4de8dcd9d 100644 +--- a/lib/Analysis/Allocation.cpp ++++ b/lib/Analysis/Allocation.cpp +@@ -1,4 +1,5 @@ + #include "triton/Analysis/Allocation.h" ++#include "mlir/Analysis/DataFlowFramework.h" + #include "mlir/Analysis/Liveness.h" + #include "mlir/Analysis/SliceAnalysis.h" + #include "mlir/Dialect/Tensor/IR/Tensor.h" +@@ -33,10 +34,8 @@ constexpr int kPtrBitWidth = 64; + + static std::pair, SmallVector> + getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) { +- auto srcBlockedLayout = srcLayout.dyn_cast(); + auto srcMmaLayout = srcLayout.dyn_cast(); + auto srcDotLayout = srcLayout.dyn_cast(); +- auto dstBlockedLayout = dstLayout.dyn_cast(); + auto dstMmaLayout = dstLayout.dyn_cast(); + auto dstDotLayout = dstLayout.dyn_cast(); + assert(!(srcMmaLayout && dstMmaLayout) && +@@ -224,14 +223,12 @@ class AllocationAnalysis { + } + + void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) { +- LatticeElement *latticeElement = +- analysis.lookupLatticeElement(value); +- if (latticeElement) { +- auto &info = latticeElement->getValue(); +- if (!info.getAllocs().empty()) { +- for (auto alloc : info.getAllocs()) { +- allocation->addAlias(value, alloc); +- } ++ dataflow::Lattice *latticeElement = ++ analysis.getLatticeElement(value); ++ if (latticeElement && !latticeElement->isUninitialized()) { ++ AliasInfo &info = latticeElement->getValue(); ++ for (auto alloc : info.getAllocs()) { ++ allocation->addAlias(value, alloc); + } + } + } +@@ -244,14 +241,19 @@ class AllocationAnalysis { + getScratchValueSize(op); + }); + // Get the alias values +- SharedMemoryAliasAnalysis aliasAnalysis(operation->getContext()); +- aliasAnalysis.run(operation); ++ std::unique_ptr solver = createDataFlowSolver(); ++ SharedMemoryAliasAnalysis *aliasAnalysis = ++ solver->load(); ++ if (failed(solver->initializeAndRun(operation))) { ++ // TODO: return error instead of bailing out.. ++ llvm_unreachable("failed to run SharedMemoryAliasAnalysis"); ++ } + operation->walk([&](Operation *op) { + for (auto operand : op->getOperands()) { +- getValueAlias(operand, aliasAnalysis); ++ getValueAlias(operand, *aliasAnalysis); + } + for (auto value : op->getResults()) { +- getValueAlias(value, aliasAnalysis); ++ getValueAlias(value, *aliasAnalysis); + } + }); + } +diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp +index 0b7142b04d..4af46c3fbb 100644 +--- a/lib/Analysis/AxisInfo.cpp ++++ b/lib/Analysis/AxisInfo.cpp +@@ -1,4 +1,4 @@ +-#include "mlir/Analysis/DataFlowAnalysis.h" ++#include "mlir/Analysis/DataFlowFramework.h" + #include "mlir/Dialect/LLVMIR/LLVMDialect.h" + #include "llvm/Support/raw_ostream.h" + +@@ -52,7 +52,7 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) { + BlockArgument blockArg = value.dyn_cast(); + if (blockArg && blockArg.getOwner()->isEntryBlock()) { + Operation *op = blockArg.getOwner()->getParentOp(); +- if (FuncOp fun = dyn_cast(op)) { ++ if (func::FuncOp fun = dyn_cast(op)) { + Attribute attr = + fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility"); + if (attr) +@@ -136,8 +136,9 @@ class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl { + public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + +- AxisInfo getAxisInfo(OpTy op, +- ArrayRef *> operands) override { ++ AxisInfo ++ getAxisInfo(OpTy op, ++ ArrayRef *> operands) override { + return operands[0]->getValue(); + } + }; +@@ -147,8 +148,9 @@ class MakeRangeOpAxisInfoVisitor final + public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + +- AxisInfo getAxisInfo(triton::MakeRangeOp op, +- ArrayRef *> operands) override { ++ AxisInfo ++ getAxisInfo(triton::MakeRangeOp op, ++ ArrayRef *> operands) override { + auto start = op.start(); + auto end = op.end(); + return AxisInfo(/*contiguity=*/{end - start}, +@@ -162,8 +164,9 @@ class ConstantOpAxisInfoVisitor final + public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + +- AxisInfo getAxisInfo(arith::ConstantOp op, +- ArrayRef *> operands) override { ++ AxisInfo ++ getAxisInfo(arith::ConstantOp op, ++ ArrayRef *> operands) override { + auto intAttr = op.getValue().dyn_cast(); + auto boolAttr = op.getValue().dyn_cast(); + if (intAttr || boolAttr) { +@@ -416,8 +419,9 @@ class SplatOpAxisInfoVisitor final + public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + +- AxisInfo getAxisInfo(triton::SplatOp op, +- ArrayRef *> operands) override { ++ AxisInfo ++ getAxisInfo(triton::SplatOp op, ++ ArrayRef *> operands) override { + Type _retTy = *op->result_type_begin(); + TensorType retTy = _retTy.cast(); + AxisInfo opInfo = operands[0]->getValue(); +@@ -439,8 +443,9 @@ class ExpandDimsOpAxisInfoVisitor final + public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + +- AxisInfo getAxisInfo(triton::ExpandDimsOp op, +- ArrayRef *> operands) override { ++ AxisInfo ++ getAxisInfo(triton::ExpandDimsOp op, ++ ArrayRef *> operands) override { + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity = opInfo.getContiguity(); + AxisInfo::DimVectorT divisibility = opInfo.getDivisibility(); +@@ -458,8 +463,9 @@ class BroadcastOpAxisInfoVisitor final + public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + +- AxisInfo getAxisInfo(triton::BroadcastOp op, +- ArrayRef *> operands) override { ++ AxisInfo ++ getAxisInfo(triton::BroadcastOp op, ++ ArrayRef *> operands) override { + Type _retTy = *op->result_type_begin(); + Type _opTy = *op->operand_type_begin(); + TensorType retTy = _retTy.cast(); +@@ -486,8 +492,9 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl { + public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + +- AxisInfo getAxisInfo(OpTy op, +- ArrayRef *> operands) override { ++ AxisInfo ++ getAxisInfo(OpTy op, ++ ArrayRef *> operands) override { + auto resTy = op.getResult().getType().template dyn_cast(); + if (!resTy) + return AxisInfo(); +@@ -596,8 +603,9 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl { + public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + +- AxisInfo getAxisInfo(OpTy op, +- ArrayRef *> operands) override { ++ AxisInfo ++ getAxisInfo(OpTy op, ++ ArrayRef *> operands) override { + auto resTy = op.getResult().getType().template dyn_cast(); + if (!resTy) + return AxisInfo(); +@@ -757,8 +765,9 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl { + public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + +- AxisInfo getAxisInfo(OpTy op, +- ArrayRef *> operands) override { ++ AxisInfo ++ getAxisInfo(OpTy op, ++ ArrayRef *> operands) override { + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + std::optional constantValue; +@@ -786,8 +795,8 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl { + // AxisInfoAnalysis + //===----------------------------------------------------------------------===// + +-AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context) +- : ForwardDataFlowAnalysis(context) { ++AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) ++ : dataflow::SparseDataFlowAnalysis>(solver) { + // UnrealizedConversionCast: + // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is + // in the process of a PartialConversion, where UnrealizedConversionCast +@@ -819,7 +828,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context) + visitors.append, + LogicalOpAxisInfoVisitor, + LogicalOpAxisInfoVisitor>(); +- visitors.append, ++ visitors.append, + SelectOpAxisInfoVisitor>(); + visitors.append, + ShROpAxisInfoVisitor>(); +@@ -829,11 +838,12 @@ AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context) + MaxMinOpAxisInfoVisitor>(); + } + +-ChangeResult AxisInfoAnalysis::visitOperation( +- Operation *op, ArrayRef *> operands) { ++void AxisInfoAnalysis::visitOperation( ++ Operation *op, ArrayRef *> operands, ++ ArrayRef *> results) { + AxisInfo curr = visitors.apply(op, operands); + if (curr.getRank() == 0) { +- return markAllPessimisticFixpoint(op->getResults()); ++ return markAllPessimisticFixpoint(results); + } + // override with hint + auto newContiguity = curr.getContiguity(); +@@ -854,11 +864,8 @@ ChangeResult AxisInfoAnalysis::visitOperation( + curr = mlir::AxisInfo(newContiguity, newDivisibility, newConstancy, + curr.getConstantValue()); + // join all lattice elements +- ChangeResult result = ChangeResult::NoChange; +- for (Value value : op->getResults()) { +- result |= getLatticeElement(value).join(curr); +- } +- return result; ++ for (auto *result : results) ++ propagateIfChanged(result, result->join(curr)); + } + + unsigned AxisInfoAnalysis::getPtrContiguity(Value ptr) { +@@ -884,7 +891,10 @@ unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) { + auto tensorTy = ptr.getType().dyn_cast(); + if (!tensorTy) + return 1; +- auto axisInfo = lookupLatticeElement(ptr)->getValue(); ++ dataflow::Lattice *latticeElement = getLatticeElement(ptr); ++ if (!latticeElement || latticeElement->isUninitialized()) ++ return 1; ++ auto axisInfo = latticeElement->getValue(); + auto layout = tensorTy.getEncoding(); + auto order = triton::gpu::getOrder(layout); + auto maxMultipleBytes = axisInfo.getDivisibility(order[0]); +@@ -900,8 +910,11 @@ unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) { + auto tensorTy = mask.getType().dyn_cast(); + if (!tensorTy) + return 1; ++ dataflow::Lattice *latticeElement = getLatticeElement(mask); ++ if (!latticeElement || latticeElement->isUninitialized()) ++ return 1; ++ auto maskAxis = latticeElement->getValue(); + auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding()); +- auto maskAxis = lookupLatticeElement(mask)->getValue(); + auto alignment = std::max(maskAxis.getConstancy(maskOrder[0]), 1); + return alignment; + } +diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt +index afbc692510..1f761f845c 100644 +--- a/lib/Analysis/CMakeLists.txt ++++ b/lib/Analysis/CMakeLists.txt +@@ -8,7 +8,7 @@ add_mlir_library(TritonAnalysis + DEPENDS + TritonTableGen + TritonGPUAttrDefsIncGen +- ++ + LINK_LIBS PUBLIC + MLIRAnalysis + ) +diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp +index acc885e827..910274b2ac 100644 +--- a/lib/Analysis/Membar.cpp ++++ b/lib/Analysis/Membar.cpp +@@ -2,7 +2,7 @@ + #include "triton/Analysis/Alias.h" + #include "triton/Dialect/TritonGPU/IR/Dialect.h" + +-#include "mlir/Dialect/GPU/GPUDialect.h" ++#include "mlir/Dialect/GPU/IR/GPUDialect.h" + #include "mlir/Dialect/Tensor/IR/Tensor.h" + + namespace mlir { +diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp +index d9e917e731..6ea52df272 100644 +--- a/lib/Analysis/Utility.cpp ++++ b/lib/Analysis/Utility.cpp +@@ -1,5 +1,8 @@ + #include "triton/Analysis/Utility.h" ++#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" ++#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" + #include "mlir/IR/Dialect.h" ++#include "mlir/IR/Matchers.h" + #include "triton/Dialect/Triton/IR/Dialect.h" + #include "triton/Dialect/TritonGPU/IR/Dialect.h" + #include +@@ -325,4 +328,55 @@ SetVector multiRootGetSlice(Operation *op, + return multiRootTopologicalSort(slice); + } + ++namespace { ++// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis ++// interacts with constant propagation, but SparseConstantPropagation ++// doesn't seem to be sufficient. ++struct ConstantAnalysis : public DataFlowAnalysis { ++ using DataFlowAnalysis::DataFlowAnalysis; ++ ++ LogicalResult initialize(Operation *top) override { ++ WalkResult result = top->walk([&](Operation *op) { ++ if (failed(visit(op))) ++ return WalkResult::interrupt(); ++ return WalkResult::advance(); ++ }); ++ return success(!result.wasInterrupted()); ++ } ++ ++ LogicalResult visit(ProgramPoint point) override { ++ Operation *op = point.get(); ++ Attribute value; ++ if (matchPattern(op, m_Constant(&value))) { ++ auto *constant = getOrCreate>( ++ op->getResult(0)); ++ propagateIfChanged(constant, constant->join(dataflow::ConstantValue( ++ value, op->getDialect()))); ++ return success(); ++ } ++ setAllToUnknownConstants(op->getResults()); ++ for (Region ®ion : op->getRegions()) ++ setAllToUnknownConstants(region.getArguments()); ++ return success(); ++ } ++ ++ /// Set all given values as not constants. ++ void setAllToUnknownConstants(ValueRange values) { ++ dataflow::ConstantValue unknownConstant(nullptr, nullptr); ++ for (Value value : values) { ++ auto *constant = ++ getOrCreate>(value); ++ propagateIfChanged(constant, constant->join(unknownConstant)); ++ } ++ } ++}; ++} // namespace ++ ++std::unique_ptr createDataFlowSolver() { ++ auto solver = std::make_unique(); ++ solver->load(); ++ solver->load(); ++ return solver; ++} ++ + } // namespace mlir +diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +index 6a46265bd7..e352eb3698 100644 +--- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp ++++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +@@ -159,9 +159,6 @@ struct ConvertLayoutOpConversion + Value smemBase) const { + auto accumNumCTAsEachRep = product(numCTAsEachRep); + auto layout = type.getEncoding(); +- auto blockedLayout = layout.dyn_cast(); +- auto sliceLayout = layout.dyn_cast(); +- auto mmaLayout = layout.dyn_cast(); + auto rank = type.getRank(); + auto sizePerThread = getSizePerThread(layout); + auto accumSizePerThread = product(sizePerThread); +diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h +index 4b89965aa9..1d9e00519b 100644 +--- a/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h ++++ b/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h +@@ -7,10 +7,8 @@ + #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" + #include "mlir/Conversion/LLVMCommon/Pattern.h" + #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +-#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" + #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +-#include "mlir/Dialect/GPU/GPUDialect.h" ++#include "mlir/Dialect/GPU/IR/GPUDialect.h" + #include "mlir/Dialect/LLVMIR/LLVMDialect.h" + #include "mlir/Dialect/Tensor/IR/Tensor.h" + #include "mlir/IR/Matchers.h" +@@ -422,9 +420,9 @@ struct MMA16816ConversionHelper { + MMA16816ConversionHelper(Type dotOperand, MmaEncodingAttr mmaLayout, + Value thread, ConversionPatternRewriter &rewriter, + TypeConverter *typeConverter, Location loc) +- : mmaLayout(mmaLayout), thread(thread), helper(mmaLayout), +- rewriter(rewriter), typeConverter(typeConverter), loc(loc), +- ctx(mmaLayout.getContext()), wpt(mmaLayout.getWarpsPerCTA()) { ++ : mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()), thread(thread), ++ helper(mmaLayout), rewriter(rewriter), typeConverter(typeConverter), ++ loc(loc), ctx(mmaLayout.getContext()) { + helper.deduceMmaType(dotOperand); + + Value _32 = i32_val(32); +diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp +index 0f8070ca9f..e4bd47c411 100644 +--- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp ++++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp +@@ -115,8 +115,6 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { + auto DTensorTy = D.getType().cast(); + auto AShape = ATensorTy.getShape(); + auto BShape = BTensorTy.getShape(); +- auto DShape = DTensorTy.getShape(); +- auto wpt = mmaLayout.getWarpsPerCTA(); + + bool isARow = ALayout.getIsMMAv1Row().cast().getValue(); + bool isBRow = BLayout.getIsMMAv1Row().cast().getValue(); +@@ -221,7 +219,6 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { + ConversionPatternRewriter &rewriter) const { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); +- auto threadId = getThreadId(rewriter, loc); + + auto A = op.a(); + auto B = op.b(); +@@ -230,12 +227,10 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { + + auto aTensorTy = A.getType().cast(); + auto bTensorTy = B.getType().cast(); +- auto cTensorTy = C.getType().cast(); + auto dTensorTy = D.getType().cast(); + + auto aShape = aTensorTy.getShape(); + auto bShape = bTensorTy.getShape(); +- auto cShape = cTensorTy.getShape(); + + BlockedEncodingAttr dLayout = + dTensorTy.getEncoding().cast(); +diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +index deb71b9597..0b9e67674b 100644 +--- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp ++++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +@@ -61,7 +61,6 @@ struct FpToFpOpConversion + convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter, + const Value &v0, const Value &v1, const Value &v2, + const Value &v3) { +- auto ctx = rewriter.getContext(); + auto fp16x2VecTy = vec_ty(f16_ty, 2); + Value fp16x2Vec0 = undef(fp16x2VecTy); + Value fp16x2Vec1 = undef(fp16x2VecTy); +@@ -153,7 +152,6 @@ struct FpToFpOpConversion + convertBf16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter, + const Value &v0, const Value &v1, const Value &v2, + const Value &v3) { +- auto ctx = rewriter.getContext(); + auto bf16x2VecTy = vec_ty(i16_ty, 2); + Value bf16x2Vec0 = undef(bf16x2VecTy); + Value bf16x2Vec1 = undef(bf16x2VecTy); +diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +index 9a8b4702bc..bae675f0cb 100644 +--- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp ++++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +@@ -109,7 +109,8 @@ struct LoadOpConversion + DenseElementsAttr constAttr; + int64_t splatVal = 0; + if (other && valueElemTy.isa() && +- matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat()) { ++ matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat() && ++ constAttr.getElementType().isa()) { + otherIsSplatConstInt = true; + splatVal = constAttr.getSplatValue().getSExtValue(); + } +@@ -333,7 +334,6 @@ struct StoreOpConversion + elem = rewriter.create(loc, type::i8Ty(ctx), elem); + elem = bitcast(elem, valueElemTy); + +- Type u32Ty = typeConverter->convertType(type::u32Ty(ctx)); + llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx)); + } + llWord = bitcast(llWord, valArgTy); +@@ -387,7 +387,6 @@ struct AtomicCASOpConversion + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); +- Value ptr = op.ptr(); + + Value llPtr = adaptor.ptr(); + Value llCmp = adaptor.cmp(); +diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +index 69abd889be..1c973dc196 100644 +--- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp ++++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +@@ -286,7 +286,6 @@ struct ReduceOpConversion + auto srcTy = op.operand().getType().cast(); + auto srcLayout = srcTy.getEncoding(); + auto srcShape = srcTy.getShape(); +- auto srcRank = srcTy.getRank(); + auto order = getOrder(srcLayout); + + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); +@@ -351,7 +350,6 @@ struct ReduceOpConversion + + Value zero = i32_val(0); + Value laneZero = icmp_eq(laneIdAxis, zero); +- Value warpZero = icmp_eq(warpIdAxis, zero); + + for (auto it : accs) { + const SmallVector &key = it.first; +diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +index 5b77150b1a..78cfa076bd 100644 +--- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp ++++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +@@ -11,11 +11,11 @@ using ::mlir::LLVM::getStructFromElements; + using ::mlir::triton::gpu::getElemsPerThread; + using ::mlir::triton::gpu::SharedEncodingAttr; + +-struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> { +- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; ++struct ReturnOpConversion : public ConvertOpToLLVMPattern { ++ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult +- matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ++ matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + unsigned numArguments = op.getNumOperands(); + +@@ -476,7 +476,6 @@ struct ExtractSliceOpConversion + + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + auto elemPtrTy = ptr_ty(llvmElemTy, 3); +- auto resTy = op.getType().dyn_cast(); + smemObj = SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset), + strideVals, offsetVals); + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); +diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +index bb10d5b24a..00e399f848 100644 +--- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h ++++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +@@ -4,6 +4,7 @@ + // TODO: refactor so that it doesn't fail if Allocation.h + // is included after utility.h (due to conflict in `store` macro + // and ++#include "mlir/Dialect/Func/IR/FuncOps.h" + #include "triton/Analysis/Allocation.h" + + // +@@ -39,15 +40,15 @@ void vprintf_array(Value thread, ArrayRef arr, std::string info, + // TODO(Superjomn): remove the code when MLIR v15.0 is included. + // All the rights are reserved by the LLVM community. + +-struct FuncOpConversionBase : public ConvertOpToLLVMPattern { ++struct FuncOpConversionBase : public ConvertOpToLLVMPattern { + private: + /// Only retain those attributes that are not constructed by + /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument + /// attributes. +- static void filterFuncAttributes(ArrayRef attrs, +- bool filterArgAttrs, ++ static void filterFuncAttributes(func::FuncOp op, bool filterArgAttrs, + SmallVectorImpl &result) { +- for (const auto &attr : attrs) { ++ ++ for (const auto &attr : op->getAttrs()) { + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == FunctionOpInterface::getTypeAttrName() || + attr.getName() == "std.varargs" || +@@ -65,27 +66,27 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern { + } + + protected: +- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; ++ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided + // to this legalization pattern. + LLVM::LLVMFuncOp +- convertFuncOpToLLVMFuncOp(FuncOp funcOp, ++ convertFuncOpToLLVMFuncOp(func::FuncOp funcOp, + ConversionPatternRewriter &rewriter) const { + // Convert the original function arguments. They are converted using the + // LLVMTypeConverter provided to this legalization pattern. + auto varargsAttr = funcOp->getAttrOfType("func.varargs"); + TypeConverter::SignatureConversion result(funcOp.getNumArguments()); + auto llvmType = getTypeConverter()->convertFunctionSignature( +- funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); ++ funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(), ++ result); + if (!llvmType) + return nullptr; + + // Propagate argument/result attributes to all converted arguments/result + // obtained after converting a given original argument/result. + SmallVector attributes; +- filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true, +- attributes); ++ filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, attributes); + if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) { + assert(!resAttrDicts.empty() && "expected array to be non-empty"); + auto newResAttrDicts = +@@ -131,7 +132,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern { + } + auto newFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), llvmType, linkage, +- /*dsoLocal*/ false, attributes); ++ /*dsoLocal*/ false, LLVM::CConv::C, attributes); + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter, +@@ -191,8 +192,8 @@ class ConvertTritonGPUOpToLLVMPatternBase { + const Allocation *allocation, + Value smem, + IndexCacheInfo indexCacheInfo) +- : converter(&typeConverter), indexCacheInfo(indexCacheInfo), +- allocation(allocation), smem(smem) {} ++ : converter(&typeConverter), allocation(allocation), smem(smem), ++ indexCacheInfo(indexCacheInfo) {} + + LLVMTypeConverter *getTypeConverter() const { return converter; } + +@@ -861,7 +862,6 @@ class ConvertTritonGPUOpToLLVMPatternBase { + ArrayRef shape) const { + auto parent = sliceLayout.getParent(); + unsigned dim = sliceLayout.getDim(); +- size_t rank = shape.size(); + auto parentIndices = + emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape)); + unsigned numIndices = parentIndices.size(); +diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +index ff1af09835..6f66af4e34 100644 +--- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp ++++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +@@ -1,10 +1,11 @@ + #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h" + ++#include "mlir/Analysis/DataFlowFramework.h" + #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" ++#include "mlir/Conversion/ControlFlowToLLVM//ControlFlowToLLVM.h" + #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" + #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +-#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" ++#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" + #include "mlir/Dialect/LLVMIR/LLVMDialect.h" + #include "mlir/Dialect/LLVMIR/NVVMDialect.h" + #include "mlir/Pass/Pass.h" +@@ -40,7 +41,6 @@ class TritonLLVMConversionTarget : public ConversionTarget { + addIllegalDialect(); + addIllegalDialect(); + addIllegalDialect(); +- addIllegalDialect(); + addLegalOp(); + } + }; +@@ -51,7 +51,7 @@ class TritonLLVMFunctionConversionTarget : public ConversionTarget { + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); +- addIllegalOp(); ++ addIllegalOp(); + addLegalOp(); + } + }; +@@ -69,7 +69,7 @@ struct FuncOpConversion : public FuncOpConversionBase { + : FuncOpConversionBase(converter, benefit), numWarps(numWarps) {} + + LogicalResult +- matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, ++ matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); + if (!newFuncOp) +@@ -133,7 +133,8 @@ class ConvertTritonGPUToLLVM + decomposeBlockedToDotOperand(mod); + + // Step 2 +- decomposeInsertSliceAsyncOp(mod); ++ if (failed(decomposeInsertSliceAsyncOp(mod))) ++ return signalPassFailure(); + + // Step 3 + Allocation allocation(mod); +@@ -142,7 +143,7 @@ class ConvertTritonGPUToLLVM + + // Step 4 + RewritePatternSet scf_patterns(context); +- mlir::populateLoopToStdConversionPatterns(scf_patterns); ++ mlir::populateSCFToControlFlowConversionPatterns(scf_patterns); + mlir::ConversionTarget scf_target(*context); + scf_target.addIllegalOp(); +@@ -159,8 +160,10 @@ class ConvertTritonGPUToLLVM + return signalPassFailure(); + + // Step 6 - get axis and shared memory info +- AxisInfoAnalysis axisInfoAnalysis(mod.getContext()); +- axisInfoAnalysis.run(mod); ++ std::unique_ptr solver = createDataFlowSolver(); ++ AxisInfoAnalysis *axisInfoAnalysis = solver->load(); ++ if (failed(solver->initializeAndRun(mod))) ++ return signalPassFailure(); + initSharedMemory(allocation.getSharedMemorySize(), typeConverter); + mod->setAttr("triton_gpu.shared", + mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32), +@@ -178,38 +181,39 @@ class ConvertTritonGPUToLLVM + + // Normal conversions + populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps, +- axisInfoAnalysis, &allocation, smem, ++ *axisInfoAnalysis, &allocation, smem, + indexCacheInfo, /*benefit=*/10); + // ConvertLayoutOp + populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps, +- axisInfoAnalysis, &allocation, smem, ++ *axisInfoAnalysis, &allocation, smem, + indexCacheInfo, /*benefit=*/10); + // DotOp + populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps, +- axisInfoAnalysis, &allocation, smem, ++ *axisInfoAnalysis, &allocation, smem, + /*benefit=*/10); + // ElementwiseOp + populateElementwiseOpToLLVMPatterns(typeConverter, patterns, numWarps, +- axisInfoAnalysis, &allocation, smem, ++ *axisInfoAnalysis, &allocation, smem, + /*benefit=*/10); + // LoadStoreOp + populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps, +- axisInfoAnalysis, &allocation, smem, ++ *axisInfoAnalysis, &allocation, smem, + indexCacheInfo, /*benefit=*/10); + // ReduceOp + populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps, +- axisInfoAnalysis, &allocation, smem, ++ *axisInfoAnalysis, &allocation, smem, + indexCacheInfo, /*benefit=*/10); + // ViewOp + populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps, +- axisInfoAnalysis, &allocation, smem, ++ *axisInfoAnalysis, &allocation, smem, + /*benefit=*/10); + + // Add arith/math's patterns to help convert scalar expression to LLVM. + mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, + patterns); + mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns); +- mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns); ++ mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, ++ patterns); + mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(mod, target, std::move(patterns)))) +@@ -306,9 +310,11 @@ class ConvertTritonGPUToLLVM + }); + } + +- void decomposeInsertSliceAsyncOp(ModuleOp mod) const { +- AxisInfoAnalysis axisInfoAnalysis(mod.getContext()); +- axisInfoAnalysis.run(mod); ++ LogicalResult decomposeInsertSliceAsyncOp(ModuleOp mod) const { ++ std::unique_ptr solver = createDataFlowSolver(); ++ AxisInfoAnalysis *axisInfoAnalysis = solver->load(); ++ if (failed(solver->initializeAndRun(mod))) ++ return failure(); + // TODO(Keren): This is a hacky knob that may cause performance regression + // when decomposition has been performed. We should remove this knob once we + // have thorough analysis on async wait. Currently, we decompose +@@ -342,7 +348,7 @@ class ConvertTritonGPUToLLVM + auto resSharedLayout = + dstTy.getEncoding().dyn_cast(); + auto resElemTy = dstTy.getElementType(); +- unsigned inVec = axisInfoAnalysis.getPtrContiguity(src); ++ unsigned inVec = axisInfoAnalysis->getPtrContiguity(src); + unsigned outVec = resSharedLayout.getVec(); + unsigned minVec = std::min(outVec, inVec); + auto maxBitWidth = +@@ -400,11 +406,11 @@ class ConvertTritonGPUToLLVM + } else if (decomposed) { + // Wait for all previous async ops + OpBuilder builder(asyncWaitOp); +- auto newAsyncWaitOp = +- builder.create(asyncWaitOp.getLoc(), 0); ++ builder.create(asyncWaitOp.getLoc(), 0); + asyncWaitOp.erase(); + } + }); ++ return success(); + } + }; + +diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h +index d35dac28c5..11976908cf 100644 +--- a/lib/Conversion/TritonGPUToLLVM/Utility.h ++++ b/lib/Conversion/TritonGPUToLLVM/Utility.h +@@ -220,10 +220,7 @@ struct SharedMemoryObject { + ConversionPatternRewriter &rewriter) + : base(base) { + strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter); +- +- for (auto idx : order) { +- offsets.emplace_back(i32_val(0)); +- } ++ offsets.append(order.size(), i32_val(0)); + } + + SmallVector getElems() const { +diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +index fe42202c34..5f230f787f 100644 +--- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp ++++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +@@ -1,10 +1,10 @@ + #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" + + #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +-#include "mlir/Dialect/GPU/GPUDialect.h" ++#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" ++#include "mlir/Dialect/GPU/IR/GPUDialect.h" + #include "mlir/Dialect/LLVMIR/LLVMDialect.h" + #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +-#include "mlir/Dialect/StandardOps/IR/Ops.h" + #include "mlir/Pass/Pass.h" + #include "mlir/Transforms/DialectConversion.h" + #include "triton/Dialect/Triton/IR/Dialect.h" +@@ -59,10 +59,13 @@ class ArithConstantPattern : public OpConversionPattern { + Type retType = getTypeConverter()->convertType(op.getType()); + auto value = adaptor.getValue().dyn_cast(); + assert(value); +- rewriter.replaceOpWithNewOp( +- op, retType, +- value.reshape(retType) // This is a hack. We just want to add encoding +- ); ++ if (value.getElementType().isInteger(1) && value.isSplat()) ++ // Workaround until https://reviews.llvm.org/D133743 is included. ++ value = DenseElementsAttr::get(retType, value.getSplatValue()); ++ else ++ // This is a hack. We just want to add encoding ++ value = value.reshape(retType); ++ rewriter.replaceOpWithNewOp(op, retType, value); + return success(); + } + }; +@@ -127,12 +130,12 @@ void populateArithmeticPatternsAndLegality( + } + + // this shouldn't exist if mlir's SelectOp checked encodings properly +-class StdSelectPattern : public OpConversionPattern { ++class StdSelectPattern : public OpConversionPattern { + public: +- using OpConversionPattern::OpConversionPattern; ++ using OpConversionPattern::OpConversionPattern; + + LogicalResult +- matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor, ++ matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = this->getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp( +@@ -148,8 +151,8 @@ void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter, + MLIRContext *context = patterns.getContext(); + // Rewrite rule + patterns.add(typeConverter, context); +- target.addLegalOp(); // this is ok because all functions are inlined +- // by the frontend ++ target.addLegalOp(); // this is ok because all functions are ++ // inlined by the frontend + } + + void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter, +@@ -455,18 +458,19 @@ struct TritonPrintfPattern : public OpConversionPattern { + void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); +- patterns.add< // TODO: view should have custom pattern that views the layout +- TritonGenericPattern, +- TritonGenericPattern, +- TritonGenericPattern, +- TritonGenericPattern, +- TritonGenericPattern, +- TritonGenericPattern, TritonBroadcastPattern, +- TritonGenericPattern, TritonCatPattern, +- TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern, +- TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, +- TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern, +- TritonAtomicRMWPattern>(typeConverter, context); ++ patterns ++ .insert< // TODO: view should have custom pattern that views the layout ++ TritonGenericPattern, ++ TritonGenericPattern, ++ TritonGenericPattern, ++ TritonGenericPattern, ++ TritonGenericPattern, ++ TritonGenericPattern, TritonBroadcastPattern, ++ TritonGenericPattern, TritonCatPattern, ++ TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern, ++ TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, ++ TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern, ++ TritonAtomicRMWPattern>(typeConverter, context); + } + + // +@@ -623,29 +627,28 @@ void populateSCFPatterns(TritonGPUTypeConverter &typeConverter, + + // CF + +-class CFBranchPattern : public OpConversionPattern { ++class CFBranchPattern : public OpConversionPattern { + public: +- using OpConversionPattern::OpConversionPattern; ++ using OpConversionPattern::OpConversionPattern; + + LogicalResult +- matchAndRewrite(BranchOp op, BranchOp::Adaptor adaptor, ++ matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { +- auto converter = getTypeConverter(); +- auto newOp = rewriter.replaceOpWithNewOp(op, op.getSuccessor(), +- adaptor.getOperands()); ++ auto newOp = rewriter.replaceOpWithNewOp( ++ op, op.getSuccessor(), adaptor.getOperands()); + return success(); + } + }; + +-class CFCondBranchPattern : public OpConversionPattern { ++class CFCondBranchPattern : public OpConversionPattern { + public: +- using OpConversionPattern::OpConversionPattern; ++ using OpConversionPattern::OpConversionPattern; + + LogicalResult +- matchAndRewrite(CondBranchOp op, CondBranchOp::Adaptor adaptor, ++ matchAndRewrite(cf::CondBranchOp op, cf::CondBranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); +- auto newOp = rewriter.replaceOpWithNewOp( ++ auto newOp = rewriter.replaceOpWithNewOp( + op, adaptor.getCondition(), op.getTrueDest(), + adaptor.getTrueDestOperands(), op.getFalseDest(), + adaptor.getFalseDestOperands()); +diff --git a/lib/Dialect/Triton/IR/CMakeLists.txt b/lib/Dialect/Triton/IR/CMakeLists.txt +index 2d679b21fd..705554ba6b 100644 +--- a/lib/Dialect/Triton/IR/CMakeLists.txt ++++ b/lib/Dialect/Triton/IR/CMakeLists.txt +@@ -10,11 +10,7 @@ add_mlir_dialect_library(TritonIR + + LINK_LIBS PUBLIC + MLIRIR +- MLIRArithmetic +- MLIRSCF +- +- # Since LLVM 15 +- # MLIRFunc +- # else +- MLIRStandard ++ MLIRArithmeticDialect ++ MLIRSCFDialect ++ MLIRFuncDialect + ) +diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp +index 3aadbfa0c0..86570359c5 100644 +--- a/lib/Dialect/Triton/IR/Ops.cpp ++++ b/lib/Dialect/Triton/IR/Ops.cpp +@@ -1,10 +1,9 @@ +-#include "triton/Dialect/Triton/IR/Dialect.h" +-#include "triton/Dialect/Triton/IR/Types.h" +- + #include "mlir/IR/Builders.h" + #include "mlir/IR/BuiltinAttributes.h" + #include "mlir/IR/BuiltinTypes.h" + #include "mlir/IR/OperationSupport.h" ++#include "triton/Dialect/Triton/IR/Dialect.h" ++#include "triton/Dialect/Triton/IR/Types.h" + + namespace mlir { + namespace triton { +@@ -38,8 +37,8 @@ static Type getPointerTypeSameShape(Type type) { + } + + // Parser & printer for assembly forms +-ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { +- SmallVector allOperands; ++ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { ++ SmallVector allOperands; + Type resultTypes[1]; + SMLoc allOperandLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(allOperands) || +@@ -73,18 +72,18 @@ ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { + return success(); + } + +-void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) { ++void LoadOp::print(OpAsmPrinter &printer) { + printer << " "; +- printer << loadOp.getOperation()->getOperands(); ++ printer << getOperation()->getOperands(); + // "operand_segment_sizes" can be deduced, so we don't print it. +- printer.printOptionalAttrDict(loadOp->getAttrs(), +- {loadOp.operand_segment_sizesAttrName()}); ++ printer.printOptionalAttrDict(getOperation()->getAttrs(), ++ {operand_segment_sizesAttrName()}); + printer << " : "; +- printer.printStrippedAttrOrType(loadOp.result().getType()); ++ printer.printStrippedAttrOrType(getResult().getType()); + } + +-ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { +- SmallVector allOperands; ++ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { ++ SmallVector allOperands; + Type valueType; + SMLoc allOperandLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(allOperands) || +@@ -104,12 +103,12 @@ ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { + return success(); + } + +-void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) { ++void StoreOp::print(OpAsmPrinter &printer) { + printer << " "; +- printer << storeOp.getOperation()->getOperands(); +- printer.printOptionalAttrDict(storeOp->getAttrs(), /*elidedAttrs=*/{}); ++ printer << getOperation()->getOperands(); ++ printer.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs=*/{}); + printer << " : "; +- printer.printStrippedAttrOrType(storeOp.value().getType()); ++ printer.printStrippedAttrOrType(value().getType()); + } + + } // namespace triton +@@ -319,7 +318,8 @@ OpFoldResult SplatOp::fold(ArrayRef operands) { + if (!constOperand) + return {}; + auto shapedType = getType().cast(); +- auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()}); ++ auto ret = SplatElementsAttr::get( ++ shapedType, ArrayRef(constOperand.getValue())); + return ret; + } + +diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp +index 2261472170..11570283d6 100644 +--- a/lib/Dialect/Triton/Transforms/Combine.cpp ++++ b/lib/Dialect/Triton/Transforms/Combine.cpp +@@ -57,13 +57,13 @@ DenseElementsAttr getConstantValue(Builder &builder, Attribute value, + class CombineSelectMaskedLoadPattern : public mlir::RewritePattern { + public: + CombineSelectMaskedLoadPattern(mlir::MLIRContext *context) +- : mlir::RewritePattern(mlir::SelectOp::getOperationName(), 3, context, +- {triton::LoadOp::getOperationName()}) {} ++ : mlir::RewritePattern(mlir::arith::SelectOp::getOperationName(), 3, ++ context, {triton::LoadOp::getOperationName()}) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { +- auto selectOp = llvm::dyn_cast(op); ++ auto selectOp = llvm::dyn_cast(op); + if (!selectOp) + return mlir::failure(); + +diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td +index 14f286b26e..ded0e346e6 100644 +--- a/lib/Dialect/Triton/Transforms/Combine.td ++++ b/lib/Dialect/Triton/Transforms/Combine.td +@@ -1,9 +1,9 @@ + #ifndef TRITON_PATTERNS + #define TRITON_PATTERNS + +-include "mlir/Dialect/StandardOps/IR/Ops.td" + include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" + include "triton/Dialect/Triton/IR/TritonOps.td" ++include "mlir/IR/PatternBase.td" + + + // AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) +diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp +index 1fbc609e88..bfc3f3d3da 100644 +--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp ++++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp +@@ -1,14 +1,14 @@ ++#include "triton/Dialect/Triton/IR/Dialect.h" ++ + #include + + #include "mlir/IR/DialectImplementation.h" + #include "mlir/IR/OpImplementation.h" + #include "triton/Analysis/Utility.h" +-#include "triton/Dialect/Triton/IR/Dialect.h" ++#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" + #include "triton/Dialect/TritonGPU/IR/Dialect.h" + #include "llvm/ADT/TypeSwitch.h" + +-#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" +- + using namespace mlir; + using namespace mlir::triton::gpu; + +@@ -366,7 +366,6 @@ template SmallVector + SliceEncodingAttr::paddedShape(ArrayRef shape) const; + + unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef shape) const { +- size_t rank = shape.size(); + auto parent = getParent(); + return ::getElemsPerThread(parent, paddedShape(shape)); + } +@@ -655,9 +654,9 @@ void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const { + // InsertSliceAsyncOp + //===----------------------------------------------------------------------===// + +-ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser, +- OperationState &result) { +- SmallVector allOperands; ++ParseResult InsertSliceAsyncOp::parse(OpAsmParser &parser, ++ OperationState &result) { ++ SmallVector allOperands; + Type srcType, dstType; + SMLoc allOperandLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(allOperands) || +@@ -696,18 +695,16 @@ ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser, + return success(); + } + +-void printInsertSliceAsyncOp(OpAsmPrinter &printer, +- InsertSliceAsyncOp insertSliceAsyncOp) { ++void InsertSliceAsyncOp::print(OpAsmPrinter &printer) { + printer << " "; +- printer << insertSliceAsyncOp.getOperation()->getOperands(); ++ printer << getOperation()->getOperands(); + // "operand_segment_sizes" can be deduced, so we don't print it. +- printer.printOptionalAttrDict( +- insertSliceAsyncOp->getAttrs(), +- {insertSliceAsyncOp.operand_segment_sizesAttrName()}); ++ printer.printOptionalAttrDict(getOperation()->getAttrs(), ++ {operand_segment_sizesAttrName()}); + printer << " : "; +- printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType()); ++ printer.printStrippedAttrOrType(src().getType()); + printer << " -> "; +- printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType()); ++ printer.printStrippedAttrOrType(result().getType()); + } + + //===----------------------------------------------------------------------===// +diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +index 82407980d3..ee6009f44a 100644 +--- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +@@ -27,7 +27,11 @@ struct CoalescePass : public TritonGPUCoalesceBase { + auto origType = ptr.getType().cast(); + // Get the shape of the tensor. + size_t rank = origType.getRank(); +- AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue(); ++ dataflow::Lattice *latticeElement = ++ axisInfo.getLatticeElement(ptr); ++ AxisInfo info = latticeElement && !latticeElement->isUninitialized() ++ ? latticeElement->getValue() ++ : AxisInfo(); + // Get the contiguity order of `ptr` + auto order = argSort(info.getContiguity()); + // The desired divisibility is the maximum divisibility +@@ -40,7 +44,7 @@ struct CoalescePass : public TritonGPUCoalesceBase { + for (Value val : op->getResults()) { + if (val.getType() != origType) + continue; +- auto valInfo = axisInfo.lookupLatticeElement(val); ++ auto valInfo = axisInfo.getLatticeElement(val); + auto currOrder = argSort(valInfo->getValue().getContiguity()); + if (order == currOrder) + withSameOrder.insert(val); +@@ -55,7 +59,7 @@ struct CoalescePass : public TritonGPUCoalesceBase { + unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); + unsigned perThread = 1; + for (Value val : withSameOrder) { +- AxisInfo info = axisInfo.lookupLatticeElement(val)->getValue(); ++ AxisInfo info = axisInfo.getLatticeElement(val)->getValue(); + unsigned maxMultipleBytes = info.getDivisibility(order[0]); + unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u); + unsigned maxContig = info.getContiguity(order[0]); +@@ -123,8 +127,10 @@ struct CoalescePass : public TritonGPUCoalesceBase { + void runOnOperation() override { + Operation *op = getOperation(); + // Run axis info analysis +- AxisInfoAnalysis axisInfo(&getContext()); +- axisInfo.run(op); ++ std::unique_ptr solver = createDataFlowSolver(); ++ AxisInfoAnalysis *axisInfo = solver->load(); ++ if (failed(solver->initializeAndRun(op))) ++ return signalPassFailure(); + + // For each i/o operation, we determine what layout + // the pointers should have for best memory coalescing +@@ -146,10 +152,10 @@ struct CoalescePass : public TritonGPUCoalesceBase { + RankedTensorType ty = ptr.getType().template dyn_cast(); + if (!ty || !ty.getElementType().isa()) + return; +- AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue(); ++ AxisInfo info = axisInfo->getLatticeElement(ptr)->getValue(); + auto mod = curr->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); +- auto convertType = getTypeConverter(axisInfo, ptr, numWarps); ++ auto convertType = getTypeConverter(*axisInfo, ptr, numWarps); + layoutMap[ptr] = convertType; + }); + +diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp +index efa37ff2dc..089ce3996c 100644 +--- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp +@@ -1,6 +1,6 @@ + #include "Utility.h" + #include "mlir/Analysis/SliceAnalysis.h" +-#include "mlir/Dialect/SCF/SCF.h" ++#include "mlir/Dialect/SCF/IR/SCF.h" + #include "mlir/IR/BlockAndValueMapping.h" + #include "mlir/IR/BuiltinAttributes.h" + #include "mlir/IR/Matchers.h" +diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.td b/lib/Dialect/TritonGPU/Transforms/Combine.td +index 6bf1b14866..6a7b10dbcb 100644 +--- a/lib/Dialect/TritonGPU/Transforms/Combine.td ++++ b/lib/Dialect/TritonGPU/Transforms/Combine.td +@@ -3,5 +3,6 @@ + + include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td" + include "triton/Dialect/Triton/IR/TritonOps.td" ++include "mlir/IR/PatternBase.td" + + #endif +diff --git a/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp b/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp +index 4bd3bc76bf..b2f8defd81 100644 +--- a/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp +@@ -1,5 +1,5 @@ + #include "mlir/Analysis/SliceAnalysis.h" +-#include "mlir/Dialect/SCF/SCF.h" ++#include "mlir/Dialect/SCF/IR/SCF.h" + #include "mlir/IR/BlockAndValueMapping.h" + #include "mlir/IR/BuiltinAttributes.h" + #include "mlir/IR/Matchers.h" +diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +index 9b2f42231e..85f746c1dc 100644 +--- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +@@ -2,6 +2,7 @@ + #include "mlir/IR/BlockAndValueMapping.h" + #include "mlir/IR/TypeUtilities.h" + #include "triton/Analysis/AxisInfo.h" ++#include "triton/Analysis/Utility.h" + #include "triton/Dialect/TritonGPU/IR/Dialect.h" + #include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +@@ -160,15 +161,18 @@ ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op, + LogicalResult LoopPipeliner::initialize() { + Block *loop = forOp.getBody(); + +- AxisInfoAnalysis axisInfoAnalysis(forOp.getContext()); +- axisInfoAnalysis.run(forOp->getParentOfType()); ++ std::unique_ptr solver = createDataFlowSolver(); ++ AxisInfoAnalysis *axisInfoAnalysis = solver->load(); ++ if (failed(solver->initializeAndRun(forOp->getParentOfType()))) { ++ return failure(); ++ } + + // can we use forOp.walk(...) here? + SmallVector allLoads; + for (Operation &op : *loop) + if (auto loadOp = dyn_cast(&op)) { + auto ptr = loadOp.ptr(); +- unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); ++ unsigned vec = axisInfoAnalysis->getPtrContiguity(ptr); + auto tensorTy = ptr.getType().dyn_cast(); + if (!tensorTy) + continue; +diff --git a/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp b/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp +index 0e7dbe5264..b95a4f50a6 100644 +--- a/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp +@@ -1,5 +1,5 @@ + #include "mlir/Analysis/SliceAnalysis.h" +-#include "mlir/Dialect/SCF/SCF.h" ++#include "mlir/Dialect/SCF/IR/SCF.h" + #include "mlir/IR/BlockAndValueMapping.h" + #include "mlir/IR/BuiltinAttributes.h" + #include "mlir/IR/Matchers.h" +diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +index 37ac710995..762e887f36 100644 +--- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +@@ -82,12 +82,12 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( + scf::ReduceReturnOp>(); + + addDynamicallyLegalDialect([&](Operation *op) { +- if (typeConverter.isLegal(op)) +- return true; +- return false; +- }); ++ triton::TritonDialect, scf::SCFDialect>( ++ [&](Operation *op) { ++ if (typeConverter.isLegal(op)) ++ return true; ++ return false; ++ }); + + // We have requirements for the data layouts + addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { +diff --git a/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp b/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp +index c229104286..c911fd4a5c 100644 +--- a/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp +@@ -1,5 +1,5 @@ + #include "Utility.h" +-#include "mlir/Dialect/SCF/SCF.h" ++#include "mlir/Dialect/SCF/IR/SCF.h" + #include "mlir/IR/Matchers.h" + #include "mlir/IR/PatternMatch.h" + #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +@@ -118,8 +118,8 @@ void setOpResultType(Operation *op, ArrayRef newTypes) { + .get("value") + .dyn_cast(); + if (attr) { +- auto newAttr = mlir::DenseElementsAttr::getFromRawBuffer( +- newType, attr.getRawData(), true); ++ auto newAttr = ++ mlir::DenseElementsAttr::getFromRawBuffer(newType, attr.getRawData()); + op->setAttr("value", newAttr); + } + } +diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp +index ed15f02f67..6400f1633a 100644 +--- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp +@@ -1,5 +1,5 @@ + #include "Utility.h" +-#include "mlir/Dialect/SCF/SCF.h" ++#include "mlir/Dialect/SCF/IR/SCF.h" + #include "mlir/IR/BlockAndValueMapping.h" + #include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt +index f1bbd0bf4e..ac8973ad19 100644 +--- a/lib/Target/LLVMIR/CMakeLists.txt ++++ b/lib/Target/LLVMIR/CMakeLists.txt +@@ -6,8 +6,7 @@ add_mlir_translation_library(TritonLLVMIR + + LINK_LIBS PUBLIC + MLIRIR +- MLIRLLVMIR +- MLIRSCFToStandard ++ MLIRLLVMDialect + MLIRSupport + MLIRTargetLLVMIRExport + ) +diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp +index 4cb0d8193c..6a5453a6e7 100644 +--- a/lib/Target/PTX/PTXTranslation.cpp ++++ b/lib/Target/PTX/PTXTranslation.cpp +@@ -1,11 +1,14 @@ + #include "triton/Target/PTX/PTXTranslation.h" + #include "triton/Target/LLVMIR/LLVMIRTranslation.h" ++#include + + #include "llvm/IR/IRBuilder.h" + #include "llvm/IR/LegacyPassManager.h" + #include "llvm/IR/Module.h" + #include "llvm/IR/Verifier.h" + #include "llvm/MC/TargetRegistry.h" ++#include "llvm/Pass.h" ++#include "llvm/Support/CommandLine.h" + #include "llvm/Support/TargetSelect.h" + #include "llvm/Target/TargetMachine.h" + +diff --git a/python/setup.py b/python/setup.py +index 2ac3accd25..4530b36714 100644 +--- a/python/setup.py ++++ b/python/setup.py +@@ -57,19 +57,10 @@ def get_pybind11_package_info(): + def get_llvm_package_info(): + # download if nothing is installed + system = platform.system() +- if system == "Darwin": +- system_suffix = "apple-darwin" +- elif system == "Linux": +- vglibc = tuple(map(int, platform.libc_ver()[1].split('.'))) +- vglibc = vglibc[0] * 100 + vglibc[1] +- linux_suffix = 'ubuntu-18.04' if vglibc > 217 else 'centos-7' +- system_suffix = f"linux-gnu-{linux_suffix}" +- else: +- raise RuntimeError(f"unsupported system: {system}") ++ system_suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system] + use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False") +- release_suffix = "assert" if use_assert_enabled_llvm else "release" +- name = f'llvm+mlir-14.0.6-x86_64-{system_suffix}-{release_suffix}' +- url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/llvm-14.0.6-f28c006a5895/{name}.tar.xz" ++ name = 'llvm+mlir-15.0.7-x86_64-{}-{}'.format(system_suffix, "assert" if use_assert_enabled_llvm else "release") ++ url = "https://github.com/ptillet/triton-llvm-releases/releases/download/llvm-15.0.7-8dfdcc7b7bf6/{}.tar.xz".format(name) + return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH") + + +diff --git a/python/src/triton.cc b/python/src/triton.cc +index c40b117a55..f190eacc34 100644 +--- a/python/src/triton.cc ++++ b/python/src/triton.cc +@@ -8,9 +8,10 @@ + #include "mlir/Pass/PassManager.h" + #include "mlir/Transforms/Passes.h" + +-#include "mlir/Parser.h" ++#include "mlir/Parser/Parser.h" + #include "mlir/Support/FileUtilities.h" + ++#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" + #include "mlir/Dialect/LLVMIR/LLVMDialect.h" + #include "triton/Analysis/Allocation.h" + #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h" +@@ -195,7 +196,7 @@ void init_triton_ir(py::module &&m) { + std::string attrName = name + "_arg" + std::to_string(id); + mlir::Block *owner = arg.getOwner(); + if (owner->isEntryBlock() && +- !mlir::isa(owner->getParentOp())) { ++ !mlir::isa(owner->getParentOp())) { + owner->getParentOp()->setAttr(attrName, attr); + } + } +@@ -348,7 +349,7 @@ void init_triton_ir(py::module &&m) { + return str; + }) + .def("push_back", +- [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void { ++ [](mlir::ModuleOp &self, mlir::func::FuncOp &funcOp) -> void { + self.push_back(funcOp); + }) + .def("has_function", +@@ -358,16 +359,18 @@ void init_triton_ir(py::module &&m) { + return false; + }) + .def("get_function", +- [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp { +- return self.lookupSymbol(funcName); +- }) +- .def("get_single_function", [](mlir::ModuleOp &self) -> mlir::FuncOp { +- llvm::SmallVector funcs; +- self.walk([&](mlir::FuncOp func) { funcs.push_back(func); }); +- if (funcs.size() != 1) +- throw std::runtime_error("Expected a single function"); +- return funcs[0]; +- }); ++ [](mlir::ModuleOp &self, ++ std::string &funcName) -> mlir::func::FuncOp { ++ return self.lookupSymbol(funcName); ++ }) ++ .def("get_single_function", ++ [](mlir::ModuleOp &self) -> mlir::func::FuncOp { ++ llvm::SmallVector funcs; ++ self.walk([&](mlir::func::FuncOp func) { funcs.push_back(func); }); ++ if (funcs.size() != 1) ++ throw std::runtime_error("Expected a single function"); ++ return funcs[0]; ++ }); + + m.def("make_attr", + [](const std::vector &values, mlir::MLIRContext &context) { +@@ -388,47 +391,48 @@ void init_triton_ir(py::module &&m) { + registry.insert(); ++ mlir::func::FuncDialect, mlir::scf::SCFDialect>(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + + // parse module +- mlir::OwningOpRef module( +- mlir::parseSourceFile(inputFilename, &context)); ++ mlir::OwningOpRef module = ++ mlir::parseSourceFile(inputFilename, &context); ++ if (!module) ++ throw std::runtime_error("Parse MLIR file failed."); + // locations are incompatible with ptx < 7.5 ! + module->walk([](mlir::Operation *op) { + op->setLoc(mlir::UnknownLoc::get(op->getContext())); + }); +- if (!module) +- throw std::runtime_error("Parse MLIR file failed."); + + return module->clone(); + }, + ret::take_ownership); + +- py::class_(m, "function") ++ py::class_(m, "function") + // .def_property_readonly("attrs", &ir::function::attrs) + // .def("add_attr", &ir::function::add_attr); + .def("args", +- [](mlir::FuncOp &self, unsigned idx) -> mlir::BlockArgument { ++ [](mlir::func::FuncOp &self, unsigned idx) -> mlir::BlockArgument { + return self.getArgument(idx); + }) + .def( + "add_entry_block", +- [](mlir::FuncOp &self) -> mlir::Block * { ++ [](mlir::func::FuncOp &self) -> mlir::Block * { + return self.addEntryBlock(); + }, + ret::reference) + .def( + "set_arg_attr", +- [](mlir::FuncOp &self, int arg_no, const std::string &name, int val) { ++ [](mlir::func::FuncOp &self, int arg_no, const std::string &name, ++ int val) { + // set arg attributes "name" to value "val" + auto attrTy = mlir::IntegerType::get(self.getContext(), 32); + self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val)); + }, + ret::reference) +- .def_property_readonly("type", &mlir::FuncOp::getType) +- .def("reset_type", &mlir::FuncOp::setType); ++ .def_property_readonly("type", &mlir::func::FuncOp::getFunctionType) ++ .def("reset_type", &mlir::func::FuncOp::setType); + + py::class_(m, "InsertPoint"); + +@@ -445,13 +449,13 @@ void init_triton_ir(py::module &&m) { + .def("ret", + [](mlir::OpBuilder &self, std::vector &vals) -> void { + auto loc = self.getUnknownLoc(); +- self.create(loc, vals); ++ self.create(loc, vals); + }) + .def("call", +- [](mlir::OpBuilder &self, mlir::FuncOp &func, ++ [](mlir::OpBuilder &self, mlir::func::FuncOp &func, + std::vector &args) -> mlir::OpState { + auto loc = self.getUnknownLoc(); +- return self.create(loc, func, args); ++ return self.create(loc, func, args); + }) + // insertion block/point + .def("set_insertion_point_to_start", +@@ -618,15 +622,16 @@ void init_triton_ir(py::module &&m) { + .def("get_or_insert_function", + [](mlir::OpBuilder &self, mlir::ModuleOp &module, + std::string &funcName, mlir::Type &funcType, +- std::string &visibility) -> mlir::FuncOp { ++ std::string &visibility) -> mlir::func::FuncOp { + if (mlir::Operation *funcOperation = module.lookupSymbol(funcName)) +- return llvm::dyn_cast(funcOperation); ++ return llvm::dyn_cast(funcOperation); + auto loc = self.getUnknownLoc(); + if (auto funcTy = funcType.dyn_cast()) { + llvm::SmallVector attrs = { + mlir::NamedAttribute(self.getStringAttr("sym_visibility"), + self.getStringAttr(visibility))}; +- return self.create(loc, funcName, funcTy, attrs); ++ return self.create(loc, funcName, funcTy, ++ attrs); + } + throw std::runtime_error("invalid function type"); + }) +@@ -658,15 +663,15 @@ void init_triton_ir(py::module &&m) { + [](mlir::OpBuilder &self, mlir::Value condition, + mlir::Block *trueDest, mlir::Block *falseDest) { + auto loc = self.getUnknownLoc(); +- self.create(loc, condition, trueDest, +- falseDest); ++ self.create(loc, condition, trueDest, ++ falseDest); + return; + }) + .def("create_branch", + [](mlir::OpBuilder &self, mlir::Block *dest, + std::vector &args) { + auto loc = self.getUnknownLoc(); +- self.create(loc, dest, args); ++ self.create(loc, dest, args); + return; + }) + // Structured control flow +@@ -792,14 +797,14 @@ void init_triton_ir(py::module &&m) { + .def("create_to_index", + [](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value { + auto loc = self.getUnknownLoc(); +- return self.create(loc, input, +- self.getIndexType()); ++ return self.create( ++ loc, self.getIndexType(), input); + }) + .def("create_index_to_si", + [](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value { + auto loc = self.getUnknownLoc(); +- return self.create(loc, input, +- self.getI32Type()); ++ return self.create( ++ loc, self.getI32Type(), input); + }) + .def("create_fmul", + [](mlir::OpBuilder &self, mlir::Value &lhs, +@@ -1316,8 +1321,8 @@ void init_triton_ir(py::module &&m) { + [](mlir::OpBuilder &self, mlir::Value &condition, + mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value { + auto loc = self.getUnknownLoc(); +- return self.create(loc, condition, trueValue, +- falseValue); ++ return self.create(loc, condition, ++ trueValue, falseValue); + }) + .def("create_printf", + [](mlir::OpBuilder &self, const std::string &prefix, +@@ -1429,7 +1434,7 @@ void init_triton_ir(py::module &&m) { + self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass()); + }) + .def("add_scf_to_cfg", [](mlir::PassManager &self) { +- self.addPass(mlir::createLowerToCFGPass()); ++ self.addPass(mlir::createConvertSCFToCFPass()); + }); + } + +diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py +index 432544a8a4..018f544714 100644 +--- a/python/test/unit/language/test_core.py ++++ b/python/test/unit/language/test_core.py +@@ -1918,7 +1918,7 @@ def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'): + #dst = {dst_layout} + """ + """ + module attributes {"triton_gpu.num-warps" = 4 : i32} { +- func public @kernel_0d1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { ++ func.func public @kernel_0d1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<128> : tensor<128x1xi32, #src> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>> + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>> +diff --git a/python/triton/compiler.py b/python/triton/compiler.py +index 5d167634df..c36589037c 100644 +--- a/python/triton/compiler.py ++++ b/python/triton/compiler.py +@@ -1514,14 +1514,14 @@ def make_hash(fn, **kwargs): + return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest() + + +-# - ^\s*func\s+ : match the start of the string, any leading whitespace, the keyword func, ++# - ^\s*func\.func\s+ : match the start of the string, any leading whitespace, the keyword func, + # and any following whitespace + # - (public\s+)? : optionally match the keyword public and any following whitespace + # - (@\w+) : match an @ symbol followed by one or more word characters + # (letters, digits, or underscores), and capture it as group 1 (the function name) + # - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing + # zero or more arguments separated by commas, and capture it as group 2 (the argument list) +-mlir_prototype_pattern = r'^\s*func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$' ++mlir_prototype_pattern = r'^\s*func\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$' + ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" + prototype_pattern = { + "ttir": mlir_prototype_pattern, +diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir +index b3d5673f85..bb21615e68 100644 +--- a/test/Analysis/test-alias.mlir ++++ b/test/Analysis/test-alias.mlir +@@ -11,7 +11,7 @@ + + // CHECK-LABEL: matmul_loop + // There shouldn't be any aliasing with the dot op encoding. +-func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { ++func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> +@@ -36,7 +36,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B + } + + // CHECK-LABEL: alloc +-func @alloc(%A : !tt.ptr) { ++func.func @alloc(%A : !tt.ptr) { + // CHECK: %cst -> %cst + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> +@@ -46,7 +46,7 @@ func @alloc(%A : !tt.ptr) { + } + + // CHECK-LABEL: convert +-func @convert(%A : !tt.ptr) { ++func.func @convert(%A : !tt.ptr) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + // CHECK: %0 -> %0 + %cst1 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED> +@@ -54,7 +54,7 @@ func @convert(%A : !tt.ptr) { + } + + // CHECK-LABEL: trans +-func @trans(%A : !tt.ptr) { ++func.func @trans(%A : !tt.ptr) { + // CHECK: %cst -> %cst + %tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED> + // CHECK: %0 -> %cst +@@ -63,7 +63,7 @@ func @trans(%A : !tt.ptr) { + } + + // CHECK-LABEL: insert_slice_async +-func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { ++func.func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { + %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> + %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> + %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> +@@ -76,7 +76,7 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { + } + + // CHECK-LABEL: insert_slice +-func @insert_slice(%A : !tt.ptr, %i1 : i1) { ++func.func @insert_slice(%A : !tt.ptr, %i1 : i1) { + %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> + %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> + %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> +@@ -90,7 +90,7 @@ func @insert_slice(%A : !tt.ptr, %i1 : i1) { + } + + // CHECK-LABEL: extract_slice +-func @extract_slice(%A : !tt.ptr) { ++func.func @extract_slice(%A : !tt.ptr) { + // CHECK: %cst -> %cst + %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> + %index = arith.constant 0 : index +@@ -100,7 +100,7 @@ func @extract_slice(%A : !tt.ptr) { + } + + // CHECK-LABEL: if_cat +-func @if_cat(%i1 : i1) { ++func.func @if_cat(%i1 : i1) { + // CHECK: %cst -> %cst + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + // CHECK: %cst_0 -> %cst_0 +@@ -119,7 +119,7 @@ func @if_cat(%i1 : i1) { + } + + // CHECK-LABEL: if_alias +-func @if_alias(%i1 : i1) { ++func.func @if_alias(%i1 : i1) { + // CHECK: %cst -> %cst + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + // CHECK-NEXT: %cst_0 -> %cst_0 +@@ -134,7 +134,7 @@ func @if_alias(%i1 : i1) { + } + + // CHECK-LABEL: for +-func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { ++func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + // CHECK: %cst -> %cst + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK-NEXT: %cst_0 -> %cst_0 +@@ -154,7 +154,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.p + } + + // CHECK-LABEL: for_if +-func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { ++func.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + // CHECK: %cst -> %cst + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK-NEXT: %cst_0 -> %cst_0 +@@ -180,7 +180,7 @@ func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t + } + + // CHECK-LABEL: for_if_for +-func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { ++func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + // CHECK: %cst -> %cst + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK-NEXT: %cst_0 -> %cst_0 +diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir +index 0ab34c7a78..af8ea6f856 100644 +--- a/test/Analysis/test-alignment.mlir ++++ b/test/Analysis/test-alignment.mlir +@@ -1,288 +1,288 @@ +-// RUN: triton-opt %s -test-print-alignment -split-input-file 2>&1 | FileCheck %s ++// RUN: triton-opt %s -test-print-alignment -split-input-file -o %t 2>&1 | FileCheck %s + +-// CHECK-LABEL: cast +-func @cast() { +- // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [1] ++// CHECK-LABEL: @cast ++func.func @cast() { ++ // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 + %cst = arith.constant 1 : i32 +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [1] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 + %0 = arith.extsi %cst : i32 to i64 +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 + %cst_tensor = arith.constant dense<1> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 + %1 = tt.bitcast %cst_tensor : tensor<128xi32> -> tensor<128xi64> + return + } + + // ----- + +-// CHECK-LABEL: add +-func @add() { +- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] ++// CHECK-LABEL: @add ++func.func @add() { ++ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 + %1 = arith.constant dense<1> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = + %2 = arith.addi %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [127] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 127 + %3 = arith.constant dense<127> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128] ++ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 + %4 = arith.addi %1, %3 : tensor<128xi32> + return + } + + // ----- + +-// CHECK-LABEL: sub +-func @sub() { +- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] ++// CHECK-LABEL: @sub ++func.func @sub() { ++ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 + %1 = arith.constant dense<1> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = + %2 = arith.subi %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [129] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 129 + %3 = arith.constant dense<129> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128] ++ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 + %4 = arith.subi %3, %1 : tensor<128xi32> + return + } + + // ----- + +-// CHECK-LABEL: mul +-func @mul() { +- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] ++// CHECK-LABEL: @mul ++func.func @mul() { ++ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 + %1 = arith.constant dense<1> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %2 = arith.muli %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128] ++ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 + %3 = arith.constant dense<128> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128] ++ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 + %4 = arith.muli %3, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [2] ++ // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 2 + %5 = arith.constant dense<2> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [256] ; Constancy: [128] ; ConstantValue: [256] ++ // CHECK-NEXT: contiguity = [1], divisibility = [256], constancy = [128], constant_value = 256 + %6 = arith.muli %4, %5 : tensor<128xi32> + return + } + + // ----- + +-// CHECK-LABEL: div +-func @div() { +- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] ++// CHECK-LABEL: @div ++func.func @div() { ++ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 + %1 = arith.constant dense<1> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %2 = arith.divsi %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %3 = arith.divui %1, %0 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64] ++ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 + %4 = arith.constant dense<64> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [64] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = + %5 = arith.divsi %0, %4 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %6 = arith.divsi %4, %0 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64] ++ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 + %7 = arith.divsi %4, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [66] ++ // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66 + %8 = arith.constant dense<66> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [2] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [2], constant_value = + %9 = arith.divui %0, %8 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [8192] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [128], divisibility = [8192], constancy = [1], constant_value = + %10 = tt.make_range {end = 8320 : i32, start = 8192 : i32} : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [64] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [64], constant_value = + %11 = arith.divsi %10, %4 : tensor<128xi32> +- return ++ return + } + + // ----- + +-// CHECK-LABEL: rem +-func @rem() { +- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] ++// CHECK-LABEL: @rem ++func.func @rem() { ++ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 + %1 = arith.constant dense<1> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0] ++ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 + %2 = arith.remsi %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %3 = arith.remui %1, %0 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64] ++ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 + %4 = arith.constant dense<64> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [64] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [64], divisibility = [64], constancy = [1], constant_value = + %5 = arith.remsi %0, %4 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [1], constant_value = + %6 = arith.remsi %4, %0 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [66] ++ // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66 + %7 = arith.constant dense<66> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [2] ; Divisibility: [2] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [2], divisibility = [2], constancy = [1], constant_value = + %8 = arith.remui %0, %7 : tensor<128xi32> +- return ++ return + } + + // ----- + +-// CHECK-LABEL: broadcast +-func @broadcast() { +- // CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64] ++// CHECK-LABEL: @broadcast ++func.func @broadcast() { ++ // CHECK: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 + %0 = arith.constant dense<64> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [64, 1] ; Constancy: [128, 1] ; ConstantValue: [64] ++ // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 1], constant_value = 64 + %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> +- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [64, 1] ; Constancy: [128, 128] ; ConstantValue: [64] ++ // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 128], constant_value = 64 + %2 = tt.broadcast %1 : (tensor<128x1xi32>) -> tensor<128x128xi32> + return + } + + // ----- + +-// CHECK-LABEL: splat +-func @splat(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { +- // CHECK: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 128] ; ConstantValue: [None] ++// CHECK-LABEL: @splat ++func.func @splat(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { ++ // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = + %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x128x!tt.ptr> + return + } + + // ----- + +-// CHECK-LABEL: cmp +-func @cmp() { +- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] ++// CHECK-LABEL: @cmp ++func.func @cmp() { ++ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0] ++ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 + %1 = arith.constant dense<0> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = + %2 = arith.cmpi eq, %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = + %3 = arith.cmpi slt, %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %4 = arith.cmpi sle, %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = + %5 = arith.cmpi sge, %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8] ++ // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 + %6 = arith.constant dense<8> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = + %7 = arith.cmpi sgt, %0, %6 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [0] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 0 + %8 = arith.cmpi sgt, %1, %6 : tensor<128xi32> + return + } + + // ----- + +-// CHECK-LABEL: logic +-func @logic() { +- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] ++// CHECK-LABEL: @logic ++func.func @logic() { ++ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64] ++ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 + %1 = arith.constant dense<64> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [64] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = + %2 = arith.divsi %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8] ++ // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 + %3 = arith.constant dense<8> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [134217728] ; Constancy: [8] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [134217728], constancy = [8], constant_value = + %4 = arith.divsi %0, %3 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %5 = arith.andi %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %6 = arith.ori %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %7 = arith.xori %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = + %8 = arith.andi %2, %4 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = + %9 = arith.ori %2, %4 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = + %10 = arith.xori %2, %4 : tensor<128xi32> + return + } + + // ----- + +-// CHECK-LABEL: select +-func @select() { +- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] ++// CHECK-LABEL: @select ++func.func @select() { ++ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0] ++ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 + %1 = arith.constant dense<0> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = + %2 = arith.cmpi eq, %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = + %3 = arith.cmpi slt, %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [1] ; ConstantValue: [0] ++ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0 + %4 = arith.constant 0 : i1 +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0] ++ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 + %7 = tt.splat %4 : (i1) -> tensor<128xi1> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0] +- %5 = select %4, %3, %7 : tensor<128xi1> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 ++ %5 = arith.select %4, %3, %7 : tensor<128xi1> ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = + %8 = "triton_gpu.select"(%7, %3, %2) : (tensor<128xi1>, tensor<128xi1>, tensor<128xi1>) -> tensor<128xi1> + return + } + + // ----- + +-func @shift() { +- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] ++func.func @shift() { ++ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8] ++ // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 + %1 = arith.constant dense<8> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4] ; Constancy: [128] ; ConstantValue: [4] ++ // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4 + %2 = arith.constant dense<4> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [274877906944] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [274877906944], constancy = [1], constant_value = + %3 = arith.shli %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [67108864] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [67108864], constancy = [1], constant_value = + %4 = arith.shrsi %0, %2 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128] ++ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 + %5 = arith.shli %1, %2 : tensor<128xi32> + return + } + + // ----- + +-func @max_min() { +- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] ++func.func @max_min() { ++ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [128], divisibility = [64], constancy = [1], constant_value = + %1 = tt.make_range {end = 192 : i32, start = 64 : i32} : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %2 = arith.maxsi %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %3 = arith.minsi %0, %1 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8] ++ // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 + %4 = arith.constant dense<8> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4] ; Constancy: [128] ; ConstantValue: [4] ++ // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4 + %5 = arith.constant dense<4> : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [8] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 8 + %6 = arith.maxsi %4, %5 : tensor<128xi32> + return + } + + // ----- + +-// CHECK-LABEL: for +-func @for() { +- // CHECK: Contiguity: [1, 1] ; Divisibility: [4611686018427387904, 4611686018427387904] ; Constancy: [128, 32] ; ConstantValue: [0] ++// CHECK-LABEL: @for ++func.func @for() { ++ // CHECK: contiguity = [1, 1], divisibility = [4611686018427387904, 4611686018427387904], constancy = [128, 32], constant_value = 0 + %a_init = arith.constant dense<0> : tensor<128x32xi32> +- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [1] ++ // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = 1 + %b_init = arith.constant dense<1> : tensor<128x32xi32> +- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [4, 4] ; Constancy: [128, 32] ; ConstantValue: [4] ++ // CHECK-NEXT: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4 + %c_init = arith.constant dense<4> : tensor<128x32xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [128] ++ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128 + %ub = arith.constant 128 : index +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [1] ; ConstantValue: [0] ++ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0 + %lb = arith.constant 0 : index +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [16] ++ // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16 + %step = arith.constant 16 : index + %a, %b, %c = scf.for %iv = %lb to %ub step %step iter_args(%a = %a_init, %b = %b_init, %c = %c_init) -> (tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>) { +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = + %t = arith.index_cast %iv : index to i32 +- // CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [None] +- // CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [None] +- // CHECK: Contiguity: [1, 1] ; Divisibility: [4, 4] ; Constancy: [128, 32] ; ConstantValue: [4] ++ // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = ++ // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = ++ // CHECK: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4 + scf.yield %b, %a, %c : tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32> + } + return +@@ -290,53 +290,53 @@ func @for() { + + // ----- + +-// CHECK-LABEL: permute_2d +-func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { +- // CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 128] ; ConstantValue: [1] ++// CHECK-LABEL: @permute_2d ++func.func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { ++ // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 128], constant_value = 1 + %cst = arith.constant dense : tensor<128x128xi1> +- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> +- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [1073741824, 1] ; Constancy: [1, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = + %2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> +- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = + %3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32> +- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [17179869184, 16] ; Constancy: [1, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1, 1], divisibility = [17179869184, 16], constancy = [1, 1], constant_value = + %4 = arith.muli %2, %3 : tensor<128x1xi32> +- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = + %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x1x!tt.ptr> +- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = + %6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> +- // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [1, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = + %7 = tt.expand_dims %1 {axis = 0 : i32}: (tensor<128xi32>) -> tensor<1x128xi32> +- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = + %8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr>) -> tensor<128x128x!tt.ptr> +- // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [128, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [128, 1], constant_value = + %9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32> +- // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 16], constancy = [1, 1], constant_value = + %10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> +- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [1073741824, 1] ; Constancy: [1, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = + %11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32> +- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = + %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x1x!tt.ptr> +- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = + %13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> +- // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [1, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = + %14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> +- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = + %15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32> +- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 17179869184] ; Constancy: [1, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [1, 1], constant_value = + %16 = arith.muli %14, %15 : tensor<1x128xi32> +- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 128] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 128], constant_value = + %17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr>) -> tensor<128x128x!tt.ptr> +- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 17179869184] ; Constancy: [128, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [128, 1], constant_value = + %18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32> +- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = + %19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> +- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = + %20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32> + tt.store %19, %20, %cst : tensor<128x128xf32> + return +@@ -347,29 +347,29 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t + module { + + // This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer. +-// CHECK-LABEL: store_constant_align +-func @store_constant_align(%addr: !tt.ptr {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) { +- // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ++// CHECK-LABEL: @store_constant_align ++func.func @store_constant_align(%addr: !tt.ptr {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) { ++ // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %pid = tt.get_program_id {axis = 0 : i32} : i32 +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [128] ++ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128 + %c128_i32 = arith.constant 128 : i32 +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = + %1 = arith.muli %pid, %c128_i32 : i32 +- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = + %3 = tt.splat %1 : (i32) -> tensor<128xi32> +- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [128], divisibility = [128], constancy = [1], constant_value = + %4 = arith.addi %3, %2 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [128], constant_value = + %5 = tt.splat %addr : (!tt.ptr) -> tensor<128x!tt.ptr> +- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [128], divisibility = [16], constancy = [1], constant_value = + %6 = tt.addptr %5, %4 : tensor<128x!tt.ptr>, tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [128], constant_value = + %9 = tt.splat %n : (i32) -> tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [16] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [16], constant_value = + %mask = arith.cmpi slt, %4, %9 : tensor<128xi32> +- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ++ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %cst = arith.constant dense<0.0> : tensor<128xf32> + tt.store %5, %cst, %mask : tensor<128xf32> + return +@@ -381,8 +381,8 @@ func @store_constant_align(%addr: !tt.ptr {tt.divisibility = 16 : i32}, %n: + + // This IR is dumped from vecadd test. + // Note, the hint {tt.divisibility = 16 : i32} for %n_elements affects the alignment of mask. +-// CHECK-LABEL: vecadd_mask_align_16 +-func @vecadd_mask_align_16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { ++// CHECK-LABEL: @vecadd_mask_align_16 ++func.func @vecadd_mask_align_16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.muli %0, %c64_i32 : i32 +@@ -394,13 +394,13 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %ar + %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + %9 = tt.splat %n_elements : (i32) -> tensor<64xi32> +- // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [16] ; ConstantValue: [None] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> ) ++ // CHECK: arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [16], constant_value = + %mask = arith.cmpi slt, %4, %9 : tensor<64xi32> + %11 = tt.load %6, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> + %12 = tt.load %8, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> + %13 = arith.addf %11, %12 : tensor<64xf32> + %14 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x!tt.ptr> +- // CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr>, tensor<64xi32> ) ++ // CHECK: tt.addptr %{{.*}} => contiguity = [64], divisibility = [16], constancy = [1], constant_value = + %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %15, %13, %mask : tensor<64xf32> + return +@@ -410,8 +410,8 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %ar + + // This IR is dumped from vecadd test. + // Note, there is no divisibility hint for %n_elements, Triton should assume its divisibility to be 1 by default. +-// CHECK-LABEL: vecadd_mask_align_1 +-func @vecadd_mask_align_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32) { ++// CHECK-LABEL: @vecadd_mask_align_1 ++func.func @vecadd_mask_align_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32) { + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.muli %0, %c64_i32 : i32 +@@ -423,7 +423,7 @@ func @vecadd_mask_align_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg + %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + %9 = tt.splat %n_elements : (i32) -> tensor<64xi32> +- // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> ) ++ // CHECK: arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %10 = arith.cmpi slt, %4, %9 : tensor<64xi32> + %11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> + %12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> +diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir +index efb00c404d..f79222aa7b 100644 +--- a/test/Analysis/test-allocation.mlir ++++ b/test/Analysis/test-allocation.mlir +@@ -13,7 +13,7 @@ + module attributes {"triton_gpu.num-warps" = 4 : i32} { + + // CHECK-LABEL: matmul_loop +-func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { ++func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> + +@@ -46,7 +46,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B + + // Shared memory is available after a tensor's liveness range ends + // CHECK-LABEL: reusable +-func @reusable(%A : !tt.ptr) { ++func.func @reusable(%A : !tt.ptr) { + %cst1 = arith.constant dense : tensor<128x32xi1, #AL> + %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %cst3 = arith.constant dense : tensor<32x128xi1, #AL> +@@ -78,7 +78,7 @@ func @reusable(%A : !tt.ptr) { + // %cst1->%cst4 + // %cst3->%g->%h->%i + // CHECK-LABEL: preallocate +-func @preallocate(%A : !tt.ptr) { ++func.func @preallocate(%A : !tt.ptr) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + // CHECK-NEXT: offset = 1024, size = 512 +@@ -113,7 +113,7 @@ func @preallocate(%A : !tt.ptr) { + + // Unused tensors are immediately released + // CHECK-LABEL: unused +-func @unused(%A : !tt.ptr) { ++func.func @unused(%A : !tt.ptr) { + // CHECK: offset = 0, size = 1024 + %cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A_SHARED> + // CHECK-NEXT: offset = 0, size = 512 +@@ -128,7 +128,7 @@ func @unused(%A : !tt.ptr) { + + // cst0 is alive through the entire function, it cannot be released before the end of the function + // CHECK-LABEL: longlive +-func @longlive(%A : !tt.ptr) { ++func.func @longlive(%A : !tt.ptr) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + // CHECK-NEXT: offset = 512, size = 512 +@@ -156,7 +156,7 @@ func @longlive(%A : !tt.ptr) { + } + + // CHECK-LABEL: alloc +-func @alloc(%A : !tt.ptr) { ++func.func @alloc(%A : !tt.ptr) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> +@@ -167,7 +167,7 @@ func @alloc(%A : !tt.ptr) { + } + + // CHECK-LABEL: scratch +-func @scratch() { ++func.func @scratch() { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + // CHECK: scratch offset = 0, size = 512 + %b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0> +@@ -176,7 +176,7 @@ func @scratch() { + } + + // CHECK-LABEL: trans +-func @trans(%A : !tt.ptr) { ++func.func @trans(%A : !tt.ptr) { + // CHECK: offset = 0, size = 1024 + %tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED> + %b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T> +@@ -184,7 +184,7 @@ func @trans(%A : !tt.ptr) { + } + + // CHECK-LABEL: insert_slice_async +-func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { ++func.func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { + %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> + %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> + %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> +@@ -197,7 +197,7 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { + } + + // CHECK-LABEL: extract_slice +-func @extract_slice(%A : !tt.ptr) { ++func.func @extract_slice(%A : !tt.ptr) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> + %index = arith.constant 0 : index +@@ -209,7 +209,7 @@ func @extract_slice(%A : !tt.ptr) { + // B0 -> (B1) -> B0 + // Memory used by B1 can be reused by B0. + // CHECK-LABEL: if +-func @if(%i1 : i1) { ++func.func @if(%i1 : i1) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + // CHECK-NEXT: offset = 512, size = 512 +@@ -233,7 +233,7 @@ func @if(%i1 : i1) { + // B0 -> (B1) -> (B2) -> B0 + // Memory used by B0 cannot be reused by B1 or B2. + // CHECK-LABEL: if_else +-func @if_else(%i1 : i1) { ++func.func @if_else(%i1 : i1) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + // CHECK-NEXT: offset = 512, size = 512 +@@ -260,7 +260,7 @@ func @if_else(%i1 : i1) { + // Block arguments and yields are memory aliases that do not trigger a new + // allocation. + // CHECK-LABEL: for +-func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { ++func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + // CHECK: offset = 0, size = 8192 + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK-NEXT: offset = 8192, size = 8192 +@@ -275,7 +275,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.p + } + + // CHECK-LABEL: for_if_slice +-func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { ++func.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + // CHECK: offset = 0, size = 8192 + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK-NEXT: offset = 8192, size = 8192 +@@ -296,7 +296,7 @@ func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr, % + + // c0 cannot be released in the loop + // CHECK-LABEL: for_use_ancestor +-func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { ++func.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + // CHECK: offset = 0, size = 8192 + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK-NEXT: offset = 8192, size = 8192 +@@ -316,7 +316,7 @@ func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { ++func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + // CHECK: offset = 0, size = 8192 + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK-NEXT: offset = 8192, size = 8192 +diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir +index 7199e5f53d..17880b2094 100644 +--- a/test/Analysis/test-membar.mlir ++++ b/test/Analysis/test-membar.mlir +@@ -14,7 +14,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + + // CHECK-LABEL: matmul_loop + // There shouldn't be any membar with the dot op encoding. +-func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { ++func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> + +@@ -42,7 +42,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B + } + + // CHECK-LABEL: raw_single_block +-func @raw_single_block(%A : !tt.ptr) { ++func.func @raw_single_block(%A : !tt.ptr) { + %cst1 = arith.constant dense : tensor<128x32xi1, #AL> + %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> +@@ -54,7 +54,7 @@ func @raw_single_block(%A : !tt.ptr) { + } + + // CHECK-LABEL: war_single_block +-func @war_single_block(%A : !tt.ptr) { ++func.func @war_single_block(%A : !tt.ptr) { + %cst1 = arith.constant dense : tensor<128x32xi1, #AL> + %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> +@@ -70,7 +70,7 @@ func @war_single_block(%A : !tt.ptr) { + } + + // CHECK-LABEL: scratch +-func @scratch() { ++func.func @scratch() { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + // CHECK: Membar 1 + %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> +@@ -81,7 +81,7 @@ func @scratch() { + } + + // CHECK-LABEL: async_wait +-func @async_wait() { ++func.func @async_wait() { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + // CHECK: Membar 1 + %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> +@@ -92,7 +92,7 @@ func @async_wait() { + } + + // CHECK-LABEL: alloc +-func @alloc() { ++func.func @alloc() { + %cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED> + %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + // CHECK: Membar 2 +@@ -101,7 +101,7 @@ func @alloc() { + } + + // CHECK-LABEL: extract_slice +-func @extract_slice() { ++func.func @extract_slice() { + %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> + %index = arith.constant 0 : index + %cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED> +@@ -113,14 +113,14 @@ func @extract_slice() { + } + + // CHECK-LABEL: trans +-func @trans() { ++func.func @trans() { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED> + %b = tt.trans %cst0 : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T> + return + } + + // CHECK-LABEL: insert_slice_async +-func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { ++func.func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { + %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> + %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> + %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> +@@ -135,7 +135,7 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { + } + + // CHECK-LABEL: insert_slice +-func @insert_slice(%A : !tt.ptr, %i1 : i1) { ++func.func @insert_slice(%A : !tt.ptr, %i1 : i1) { + %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> + %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> + %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> +@@ -153,7 +153,7 @@ func @insert_slice(%A : !tt.ptr, %i1 : i1) { + + // If branch inserted a barrier for %cst0 and %cst1, but else didn't, then the barrier should be inserted in the parent region + // CHECK-LABEL: multi_blocks +-func @multi_blocks(%i1 : i1) { ++func.func @multi_blocks(%i1 : i1) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + scf.if %i1 { +@@ -174,7 +174,7 @@ func @multi_blocks(%i1 : i1) { + + // Both branches inserted a barrier for %cst0 and %cst1, then the barrier doesn't need to be inserted in the parent region + // CHECK-LABEL: multi_blocks_join_barrier +-func @multi_blocks_join_barrier(%i1 : i1) { ++func.func @multi_blocks_join_barrier(%i1 : i1) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + scf.if %i1 { +@@ -192,7 +192,7 @@ func @multi_blocks_join_barrier(%i1 : i1) { + + // Read yielded tensor requires a barrier + // CHECK-LABEL: multi_blocks_yield +-func @multi_blocks_yield(%i1 : i1) { ++func.func @multi_blocks_yield(%i1 : i1) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) { +@@ -212,7 +212,7 @@ func @multi_blocks_yield(%i1 : i1) { + + // Conservatively add a barrier as if the branch (%i1) is never taken + // CHECK-LABEL: multi_blocks_noelse +-func @multi_blocks_noelse(%i1 : i1) { ++func.func @multi_blocks_noelse(%i1 : i1) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + scf.if %i1 { +@@ -226,7 +226,7 @@ func @multi_blocks_noelse(%i1 : i1) { + + // Conservatively add a barrier as if the branch (%i2) is never taken + // CHECK-LABEL: multi_blocks_nested_scf +-func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { ++func.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + scf.if %i1 { +@@ -247,7 +247,7 @@ func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { + } + + // CHECK-LABEL: for +-func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { ++func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> +@@ -262,7 +262,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.p + // Although a_shared and b_shared are synced before entering the loop, + // they are reassociated with aliases (c_shared) and thus require a barrier. + // CHECK-LABEL: for_alias +-func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { ++func.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK-NEXT: Membar 2 +@@ -282,7 +282,7 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : + // Although cst2 is not an argument of scf.yield, its memory is reused by cst1. + // So we need a barrier both before and after cst1 + // CHECK-LABEL: for_reuse +-func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { ++func.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK-NEXT: Membar 2 +@@ -302,7 +302,7 @@ func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : + + + // CHECK-LABEL: for_reuse_nested +-func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { ++func.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK-NEXT: Membar 2 +diff --git a/test/Conversion/triton_ops.mlir b/test/Conversion/triton_ops.mlir +index e9ee502435..0e979b148d 100644 +--- a/test/Conversion/triton_ops.mlir ++++ b/test/Conversion/triton_ops.mlir +@@ -1,6 +1,6 @@ + // RUN: triton-opt %s | FileCheck %s + +-func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) { ++func.func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) { + // scalar -> scalar + // CHECK: i64 -> !tt.ptr + %0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr +@@ -35,7 +35,7 @@ func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) { + return + } + +-func @addptr_ops(%scalar_ptr: !tt.ptr, %scalar_i32: i32) { ++func.func @addptr_ops(%scalar_ptr: !tt.ptr, %scalar_i32: i32) { + // scalar -> scalar + // CHECK: !tt.ptr + %0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr, i32 +@@ -54,7 +54,7 @@ func @addptr_ops(%scalar_ptr: !tt.ptr, %scalar_i32: i32) { + return + } + +-func @load_store_ops_scalar(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, %mask : i1) { ++func.func @load_store_ops_scalar(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, %mask : i1) { + // Test if Load/Store ops can handle scalar values + %other = arith.constant 0.0e+0 : f32 + +@@ -76,7 +76,7 @@ func @load_store_ops_scalar(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, %ma + return + } + +-func @reduce_ops_infer(%ptr: !tt.ptr, %v : tensor<1x2x4xf32>) { ++func.func @reduce_ops_infer(%ptr: !tt.ptr, %v : tensor<1x2x4xf32>) { + // Test if reduce ops infer types correctly + + // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32> +@@ -101,7 +101,7 @@ func @reduce_ops_infer(%ptr: !tt.ptr, %v : tensor<1x2x4xf32>) { + return + } + +-func @dot_ops_infer(%ptr: !tt.ptr, %v : f32) { ++func.func @dot_ops_infer(%ptr: !tt.ptr, %v : f32) { + // Test if reduce ops infer types correctly + %v128x32 = tt.splat %v : (f32) -> tensor<128x32xf32> + %v32x128 = tt.splat %v : (f32) -> tensor<32x128xf32> +diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir +index a160bc8815..b461ca542f 100644 +--- a/test/Conversion/triton_to_tritongpu.mlir ++++ b/test/Conversion/triton_to_tritongpu.mlir +@@ -1,6 +1,6 @@ + // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s + +-func @ops() { ++func.func @ops() { + // CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}} + %a = arith.constant dense<1.00e+00> : tensor<128x32xf16> + %b = arith.constant dense<2.00e+00> : tensor<32x128xf16> +@@ -11,7 +11,7 @@ func @ops() { + + // ----- + +-func @load_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { ++func.func @load_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { + // Test if LoadOp is lowered properly (see #771) + %ptrs = tt.splat %ptr : (!tt.ptr) -> tensor<128x!tt.ptr> + %mask = arith.constant dense : tensor<128xi1> +@@ -30,7 +30,7 @@ func @load_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { + + // ----- + +-func @reduce_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { ++func.func @reduce_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { + // Test if the total number of threadsPerWarp is 32 + // Test if the total number of warps is 2 + // CHECK: #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}> +diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir +index e9e7d5a340..507b362c99 100644 +--- a/test/Conversion/tritongpu_to_llvm.mlir ++++ b/test/Conversion/tritongpu_to_llvm.mlir +@@ -4,7 +4,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr) + // Here the 128 comes from the 4 in module attribute multiples 32 + // CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : i32} {{.*}} +- func @test_empty_kernel(%lb : index, %A : !tt.ptr) { ++ func.func @test_empty_kernel(%lb : index, %A : !tt.ptr) { + // CHECK: llvm.return + return + } +@@ -15,7 +15,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_load +- func @basic_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { ++ func.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK: llvm.inline_asm + %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> +@@ -28,7 +28,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: vectorized_load +- func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { ++ func.func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: ld.global.b32 + // CHECK: llvm.inline_asm +@@ -43,7 +43,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: vectorized_load_f16 +- func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) { ++ func.func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: ld.global.b16 + // CHECK: llvm.inline_asm +@@ -59,7 +59,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: masked_load_const_other +- func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { ++ func.func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> + %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> + return +@@ -72,7 +72,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: masked_load_const_other_vec +- func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { ++ func.func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> + %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> + return +@@ -84,7 +84,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> + module attributes {"triton_gpu.num-warps" = 2 : i32} { + // CHECK-LABEL: global_load_store_no_vec +- func @global_load_store_no_vec(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr {tt.divisibility = 4 : i32}, %arg2: !tt.ptr {tt.divisibility = 4 : i32}, %arg3: i32) { ++ func.func @global_load_store_no_vec(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr {tt.divisibility = 4 : i32}, %arg2: !tt.ptr {tt.divisibility = 4 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.muli %0, %c256_i32 : i32 +@@ -128,7 +128,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> + module attributes {"triton_gpu.num-warps" = 2 : i32} { + // CHECK-LABEL: global_load_store_vec4 +- func @global_load_store_vec4(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { ++ func.func @global_load_store_vec4(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.muli %0, %c256_i32 : i32 +@@ -165,7 +165,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { + #blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> + // Note, the %n_elements doesn't have a "tt.divisibility" hint, so Triton assumes it's divisibility is 1, this should effect the mask's alignment and further restrict the load/store ops' vector width to be 1. + module attributes {"triton_gpu.num-warps" = 2 : i32} { +- func @vecadd_masked_vec1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32) { ++ func.func @vecadd_masked_vec1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32) { + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.muli %0, %c64_i32 : i32 +@@ -195,7 +195,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: global_load_store_vec2 +- func @global_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 8 : i32}, %arg1: !tt.ptr {tt.divisibility = 8 : i32}, %arg2: !tt.ptr {tt.divisibility = 8 : i32}, %arg3: i32) { ++ func.func @global_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 8 : i32}, %arg1: !tt.ptr {tt.divisibility = 8 : i32}, %arg2: !tt.ptr {tt.divisibility = 8 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.muli %0, %c256_i32 : i32 +@@ -240,7 +240,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: global_load_store_vec8 +- func @global_load_store_vec8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { ++ func.func @global_load_store_vec8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.muli %0, %c256_i32 : i32 +@@ -283,7 +283,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { + #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_view_broadcast +- func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) { ++ func.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) { + // CHECK: llvm.mlir.undef + // CHECK: %[[T0:.*]] = llvm.extractvalue + // CHECK: %[[T1:.*]] = llvm.extractvalue +@@ -307,7 +307,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_make_range +- func @basic_make_range() { ++ func.func @basic_make_range() { + // CHECK: nvvm.read.ptx.sreg.tid.x + // CHECK: llvm.mlir.undef + // CHECK: llvm.insertvalue +@@ -322,7 +322,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_addf +- func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) { ++ func.func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) { + // CHECK: llvm.fadd + // CHECK: llvm.fadd + %1 = arith.addf %arg0, %arg1 : tensor<256xf32,#blocked0> +@@ -335,7 +335,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_addi +- func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { ++ func.func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { + // CHECK: llvm.add + // CHECK: llvm.add + %1 = arith.addi %arg0, %arg1 : tensor<256xi32,#blocked0> +@@ -347,7 +347,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_program_id +- func @basic_program_id() { ++ func.func @basic_program_id() { + // CHECK: nvvm.read.ptx.sreg.ctaid.x : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + return +@@ -359,7 +359,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_addptr +- func @basic_addptr(%arg0 : tensor<256x!tt.ptr,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { ++ func.func @basic_addptr(%arg0 : tensor<256x!tt.ptr,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { + // CHECK: llvm.getelementptr + // CHECK: llvm.getelementptr + %0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> +@@ -373,7 +373,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK: llvm.mlir.global external @global_smem + // CHECK-LABEL: basic_alloc_tensor +- func @basic_alloc_tensor() { ++ func.func @basic_alloc_tensor() { + // CHECK: llvm.mlir.addressof @global_smem + // CHECK-NEXT: llvm.bitcast + // CHECK-NEXT: llvm.mlir.constant +@@ -390,7 +390,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK: llvm.mlir.global external @global_smem + // CHECK-LABEL: basic_extract_slice +- func @basic_extract_slice() { ++ func.func @basic_extract_slice() { + // CHECK: llvm.mlir.addressof @global_smem + // CHECK: llvm.extractvalue + // CHECK-NEXT: llvm.extractvalue +@@ -423,7 +423,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_async_wait +- func @basic_async_wait() { ++ func.func @basic_async_wait() { + // CHECK: cp.async.wait_group 0x4 + triton_gpu.async_wait {num = 4: i32} + return +@@ -442,7 +442,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_insert_slice_async_fallback +- func @basic_insert_slice_async_fallback(%arg0: !tt.ptr {tt.divisibility = 1 : i32}) { ++ func.func @basic_insert_slice_async_fallback(%arg0: !tt.ptr {tt.divisibility = 1 : i32}) { + %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> + %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0> + %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2> +@@ -481,7 +481,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_insert_slice_async_v4 +- func @basic_insert_slice_async_v4(%arg0: !tt.ptr {tt.divisibility = 32 : i32}) { ++ func.func @basic_insert_slice_async_v4(%arg0: !tt.ptr {tt.divisibility = 32 : i32}) { + %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> + %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0> + %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2> +@@ -523,7 +523,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_insert_slice_async_v1 +- func @basic_insert_slice_async_v1(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { ++ func.func @basic_insert_slice_async_v1(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { + %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> + %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0> + %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2> +@@ -568,7 +568,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_insert_slice_async_v1_multictas +- func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { ++ func.func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { + %off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice2d1> + %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0> + %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #slice2d1>) -> tensor<32x1xi32, #block2> +@@ -619,7 +619,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK: basic_splat +- func @basic_splat(%ptr: !tt.ptr) { ++ func.func @basic_splat(%ptr: !tt.ptr) { + // CHECK: llvm.mlir.undef + // CHECK: llvm.insertvalue + // CHECK: llvm.insertvalue +@@ -633,7 +633,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_store +- func @basic_store(%ptrs: tensor<256x!tt.ptr, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) { ++ func.func @basic_store(%ptrs: tensor<256x!tt.ptr, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; + // CHECK: llvm.inline_asm +@@ -650,7 +650,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> + // CHECK-LABEL: convert_layout_blocked_blocked +- func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { ++ func.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { + // CHECK: llvm.mlir.addressof @global_smem + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> +@@ -697,7 +697,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { + module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> + // CHECK-LABEL: convert_layout_blocked_blocked_vec +- func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) { ++ func.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) { + // CHECK: llvm.mlir.addressof @global_smem + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> +@@ -720,7 +720,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { + module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> + // CHECK-LABEL: convert_layout_blocked_blocked_multi_rep +- func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { ++ func.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { + // CHECK: llvm.mlir.addressof @global_smem + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> +@@ -751,7 +751,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { + #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0}> + module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_dot +- func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { ++ func.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { + %AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0> + %BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0> + // CHECK: llvm.inline_asm +@@ -775,7 +775,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { + // TODO: problems in MLIR's parser on slice layout + // #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> + // module attributes {"triton_gpu.num-warps" = 1 : i32} { +-// func @make_range_sliced_layout() { ++// func.func @make_range_sliced_layout() { + // %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> + // return + // } +@@ -788,7 +788,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { + module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> + // CHECK-LABEL: convert_layout_mmav2_block +- func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) { ++ func.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) { + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.store +@@ -808,7 +808,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { + module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> + // CHECK-LABEL: convert_layout_mmav1_block +- func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) { ++ func.func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) { + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.store +@@ -831,7 +831,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { + module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> + // CHECK-LABEL: convert_layout_blocked_shared +- func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) { ++ func.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) { + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.store +@@ -847,7 +847,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { + #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> + module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_blocked1d_to_slice0 +- func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { ++ func.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { + // CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr, 3> + %cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + return +@@ -860,7 +860,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { + #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> + module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_blocked1d_to_slice1 +- func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { ++ func.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { + // CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr, 3> + %cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + return +@@ -873,7 +873,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { + #blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_blocked_to_blocked_ptr +- func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr, #blocked0>) { ++ func.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr, #blocked0>) { + // CHECK: llvm.ptrtoint + // CHECK: llvm.store + // CHECK: nvvm.barrier0 +@@ -892,7 +892,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { + #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}> + #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { +- func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, ++ func.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + %a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 +@@ -918,7 +918,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}> + #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { +- func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, ++ func.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + %a:tensor<32x64xf16, #shared0>, %b:tensor<64x64xf16, #shared1>) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #mma> + // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 +@@ -941,7 +941,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#blocked}> + #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { +- func @matmul_fmadot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, ++ func.func @matmul_fmadot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + %a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> + // CHECK: llvm.intr.fmuladd +@@ -965,7 +965,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: matmul_tf32dot +- func @matmul_tf32dot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, ++ func.func @matmul_tf32dot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + %a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + // CHECK: llvm.inline_asm +@@ -1000,7 +1000,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_f32 +- func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { ++ func.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: atom.global.gpu.add.f32 + %0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0> +@@ -1012,7 +1012,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + +-func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { ++func.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { + %blockidx = tt.get_program_id {axis=0:i32} : i32 + %blockidy = tt.get_program_id {axis=1:i32} : i32 + %blockidz = tt.get_program_id {axis=2:i32} : i32 +@@ -1032,7 +1032,7 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { + // ----- + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { +- func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { ++ func.func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { + // CHECK: nvvm.read.ptx.sreg.nctaid.x + // CHECK: nvvm.read.ptx.sreg.nctaid.y + // CHECK: nvvm.read.ptx.sreg.nctaid.z +@@ -1052,7 +1052,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: test_index_cache +- func @test_index_cache() { ++ func.func @test_index_cache() { + // CHECK: nvvm.read.ptx.sreg.tid.x + %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + // CHECK-NOT: nvvm.read.ptx.sreg.tid.x +@@ -1066,7 +1066,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + #shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> + module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: test_base_index_cache +- func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { ++ func.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { + // CHECK: nvvm.read.ptx.sreg.tid.x + %0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> + // CHECK-NOT: nvvm.read.ptx.sreg.tid.x +@@ -1080,7 +1080,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { + #shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> + module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: test_index_cache_different_block +- func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { ++ func.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { + // CHECK: nvvm.read.ptx.sreg.tid.x + %0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> + scf.if %arg1 { +diff --git a/test/Target/tritongpu_to_llvmir.mlir b/test/Target/tritongpu_to_llvmir.mlir +index cafff3ca60..114d3a9eb2 100644 +--- a/test/Target/tritongpu_to_llvmir.mlir ++++ b/test/Target/tritongpu_to_llvmir.mlir +@@ -4,11 +4,11 @@ + // CHECK-LABEL: ; ModuleID = 'LLVMDialectModule' + // CHECK: define void @test_empty_kernel + // CHECK: !nvvm.annotations +-// CHECK: !{void (i32, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128} ++// CHECK: !{ptr @test_empty_kernel, !"maxntidx", i32 128} + + module attributes {"triton_gpu.num-warps" = 4 : i32} { + +-func @test_empty_kernel(%lb : index, %A : !tt.ptr) { ++func.func @test_empty_kernel(%lb : index, %A : !tt.ptr) { + + return + } +diff --git a/test/Target/tritongpu_to_ptx.mlir b/test/Target/tritongpu_to_ptx.mlir +index 404e970a29..12742ad9e2 100644 +--- a/test/Target/tritongpu_to_ptx.mlir ++++ b/test/Target/tritongpu_to_ptx.mlir +@@ -6,7 +6,7 @@ + + module attributes {"triton_gpu.num-warps" = 4 : i32} { + +-func @test_empty_kernel(%lb : index, %A : !tt.ptr) { ++func.func @test_empty_kernel(%lb : index, %A : !tt.ptr) { + + return + } +diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir +index 050a3f7565..5ef6790e69 100644 +--- a/test/Triton/combine.mlir ++++ b/test/Triton/combine.mlir +@@ -2,10 +2,10 @@ + // RUN: triton-opt %s -split-input-file -canonicalize -triton-combine | FileCheck %s + + // CHECK-LABEL: @test_combine_dot_add_pattern +-func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) { +- // CHECK: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32> +- // CHECK: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32> +- // CHECK: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32> ++func.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) { ++ // CHECK-DAG: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32> ++ // CHECK-DAG: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32> ++ // CHECK-DAG: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32> + %a = arith.constant dense<1.0> : tensor<128x128xf32> + %b = arith.constant dense<2.0> : tensor<128x128xf32> + %zero = arith.constant dense<0.0> : tensor<128x128xf32> +@@ -24,7 +24,7 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32 + + + // COM: CHECK-LABEL: @test_combine_addptr_pattern +-func @test_combine_addptr_pattern(%base: !tt.ptr) -> tensor<8x!tt.ptr> { ++func.func @test_combine_addptr_pattern(%base: !tt.ptr) -> tensor<8x!tt.ptr> { + %off0 = arith.constant 10 : i32 + %off1 = arith.constant 15 : i32 + +@@ -47,46 +47,46 @@ func @test_combine_addptr_pattern(%base: !tt.ptr) -> tensor<8x!tt.ptr> + + + // CHECK-LABEL: @test_combine_select_masked_load_pattern +-func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) { ++func.func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) { + %mask = tt.broadcast %cond : (i1) -> tensor<8xi1> + %false_val = arith.constant dense<0.0> : tensor<8xf32> + + // CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> + %x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> +- %0 = select %cond, %x, %false_val : tensor<8xf32> ++ %0 = arith.select %cond, %x, %false_val : tensor<8xf32> + + // CHECK: %[[res2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> + %y = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> +- %1 = select %cond, %y, %false_val : tensor<8xf32> ++ %1 = arith.select %cond, %y, %false_val : tensor<8xf32> + + // CHECK: return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32> + return %0, %1 : tensor<8xf32>, tensor<8xf32> + } + + // CHECK-LABEL: @test_combine_select_masked_load_fail_pattern +-func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { ++func.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { + %false_val = arith.constant dense<0.0> : tensor<8xf32> + + // Case 1: value at the "load" position is not an "op". Select should not be canonicalized. +- // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> +- %0 = select %cond0, %dummy_load, %false_val : tensor<8xf32> ++ // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> ++ %0 = arith.select %cond0, %dummy_load, %false_val : tensor<8xf32> + + // Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized. + %real_load0 = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> +- // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> +- %1 = select %cond0, %real_load0, %false_val : tensor<8xf32> ++ // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> ++ %1 = arith.select %cond0, %real_load0, %false_val : tensor<8xf32> + + // Case 3: condition of "broadcast" is not the same as the condition of "select". Select should not be canonicalized. + %cond0_ = tt.broadcast %cond0 : (i1) -> tensor<8xi1> + %real_load1 = tt.load %ptr, %cond0_, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> +- // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> +- %2 = select %cond1, %real_load1, %false_val : tensor<8xf32> ++ // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> ++ %2 = arith.select %cond1, %real_load1, %false_val : tensor<8xf32> + + return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32> + } + + // CHECK-LABEL: @test_combine_broadcast_constant_pattern +-func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { ++func.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { + // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32> + %const = arith.constant dense<1.0> : tensor<8xf32> + %bst_out = tt.broadcast %const : (tensor<8xf32>) -> tensor<8x2xf32> +@@ -96,7 +96,7 @@ func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { + } + + // CHECK-LABEL: @test_canonicalize_masked_load_pattern +-func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { ++func.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { + %true_mask = arith.constant dense : tensor<8xi1> + %false_mask = arith.constant dense : tensor<8xi1> + %other_val = arith.constant dense<0.0> : tensor<8xf32> +@@ -117,7 +117,7 @@ func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr>) -> (te + } + + // CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern +-func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) { ++func.func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) { + %other_val = arith.constant dense<0.0> : tensor<8xf32> + + // Case: value at the "mask" position is not an "op". Load should not be canonicalized. +@@ -130,7 +130,7 @@ func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr>, % + } + + // CHECK-LABEL: @test_canonicalize_masked_store_pattern +-func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr>, %val: tensor<8xf32>) { ++func.func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr>, %val: tensor<8xf32>) { + %true_mask = arith.constant dense : tensor<8xi1> + %false_mask = arith.constant dense : tensor<8xi1> + +@@ -144,7 +144,7 @@ func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr>, %val: + } + + // CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern +-func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr>, %val: tensor<8xf32>, %mask: tensor<8xi1>) { ++func.func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr>, %val: tensor<8xf32>, %mask: tensor<8xi1>) { + // Case: value at the "mask" position is not an "op". Store should not be canonicalized. + // CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> + tt.store %ptr, %val, %mask : tensor<8xf32> +diff --git a/test/Triton/vecadd.mlir b/test/Triton/vecadd.mlir +index 0b69ef3054..f5019b1cdd 100644 +--- a/test/Triton/vecadd.mlir ++++ b/test/Triton/vecadd.mlir +@@ -1,7 +1,7 @@ + // RUN: triton-opt %s -verify-diagnostics + + module { +- func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { ++ func.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %c256_i32 = arith.constant 256 : i32 + %1 = arith.muli %0, %c256_i32 : i32 +@@ -43,7 +43,7 @@ module { + } + } + // module { +-// func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { ++// func.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { + // %c64 = arith.constant 64 : index + // %c32 = arith.constant 32 : index + // %c0 = arith.constant 0 : index +diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir +index 60e359f527..51cccccfbd 100644 +--- a/test/TritonGPU/coalesce.mlir ++++ b/test/TritonGPU/coalesce.mlir +@@ -19,7 +19,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]> + // CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]> + // CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]] +-func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, ++func.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}) { +diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir +index 2c009ffa48..7e9cb9d504 100644 +--- a/test/TritonGPU/combine.mlir ++++ b/test/TritonGPU/combine.mlir +@@ -9,7 +9,7 @@ + // CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> + // CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> + // CHECK-LABEL: cst +-func @cst() -> tensor<1024xi32, #layout1> { ++func.func @cst() -> tensor<1024xi32, #layout1> { + %cst = arith.constant dense<0> : tensor<1024xi32, #layout0> + %1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> + // CHECK-NOT: triton_gpu.convert_layout +@@ -18,7 +18,7 @@ func @cst() -> tensor<1024xi32, #layout1> { + } + + // CHECK-LABEL: range +-func @range() -> tensor<1024xi32, #layout1> { ++func.func @range() -> tensor<1024xi32, #layout1> { + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> + %1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> + // CHECK-NOT: triton_gpu.convert_layout +@@ -27,7 +27,7 @@ func @range() -> tensor<1024xi32, #layout1> { + } + + // CHECK-LABEL: splat +-func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> { ++func.func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> { + %0 = tt.splat %arg0 : (i32) -> tensor<1024xi32, #layout0> + %1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> + // CHECK-NOT: triton_gpu.convert_layout +@@ -36,7 +36,7 @@ func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> { + } + + // CHECK-LABEL: remat +-func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { ++func.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> + %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> + %2 = arith.muli %0, %1 : tensor<1024xi32, #layout0> +@@ -56,7 +56,7 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { + } + + // CHECK-LABEL: remat_load_store +-func @remat_load_store(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { ++func.func @remat_load_store(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0> + %1 = tt.splat %arg : (!tt.ptr) -> tensor<64x!tt.ptr, #layout0> + %2 = tt.addptr %1, %0 : tensor<64x!tt.ptr, #layout0>, tensor<64xi32, #layout0> +@@ -70,7 +70,7 @@ func @remat_load_store(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { + + // Don't rematerialize vectorized loads + // CHECK-LABEL: remat_expensive +-func @remat_expensive(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { ++func.func @remat_expensive(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout1> + %1 = tt.splat %arg : (!tt.ptr) -> tensor<64x!tt.ptr, #layout1> + %2 = tt.addptr %1, %0 : tensor<64x!tt.ptr, #layout1>, tensor<64xi32, #layout1> +@@ -85,7 +85,7 @@ func @remat_expensive(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { + + // Don't rematerialize loads when original and target layouts are different + // CHECK-LABEL: remat_multi_layout +-func @remat_multi_layout(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { ++func.func @remat_multi_layout(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0> + %1 = tt.splat %arg : (!tt.ptr) -> tensor<64x!tt.ptr, #layout0> + %2 = tt.addptr %1, %0 : tensor<64x!tt.ptr, #layout0>, tensor<64xi32, #layout0> +@@ -100,7 +100,7 @@ func @remat_multi_layout(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { + + // Always rematerialize single value loads + // CHECK-LABEL: remat_single_value +-func @remat_single_value(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { ++func.func @remat_single_value(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { + %0 = tt.splat %arg : (!tt.ptr) -> tensor<1x!tt.ptr, #layout1> + %1 = tt.load %0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1xi32, #layout1> + // CHECK-NOT: triton_gpu.convert_layout +@@ -111,7 +111,7 @@ func @remat_single_value(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { + } + + // CHECK-LABEL: if +-func @if(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { ++func.func @if(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + // CHECK-NOT: triton_gpu.convert_layout + %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout1> + %0 = tt.get_program_id {axis = 0 : i32} : i32 +@@ -128,7 +128,7 @@ func @if(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + } + + // CHECK-LABEL: if_convert_else_not +-func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { ++func.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0> + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0> +@@ -149,7 +149,7 @@ func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 + } + + // CHECK-LABEL: if_not_else_convert +-func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { ++func.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0> + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0> +@@ -170,7 +170,7 @@ func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 + } + + // CHECK-LABEL: if_else_both_convert +-func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { ++func.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0> + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0> +@@ -200,7 +200,7 @@ func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 + #blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> + + // CHECK-LABEL: transpose +-func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { ++func.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> + // CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]> +@@ -241,7 +241,7 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt + } + + // CHECK-LABEL: loop +-func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %arg4: i32) { ++func.func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %arg4: i32) { + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr, [[row_layout]]>) + // CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[row_layout]]> +@@ -295,7 +295,7 @@ func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %ar + } + + // CHECK-LABEL: vecadd +-func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { ++func.func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + // CHECK-NOT: triton_gpu.convert_layout + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 +@@ -327,7 +327,7 @@ func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) { ++func.func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) { + // CHECK-NOT: triton_gpu.convert_layout + %cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2> + %cst_0 = arith.constant dense<30000> : tensor<1x512xi32, #blocked2> +@@ -378,7 +378,7 @@ func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: !tt.ptr {tt.divisibility = 16 : i32}, %arg13: !tt.ptr {tt.divisibility = 16 : i32}, %arg14: !tt.ptr {tt.divisibility = 16 : i32}, %arg15: !tt.ptr {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) { ++func.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: !tt.ptr {tt.divisibility = 16 : i32}, %arg13: !tt.ptr {tt.divisibility = 16 : i32}, %arg14: !tt.ptr {tt.divisibility = 16 : i32}, %arg15: !tt.ptr {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked0> + %cst_0 = arith.constant dense<5.000000e-04> : tensor<1024xf32, #blocked0> + %cst_1 = arith.constant dense<0.999499976> : tensor<1024xf32, #blocked0> +@@ -775,7 +775,7 @@ func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: + // A mnist model from torch inductor. + // Check if topological sort is working correct and there's no unnecessary convert + // CHECK-LABEL: mnist +-func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) { ++func.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) { + // CHECK-NOT: triton_gpu.convert_layout + %cst = arith.constant dense<10> : tensor<16x1xi32, #blocked2> + %cst_0 = arith.constant dense<10> : tensor<1x16xi32, #blocked3> +@@ -862,7 +862,7 @@ func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt. + #blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> + // cmpf and cmpi have different operands and result types + // CHECK-LABEL: cmp +-func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { ++func.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + %c0 = arith.constant 0 : index +diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir +index 6ee3b15fbc..663f2da7b0 100644 +--- a/test/TritonGPU/loop-pipeline.mlir ++++ b/test/TritonGPU/loop-pipeline.mlir +@@ -10,7 +10,7 @@ + #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> + #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> + +-// CHECK: func @matmul_loop ++// CHECK: func.func @matmul_loop + // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 + // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 + // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +@@ -46,8 +46,8 @@ + // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] + // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] + // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] +-func @matmul_loop(%lb : index, %ub : index, %step : index, +- %A : !tt.ptr {tt.divisibility = 16 : i32}, ++func.func @matmul_loop(%lb : index, %ub : index, %step : index, ++ %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) { + // A ptrs + %a_ptr_splat = tt.splat %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> +@@ -61,7 +61,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> +- ++ + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> +@@ -88,7 +88,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, + } + + +-// CHECK: func @matmul_loop_nested ++// CHECK: func.func @matmul_loop_nested + // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 + // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 + // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +@@ -118,8 +118,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, + // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] + // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] + // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] +-func @matmul_loop_nested(%lb : index, %ub : index, %step : index, +- %A : !tt.ptr {tt.divisibility = 16 : i32}, ++func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, ++ %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) { + scf.for %iv0 = %lb to %ub step %step { + // A ptrs +@@ -134,7 +134,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> +- ++ + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> +@@ -161,7 +161,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, + } + + +-// CHECK: func @matmul_loop_single_pipeline ++// CHECK: func.func @matmul_loop_single_pipeline + // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 + // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 + // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +@@ -183,8 +183,8 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, + // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] + // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] + // CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] +-func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, +- %A : !tt.ptr {tt.divisibility = 16 : i32}, ++func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, ++ %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) { + // A ptrs + %a_ptr_splat = tt.splat %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> +diff --git a/test/TritonGPU/matmul.mlir b/test/TritonGPU/matmul.mlir +index 9bd5318e1e..01dc3f0ab1 100644 +--- a/test/TritonGPU/matmul.mlir ++++ b/test/TritonGPU/matmul.mlir +@@ -4,7 +4,7 @@ + // CHECK: offset = 49152, size = 49152 + // CHECK: size = 98304 + module { +-func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) { ++func.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) { + %cst = arith.constant dense : tensor<64x64xi1> + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index +@@ -22,7 +22,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.cmpi slt, %8, %c8_i32 : i32 +- %10 = select %9, %8, %c8_i32 : i32 ++ %10 = arith.select %9, %8, %c8_i32 : i32 + %11 = arith.remsi %0, %10 : i32 + %12 = arith.addi %7, %11 : i32 + %13 = arith.remsi %0, %5 : i32 +diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir +index 52b4dddec1..b427547890 100644 +--- a/test/TritonGPU/prefetch.mlir ++++ b/test/TritonGPU/prefetch.mlir +@@ -11,7 +11,7 @@ + #B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> + + +-// CHECK: func @matmul_loop ++// CHECK: func.func @matmul_loop + // CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[A0:.*]][0, 0] [128, 16] + // CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.convert_layout %[[A0_PREFETCH_SMEM]] + // CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[B0:.*]][0, 0] [16, 128] +@@ -28,7 +28,7 @@ + // CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = tensor.extract_slice {{.*}}[0, 0] [16, 128] + // CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_B_PREFETCH_SMEM]] + // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH]], %[[NEXT_B_PREFETCH]] +-func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { ++func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> + +diff --git a/test/TritonGPU/update-mma-for-volta.mlir b/test/TritonGPU/update-mma-for-volta.mlir +index d587fffcca..7571ec6185 100644 +--- a/test/TritonGPU/update-mma-for-volta.mlir ++++ b/test/TritonGPU/update-mma-for-volta.mlir +@@ -15,7 +15,7 @@ + // CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 2]}> + module attributes {"triton_gpu.num-warps" = 16 : i32} { + // CHECK-LABEL: dot_mmav1 +- func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> { ++ func.func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> { + %C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked0> + %AA = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a> + %BB = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b> +@@ -50,7 +50,7 @@ module attributes {"triton_gpu.num-warps" = 16 : i32} { + + module attributes {"triton_gpu.num-warps" = 16 : i32} { + // CHECK-LABEL: dot_mmav1 +- func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> { ++ func.func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> { + %C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked0> + %AA = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a> + %BB = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b> +diff --git a/test/lib/Analysis/TestAlias.cpp b/test/lib/Analysis/TestAlias.cpp +index 88a4118fe9..3fd0cfd0d3 100644 +--- a/test/lib/Analysis/TestAlias.cpp ++++ b/test/lib/Analysis/TestAlias.cpp +@@ -9,10 +9,10 @@ using namespace mlir; + namespace { + + struct TestAliasPass +- : public PassWrapper> { ++ : public PassWrapper> { ++ ++ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass); + +- // LLVM15+ +- // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass); + static void print(StringRef name, SmallVector &vals, + raw_ostream &os) { + if (vals.empty()) +@@ -39,23 +39,24 @@ struct TestAliasPass + auto opName = SymbolTable::getSymbolName(operation).getValue().str(); + os << opName << "\n"; + +- SharedMemoryAliasAnalysis analysis(&getContext()); +- analysis.run(operation); ++ std::unique_ptr solver = createDataFlowSolver(); ++ SharedMemoryAliasAnalysis *analysis = ++ solver->load(); ++ if (failed(solver->initializeAndRun(operation))) ++ return signalPassFailure(); + + AsmState state(operation->getParentOfType()); + // Get operation ids of value's aliases + auto getAllocOpNames = [&](Value value) { +- LatticeElement *latticeElement = +- analysis.lookupLatticeElement(value); ++ dataflow::Lattice *latticeElement = ++ analysis->getLatticeElement(value); + SmallVector opNames; +- if (latticeElement) { ++ if (latticeElement && !latticeElement->isUninitialized()) { + auto &info = latticeElement->getValue(); +- if (!info.getAllocs().empty()) { +- for (auto &alias : info.getAllocs()) { +- auto opName = +- getValueOperandName(alias.getDefiningOp()->getResult(0), state); +- opNames.push_back(std::move(opName)); +- } ++ for (auto &alias : info.getAllocs()) { ++ auto opName = ++ getValueOperandName(alias.getDefiningOp()->getResult(0), state); ++ opNames.push_back(std::move(opName)); + } + } + // Ensure deterministic output +diff --git a/test/lib/Analysis/TestAllocation.cpp b/test/lib/Analysis/TestAllocation.cpp +index 84108c4d36..35e42242bd 100644 +--- a/test/lib/Analysis/TestAllocation.cpp ++++ b/test/lib/Analysis/TestAllocation.cpp +@@ -6,10 +6,9 @@ using namespace mlir; + namespace { + + struct TestAllocationPass +- : public PassWrapper> { ++ : public PassWrapper> { + +- // LLVM15+ +- // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass); ++ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass); + + StringRef getArgument() const final { return "test-print-allocation"; } + StringRef getDescription() const final { +diff --git a/test/lib/Analysis/TestAxisInfo.cpp b/test/lib/Analysis/TestAxisInfo.cpp +index a5205bb0a0..22347c32f0 100644 +--- a/test/lib/Analysis/TestAxisInfo.cpp ++++ b/test/lib/Analysis/TestAxisInfo.cpp +@@ -1,25 +1,15 @@ + #include "mlir/Pass/Pass.h" + #include "triton/Analysis/AxisInfo.h" ++#include "triton/Analysis/Utility.h" + + using namespace mlir; + + namespace { + + struct TestAxisInfoPass +- : public PassWrapper> { ++ : public PassWrapper> { + +- // LLVM15+ +- // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass); +- +- void print(const std::string &name, raw_ostream &os, ArrayRef vals) { +- os << name << ": ["; +- for (size_t d = 0; d < vals.size(); d++) { +- if (d != 0) +- os << ", "; +- os << vals[d]; +- } +- os << "]"; +- } ++ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass); + + StringRef getArgument() const final { return "test-print-alignment"; } + StringRef getDescription() const final { +@@ -30,38 +20,19 @@ struct TestAxisInfoPass + Operation *operation = getOperation(); + auto &os = llvm::errs(); + auto opName = SymbolTable::getSymbolName(operation).getValue().str(); +- os << opName << "\n"; +- AxisInfoAnalysis analysis(&getContext()); +- analysis.run(operation); ++ os << "@" << opName << "\n"; ++ ++ std::unique_ptr solver = createDataFlowSolver(); ++ AxisInfoAnalysis *analysis = solver->load(); ++ if (failed(solver->initializeAndRun(operation))) ++ return signalPassFailure(); + operation->walk([&](Operation *op) { + if (op->getNumResults() < 1) + return; + for (Value result : op->getResults()) { +- // std::ostringstream oss; +- // result.print(oss); +- // os << " => "; +- LatticeElement *latticeElement = +- analysis.lookupLatticeElement(result); +- if (!latticeElement) { +- os << "None\n"; +- return; +- } +- AxisInfo &info = latticeElement->getValue(); +- print("Contiguity", os, info.getContiguity()); +- os << " ; "; +- print("Divisibility", os, info.getDivisibility()); +- os << " ; "; +- print("Constancy", os, info.getConstancy()); +- os << " ; "; +- auto constantValue = info.getConstantValue(); +- os << "ConstantValue: ["; +- if (constantValue.has_value()) +- os << constantValue.value(); +- else +- os << "None"; +- os << "] ( "; + result.print(os); +- os << " ) "; ++ os << " => "; ++ analysis->getLatticeElement(result)->getValue().print(os); + os << "\n"; + } + }); +diff --git a/test/lib/Analysis/TestMembar.cpp b/test/lib/Analysis/TestMembar.cpp +index df4279fe24..ab9b9f3fb7 100644 +--- a/test/lib/Analysis/TestMembar.cpp ++++ b/test/lib/Analysis/TestMembar.cpp +@@ -1,4 +1,4 @@ +-#include "mlir/Dialect/GPU/GPUDialect.h" ++#include "mlir/Dialect/GPU/IR/GPUDialect.h" + #include "mlir/IR/Dialect.h" + #include "mlir/Pass/Pass.h" + #include "triton/Analysis/Allocation.h" +@@ -9,10 +9,9 @@ using namespace mlir; + namespace { + + struct TestMembarPass +- : public PassWrapper> { ++ : public PassWrapper> { + +- // LLVM15+ +- // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass); ++ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass); + + StringRef getArgument() const final { return "test-print-membar"; } + StringRef getDescription() const final { diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix index 9f701493f07ae..95cd7955053fd 100644 --- a/pkgs/development/python-modules/torch/default.nix +++ b/pkgs/development/python-modules/torch/default.nix @@ -6,6 +6,7 @@ # Native build inputs cmake, util-linux, linkFarm, symlinkJoin, which, pybind11, removeReferencesTo, + pythonRelaxDepsHook, # Build inputs numactl, @@ -13,9 +14,10 @@ # Propagated build inputs filelock, - sympy, - networkx, jinja2, + networkx, + openai-triton, + sympy, numpy, pyyaml, cffi, click, typing-extensions, # Unit tests @@ -271,6 +273,7 @@ in buildPythonPackage rec { which ninja pybind11 + pythonRelaxDepsHook removeReferencesTo ] ++ lib.optionals cudaSupport [ cudatoolkit_joined ] ++ lib.optionals rocmSupport [ rocmtoolkit_joined ]; @@ -298,8 +301,17 @@ in buildPythonPackage rec { # the following are required for tensorboard support pillow six future tensorboard protobuf - ] ++ lib.optionals MPISupport [ mpi ] - ++ lib.optionals rocmSupport [ rocmtoolkit_joined ]; + ] + ++ lib.optionals MPISupport [ mpi ] + ++ lib.optionals rocmSupport [ rocmtoolkit_joined ] + # rocm build requires openai-triton; + # openai-triton currently requires cuda_nvcc, + # so not including it in the cpu-only build; + # torch.compile relies on openai-triton, + # so we include it for the cuda build as well + ++ lib.optionals (rocmSupport || cudaSupport) [ + openai-triton + ]; # Tests take a long time and may be flaky, so just sanity-check imports doCheck = false; @@ -327,6 +339,11 @@ in buildPythonPackage rec { "runHook postCheck" ]; + pythonRemoveDeps = [ + # In our dist-info the name is just "triton" + "pytorch-triton-rocm" + ]; + postInstall = '' find "$out/${python.sitePackages}/torch/include" "$out/${python.sitePackages}/torch/lib" -type f -exec remove-references-to -t ${stdenv.cc} '{}' + diff --git a/pkgs/tools/audio/tts/default.nix b/pkgs/tools/audio/tts/default.nix index 1c3cc91616c7a..9887353665ea0 100644 --- a/pkgs/tools/audio/tts/default.nix +++ b/pkgs/tools/audio/tts/default.nix @@ -184,6 +184,8 @@ python.pkgs.buildPythonApplication rec { "tests/vocoder_tests/test_multiband_melgan_train.py" "tests/vocoder_tests/test_melgan_train.py" "tests/vocoder_tests/test_wavernn_train.py" + # only a feed forward test, but still takes too long + "tests/tts_tests/test_overflow.py" ]; passthru = { diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index 5349c922a4e91..cc8d2da989eaa 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -6801,6 +6801,8 @@ self: super: with self; { open-meteo = callPackage ../development/python-modules/open-meteo { }; + openai-triton = callPackage ../development/python-modules/openai-triton { llvmPackages = pkgs.llvmPackages_rocm; }; + openai-whisper = callPackage ../development/python-modules/openai-whisper { }; openant = callPackage ../development/python-modules/openant { }; From 5e8008a536aeb17f3b90a9926e59d217591697e7 Mon Sep 17 00:00:00 2001 From: Someone Serge Date: Wed, 22 Mar 2023 16:20:14 +0200 Subject: [PATCH 05/10] python3Packages.torchWithCuda: avoid "unknown-warning" when building with cuda-compatible stdenv --- pkgs/development/python-modules/torch/default.nix | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix index 95cd7955053fd..d9cf4e5760d88 100644 --- a/pkgs/development/python-modules/torch/default.nix +++ b/pkgs/development/python-modules/torch/default.nix @@ -256,7 +256,7 @@ in buildPythonPackage rec { # Suppress gcc regression: avx512 math function raises uninitialized variable warning # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105593 # See also: Fails to compile with GCC 12.1.0 https://github.com/pytorch/pytorch/issues/77939 - ++ lib.optionals stdenv.cc.isGNU [ + ++ lib.optionals (stdenv.cc.isGNU && lib.versionAtLeast stdenv.cc.version "12.0.0") [ "-Wno-error=maybe-uninitialized" "-Wno-error=uninitialized" ] From 455d23b95d67f1774a59eabdbaafa9650ca2f61f Mon Sep 17 00:00:00 2001 From: Someone Serge Date: Thu, 23 Mar 2023 20:35:01 +0200 Subject: [PATCH 06/10] python3Packages.torchinfo: 1.64 -> 1.7.2 catch up with pytorch 2.0.0 and updated interfaces --- pkgs/development/python-modules/torchinfo/default.nix | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pkgs/development/python-modules/torchinfo/default.nix b/pkgs/development/python-modules/torchinfo/default.nix index c36372235e785..67945d9bd2f29 100644 --- a/pkgs/development/python-modules/torchinfo/default.nix +++ b/pkgs/development/python-modules/torchinfo/default.nix @@ -9,7 +9,7 @@ buildPythonPackage rec { pname = "torchinfo"; - version = "1.64"; + version = "1.7.2"; format = "setuptools"; disabled = pythonOlder "3.7"; @@ -18,7 +18,7 @@ buildPythonPackage rec { owner = "TylerYep"; repo = pname; rev = "refs/tags/v${version}"; - hash = "sha256-gcl8RxCD017FP4LtB60WVtOh7jg2Otv/vNd9hKneEAU="; + hash = "sha256-O+I7BNQ5moV/ZcbbuP/IFoi0LO0WsGHBbSfgPmFu1Ec="; }; propagatedBuildInputs = [ @@ -37,6 +37,11 @@ buildPythonPackage rec { "test_google" ]; + disabledTestPaths = [ + # Wants "compressai", which we don't package (2023-03-23) + "tests/torchinfo_xl_test.py" + ]; + pythonImportsCheck = [ "torchvision" ]; From 9b5fb1838d85f74a4c7f5b2362cec1570ba682e3 Mon Sep 17 00:00:00 2001 From: Someone Serge Date: Thu, 23 Mar 2023 20:35:15 +0200 Subject: [PATCH 07/10] python3Packages.torchinfo: fix pythonImportsCheck --- pkgs/development/python-modules/torchinfo/default.nix | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkgs/development/python-modules/torchinfo/default.nix b/pkgs/development/python-modules/torchinfo/default.nix index 67945d9bd2f29..420a5fd8dfc52 100644 --- a/pkgs/development/python-modules/torchinfo/default.nix +++ b/pkgs/development/python-modules/torchinfo/default.nix @@ -43,7 +43,7 @@ buildPythonPackage rec { ]; pythonImportsCheck = [ - "torchvision" + "torchinfo" ]; meta = with lib; { From 91f24957262e76a6cff959ce49867703349d839d Mon Sep 17 00:00:00 2001 From: Someone Serge Date: Thu, 23 Mar 2023 22:27:56 +0200 Subject: [PATCH 08/10] ocamlPackages.torch: patch for pytorch 2.0.0 compatibility --- pkgs/development/ocaml-modules/torch/default.nix | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pkgs/development/ocaml-modules/torch/default.nix b/pkgs/development/ocaml-modules/torch/default.nix index e46374ee1c277..9ba356fa93998 100644 --- a/pkgs/development/ocaml-modules/torch/default.nix +++ b/pkgs/development/ocaml-modules/torch/default.nix @@ -2,6 +2,7 @@ , stdenv , buildDunePackage , fetchFromGitHub +, fetchpatch , cmdliner , ctypes , dune-configurator @@ -24,11 +25,19 @@ buildDunePackage rec { src = fetchFromGitHub { owner = "LaurentMazare"; - repo = "ocaml-${pname}"; - rev = version; + repo = "ocaml-${pname}"; + rev = version; hash = "sha256-z/9NUBjeFWE63Z/e8OyzDiy8hrn6qzjaiBH8G9MPeos="; }; + patches = [ + # Pytorch 2.0 support. Drop when it reaches a release + (fetchpatch { + url = "https://github.com/LaurentMazare/ocaml-torch/commit/ef7ef30cafecb09e45ec1ed8ce4bedae5947cfa5.patch"; + hash = "sha256-smdwKy40iIISp/25L2J4az6KmqFS1soeChBElUyhl5A="; + }) + ]; + buildInputs = [ dune-configurator ]; propagatedBuildInputs = [ From 24d20fefbf967ce83126098a1c788ca9866c4570 Mon Sep 17 00:00:00 2001 From: Someone Date: Mon, 3 Apr 2023 16:21:15 +0000 Subject: [PATCH 09/10] python3Packages.openai-triton: inline bash comments The drawback of this is that the comments now affect outPath's. Hopefully, though, we'll remove this preFixup soon anyway Co-authored-by: Sandro --- .../python-modules/openai-triton/default.nix | 37 ++++++++----------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/pkgs/development/python-modules/openai-triton/default.nix b/pkgs/development/python-modules/openai-triton/default.nix index 9340aad3a9545..90c973d63b6e1 100644 --- a/pkgs/development/python-modules/openai-triton/default.nix +++ b/pkgs/development/python-modules/openai-triton/default.nix @@ -163,31 +163,26 @@ buildPythonPackage { ]; # Avoid GLIBCXX mismatch with other cuda-enabled python packages - preConfigure = - '' - export CC="${backendStdenv.cc}/bin/cc"; - export CXX="${backendStdenv.cc}/bin/c++"; - '' + preConfigure = '' + export CC="${backendStdenv.cc}/bin/cc"; + export CXX="${backendStdenv.cc}/bin/c++"; + # Upstream's setup.py tries to write cache somewhere in ~/ - + '' - export HOME=$TMPDIR - '' + export HOME=$TMPDIR + # Upstream's github actions patch setup.cfg to write base-dir. May be redundant - + '' - echo "" >> python/setup.cfg - echo "[build_ext]" >> python/setup.cfg - echo "base-dir=$PWD" >> python/setup.cfg - '' + echo " + [build_ext] + base-dir=$PWD" >> python/setup.cfg + # The rest (including buildPhase) is relative to ./python/ - + '' - cd python/ - '' + cd python/ + # Work around download_and_copy_ptxas() - + '' - dst_cuda="$PWD/triton/third_party/cuda/bin" - mkdir -p "$dst_cuda" - ln -s "${ptxas}" "$dst_cuda/" - ''; + dst_cuda="$PWD/triton/third_party/cuda/bin" + mkdir -p "$dst_cuda" + ln -s "${ptxas}" "$dst_cuda/" + ''; # CMake is run by setup.py instead dontUseCmakeConfigure = true; From 632cff6fcef0895202515d875dc89cc41fc14a8c Mon Sep 17 00:00:00 2001 From: Someone Serge Date: Tue, 4 Apr 2023 14:41:31 +0300 Subject: [PATCH 10/10] python3Packages.openai-triton: justify the use of pkgsTargetTarget --- .../python-modules/openai-triton/default.nix | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pkgs/development/python-modules/openai-triton/default.nix b/pkgs/development/python-modules/openai-triton/default.nix index 90c973d63b6e1..0e10642f0693f 100644 --- a/pkgs/development/python-modules/openai-triton/default.nix +++ b/pkgs/development/python-modules/openai-triton/default.nix @@ -26,6 +26,18 @@ let version = "2.0.0"; inherit (cudaPackages) cuda_cudart backendStdenv; + + # A time may come we'll want to be cross-friendly + # + # Short explanation: we need pkgsTargetTarget, because we use string + # interpolation instead of buildInputs. + # + # Long explanation: OpenAI/triton downloads and vendors a copy of NVidia's + # ptxas compiler. We're not running this ptxas on the build machine, but on + # the user's machine, i.e. our Target platform. The second "Target" in + # pkgsTargetTarget maybe doesn't matter, because ptxas compiles programs to + # be executed on the GPU. + # Cf. https://nixos.org/manual/nixpkgs/unstable/#sec-cross-infra ptxas = "${pkgsTargetTarget.cudaPackages.cuda_nvcc}/bin/ptxas"; llvm = (llvmPackages.llvm.override {