diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 216609bb783a..b13c74801ee3 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -34,7 +34,6 @@ jobs: echo '::set-output name=matrix-optional::["ubuntu-latest"]' fi - Integration-Tests: needs: Runner-Preparation @@ -49,7 +48,7 @@ jobs: - name: Checkout uses: actions/checkout@v3 with: - submodules: 'true' + submodules: "true" - name: Set CUDA ENV if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}} run: | diff --git a/python/setup.py b/python/setup.py index a2f522d89539..38f24d7f8901 100644 --- a/python/setup.py +++ b/python/setup.py @@ -354,6 +354,7 @@ def build_extension(self, ext): "triton/_C", "triton/common", "triton/compiler", + "triton/compiler/backends", "triton/language", "triton/language/extra", "triton/ops", diff --git a/python/src/triton.cc b/python/src/triton.cc index f173fb3286ae..e1f9f482e1a3 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1685,9 +1685,9 @@ void init_triton_ir(py::module &&m) { self.addPass(mlir::triton::createReorderBroadcastPass()); }) .def("add_rewrite_tensor_pointer_pass", - [](mlir::PassManager &self, int computeCapability) { - self.addPass(mlir::triton::createRewriteTensorPointerPass( - computeCapability)); + [](mlir::PassManager &self, int capability) { + self.addPass( + mlir::triton::createRewriteTensorPointerPass(capability)); }) .def("add_tritongpu_ws_feasibility_checking_pass", [](mlir::PassManager &self, int computeCapability) { @@ -1761,9 +1761,9 @@ void init_triton_ir(py::module &&m) { self.addPass(mlir::createTritonGPUReorderInstructionsPass()); }) .def("add_tritongpu_rewrite_tensor_pointer_pass", - [](mlir::PassManager &self, int computeCapability) { - self.addPass(mlir::createTritonGPURewriteTensorPointerPass( - computeCapability)); + [](mlir::PassManager &self, int capability) { + self.addPass( + mlir::createTritonGPURewriteTensorPointerPass(capability)); }) .def("add_tritongpu_decompose_conversions_pass", [](mlir::PassManager &self) { diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py index abd5c5edcbc4..1101b8906688 100644 --- a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py +++ b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py @@ -900,11 +900,12 @@ def process_epilogue(d, bias, w, epilogue): num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count if NUM_CTAS > 1: device = get_current_device() - null_kernel = triton.compile(empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}) + src = triton.compiler.ASTSource(fn=empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}) + null_kernel = triton.compile(src) null_kernel._init_handles() max_shared_mem = driver.utils.get_device_properties(device)["max_shared_mem"] - num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.cu_function, max_shared_mem, NUM_CTAS, - 1, 1) + num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.function, max_shared_mem, NUM_CTAS, 1, + 1) num_SMs = num_clusters def grid(META): diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index f1039d011e2c..63401f28e42b 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -1,12 +1,12 @@ import multiprocessing import os import shutil -from collections import namedtuple import torch import triton import triton.language as tl +from triton.compiler import ASTSource tmpdir = ".tmp" @@ -17,32 +17,26 @@ def reset_tmp_dir(): shutil.rmtree(tmpdir, ignore_errors=True) -instance_descriptor = namedtuple("instance_descriptor", - ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"]) - - -def compile_fn(config, cc): +def compile_fn(attrs, capability): @triton.jit def kernel_sub(a, b, o, N: tl.constexpr): idx = tl.arange(0, N) tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) - triton.compile( + src = ASTSource( fn=kernel_sub, - signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, - device=0, constants={3: 32}, - configs=[config], - warm_cache_only=True, - cc=cc, + signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, + attrs=attrs, ) + triton.compile(src=src, target=("cuda", capability)) def test_compile_in_subproc() -> None: major, minor = torch.cuda.get_device_capability(0) cc = major * 10 + minor - config = instance_descriptor(tuple(range(4)), (), (), ()) + config = triton.compiler.AttrsDescriptor(tuple(range(4)), (), (), ()) multiprocessing.set_start_method('fork') proc = multiprocessing.Process(target=compile_fn, args=(config, cc)) @@ -51,7 +45,7 @@ def test_compile_in_subproc() -> None: assert proc.exitcode == 0 -def compile_fn_dot(config, cc): +def compile_fn_dot(attrs, capability): @triton.jit def kernel_dot(Z): @@ -60,24 +54,18 @@ def kernel_dot(Z): z = tl.dot(z, z) tl.store(Z + offs, z) - triton.compile( - fn=kernel_dot, - signature={0: "*fp32"}, - device=0, - configs=[config], - warm_cache_only=True, - cc=cc, - ) + src = ASTSource(fn=kernel_dot, signature={0: "*fp32"}, attrs=attrs, constants=dict()) + triton.compile(src=src, target=("cuda", capability)) def test_compile_in_forked_subproc() -> None: reset_tmp_dir() major, minor = torch.cuda.get_device_capability(0) - cc = major * 10 + minor - config = instance_descriptor(tuple(range(1)), (), (), ()) + capability = major * 10 + minor + config = triton.compiler.AttrsDescriptor(tuple(range(1)), (), (), ()) assert multiprocessing.get_start_method() == 'fork' - proc = multiprocessing.Process(target=compile_fn_dot, args=(config, cc)) + proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability)) proc.start() proc.join() assert proc.exitcode == 0 diff --git a/python/test/unit/tools/test_aot.py b/python/test/unit/tools/test_aot.py index 92b5562e9527..613836d24e36 100644 --- a/python/test/unit/tools/test_aot.py +++ b/python/test/unit/tools/test_aot.py @@ -446,7 +446,7 @@ def test_ttgir_to_ptx(): kernel_path = os.path.join(tmp_dir, "empty_kernel.ttgir") with open(kernel_path, "w") as fp: fp.write(src) - k = triton.compile(kernel_path, cc=80) + k = triton.compile(kernel_path, target=("cuda", 80)) ptx = k.asm["ptx"] assert ".target sm_80" in ptx assert ".address_size 64" in ptx diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 55484acd5bf2..3cda7668e50e 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -21,6 +21,7 @@ from . import language from . import testing +from . import tools __all__ = [ "autotune", diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index fd0665e1e549..0ced3fed6af1 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,8 +1,7 @@ -from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps, - instance_descriptor) +from .compiler import (CompiledKernel, ASTSource, compile, AttrsDescriptor) from .errors import CompilationError __all__ = [ - "compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", + "compile", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages" ] diff --git a/python/triton/compiler/target.py b/python/triton/compiler/backends/__init__.py similarity index 100% rename from python/triton/compiler/target.py rename to python/triton/compiler/backends/__init__.py diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py new file mode 100644 index 000000000000..a019539ad812 --- /dev/null +++ b/python/triton/compiler/backends/cuda.py @@ -0,0 +1,225 @@ +from triton.common.backend import BaseBackend +from dataclasses import dataclass +from ..._C.libtriton.triton import ClusterInfo, get_num_warps, TMAInfos, translate_triton_gpu_to_llvmir, get_shared_memory_size, translate_llvmir_to_ptx, compile_ptx_to_cubin, add_external_libs +from ...common.backend import get_cuda_version_key, path_to_ptxas +from ..._C.libtriton.triton import ir, runtime +import functools +from typing import Any +from ..utils import get_ids_of_tensormaps, parse_tma_info +from ..make_launcher import make_stub +import hashlib + + +def get_kernel_name(src: str, pattern: str) -> str: + ''' + Get kernel name from PTX code. + This Kernel name is required when launching the kernel. + ''' + # There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin. + assert src + for line in src.split('\n'): + line = line.strip() + if line.startswith(pattern): + return line.split()[-1] + + +@functools.lru_cache() +def ptx_get_version(cuda_version) -> int: + ''' + Get the highest PTX version supported by the current CUDA driver. + ''' + assert isinstance(cuda_version, str) + major, minor = map(int, cuda_version.split('.')) + if major == 12: + return 80 + minor + if major == 11: + return 70 + minor + if major == 10: + return 63 + minor + raise RuntimeError("Triton only support CUDA 10.0 or higher") + + +@dataclass +class CUDAOptions: + num_warps: int = 4 + num_ctas: int = 1 + num_stages: int = 3 + cluster_dims: tuple = (1, 1, 1) + ptx_version: int = None + enable_warp_specialization: bool = False + enable_persistent: bool = False + optimize_epilogue: bool = False + enable_fp_fusion: bool = True + allow_fp8e4nv: bool = False + max_num_imprecise_acc_default: bool = None + debug: bool = False + + def __post_init__(self): + assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ + "num_warps must be a power of 2" + + def hash(self): + key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()]) + return hashlib.md5(key.encode("utf-8")).hexdigest() + + +@dataclass +class CUDALinkerOptions: + libs: dict = None + + def __post_init__(self): + if self.libs is not None: + self.libs = {k: v for k, v in self.libs.items() if v} + + +class CUDABackend(BaseBackend): + + def __init__(self, device_type: tuple) -> None: + super().__init__(device_type) + self.capability = device_type[1] + assert isinstance(self.capability, int) + + def parse_compiler_options(self, opts) -> Any: + options = CUDAOptions(**opts) + options.allow_fp8e4nv = self.capability >= 89 + options.max_num_imprecise_acc_default = 0 if self.capability >= 89 else None + return options + + def parse_linker_options(self, opts) -> Any: + return CUDALinkerOptions(**opts) + + @staticmethod + def make_ttir(mod, metadata, opt): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + pm.add_inliner_pass() + pm.add_triton_combine_pass() + pm.add_canonicalizer_pass() + pm.add_reorder_broadcast_pass() + pm.add_cse_pass() + pm.add_licm_pass() + pm.add_symbol_dce_pass() + pm.run(mod) + return mod + + @staticmethod + def make_ttgir(mod, metadata, opt, capability): + cluster_info = ClusterInfo() + if opt.cluster_dims is not None: + cluster_info.clusterDimX = opt.cluster_dims[0] + cluster_info.clusterDimY = opt.cluster_dims[1] + cluster_info.clusterDimZ = opt.cluster_dims[2] + # TTIR -> TTGIR + pm = ir.pass_manager(mod.context) + pm.enable_debug() + pm.add_convert_triton_to_tritongpu_pass(opt.num_warps, 32, opt.num_ctas, capability) + # optimize TTGIR + pm.add_tritongpu_coalesce_pass() + # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass + pm.add_plan_cta_pass(cluster_info) + pm.add_tritongpu_rewrite_tensor_pointer_pass(capability) + pm.add_plan_cta_pass(cluster_info) + pm.add_tritongpu_remove_layout_conversions_pass() + pm.add_tritongpu_accelerate_matmul_pass(capability) + pm.add_tritongpu_remove_layout_conversions_pass() + if opt.optimize_epilogue: + pm.add_tritongpu_optimize_epilogue_pass() + pm.add_tritongpu_optimize_dot_operands_pass() + pm.add_cse_pass() + ws_enabled = False + # `num_warps` does not mean the total number of warps of a CTA when + # warp specialization is enabled. + # it's the responsibility of the compiler to figure out the exact + # `num_warps` to use. + # TODO: support the case where `num_warps` from user is not 4. + if capability // 10 >= 9 and opt.enable_warp_specialization and opt.num_warps == 4: + pm.add_tritongpu_ws_feasibility_checking_pass(capability) + pm.run(mod) + ws_enabled = ir.is_ws_supported(mod) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + if ws_enabled: + pm.add_tritongpu_wsdecomposing_pass(capability) + pm.add_tritongpu_wspipeline_pass(opt.num_stages, opt.num_warps, capability) + pm.add_tritongpu_wsmutex_pass(capability) + pm.add_tritongpu_wsmaterialization_pass(capability) + pm.add_licm_pass() + pm.add_cse_pass() + else: + pm.add_tritongpu_pipeline_pass(opt.num_stages, opt.num_warps, opt.num_ctas, capability) + pm.add_tritongpu_materialize_load_store_pass(opt.num_warps, capability) + if capability // 10 <= 8: + pm.add_tritongpu_prefetch_pass() + pm.add_tritongpu_optimize_dot_operands_pass() + pm.add_tritongpu_remove_layout_conversions_pass() + pm.add_tritongpu_decompose_conversions_pass() + pm.add_tritongpu_ws_fixup_missing_attrs_pass() + pm.add_tritongpu_reorder_instructions_pass() + pm.add_cse_pass() + pm.add_symbol_dce_pass() + if capability // 10 >= 9: + pm.add_tritongpu_fence_insertion_pass() + pm.add_tritongpu_ws_fixup_missing_attrs_pass() + pm.add_tritongpu_optimize_thread_locality_pass() + pm.add_canonicalizer_pass() + pm.run(mod) + metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) + return mod + + @staticmethod + def make_llir(src, metadata, linker_options, capability): + metadata["enable_warp_specialization"] = ir.is_ws_supported(src) + metadata["num_warps"] = get_num_warps(src) + tma_infos = TMAInfos() + # link libraries + if linker_options.libs: + add_external_libs(src, list(linker_options.libs.keys()), list(linker_options.libs.values())) + # TritonGPU -> LLVM-IR + ret = translate_triton_gpu_to_llvmir(src, capability, tma_infos, runtime.TARGET.NVVM) + if len(tma_infos) > 0: + metadata["tensormaps_info"] = parse_tma_info(tma_infos, metadata["ids_of_folded_args"]) + for i, _ in enumerate(metadata["tensormaps_info"]): + metadata["tensormaps_info"][i].ids_of_folded_args = metadata["ids_of_folded_args"] + metadata["ids_of_tensormaps"] = get_ids_of_tensormaps(metadata.get("tensormaps_info", None)) + metadata["shared"] = get_shared_memory_size(src) + return ret + + @staticmethod + def make_ptx(src, metadata, opt, capability): + ptx_version = opt.ptx_version + if ptx_version is None: + _, cuda_version = path_to_ptxas() + ptx_version = ptx_get_version(cuda_version) + return translate_llvmir_to_ptx(src, capability, ptx_version, opt.enable_fp_fusion) + + @staticmethod + def make_cubin(src, metadata, opt, capability): + metadata["name"] = get_kernel_name(src, pattern='// .globl') + ptxas, _ = path_to_ptxas() + return compile_ptx_to_cubin(src, ptxas, capability, opt.enable_fp_fusion) + + def add_stages(self, stages, compiler_options, linker_options): + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, compiler_options) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, compiler_options, self.capability) + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, linker_options, self.capability) + stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, compiler_options, self.capability) + stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, compiler_options, self.capability) + + def hash(self): + return f'{get_cuda_version_key()}-{self.capability}' + + def make_launcher_stub(self, src, metadata): + ids = { + "ids_of_tensormaps": metadata.get("ids_of_tensormaps", tuple()), "ids_of_folded_args": + metadata.get("ids_of_folded_args", + tuple()), "ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple() + } + constants = src.constants if hasattr(src, "constants") else dict() + enable_warp_specialization = False + + # set constant + return make_stub(src.name, src.signature, constants, ids, enable_warp_specialization=enable_warp_specialization) + + @classmethod + def create_backend(cls, device_type: str): + return cls(device_type) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index f23af910ccdd..53812ca32ba7 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -208,8 +208,8 @@ def visit_Call(self, node: ast.Call) -> bool: class CodeGenerator(ast.NodeVisitor): - def __init__(self, context, prototype, gscope, attributes, constants, function_name, target, module=None, - is_kernel=False, function_types: Optional[Dict] = None, debug=False, noinline=False, + def __init__(self, context, prototype, gscope, attributes, constants, function_name, options, module=None, + is_kernel=False, function_types: Optional[Dict] = None, noinline=False, file_name: Optional[str] = None, begin_line=0): self.context = context self.builder = ir.builder(context) @@ -217,7 +217,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n # node.lineno starts from 1, so we need to subtract 1 self.begin_line = begin_line - 1 self.builder.set_loc(file_name, begin_line, 0) - self.builder.target = target + self.builder.options = options self.module = self.builder.create_module() if module is None else module self.function_ret_types = {} if function_types is None else function_types self.prototype = prototype @@ -228,7 +228,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n self.function_name = function_name self.is_kernel = is_kernel self.last_node = None - self.debug = debug + self.debug = options.debug self.noinline = noinline self.scf_stack = [] self.last_ret_type = None @@ -980,12 +980,13 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): prototype = language.function_type([], arg_types) gscope = sys.modules[fn.fn.__module__].__dict__ # If the callee is not set, we use the same debug setting as the caller - debug = self.debug if fn.debug is None else fn.debug file_name, begin_line = _get_fn_file_line(fn) + options = self.builder.options + options.debug = self.debug if fn.debug is None else fn.debug generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, - function_name=fn_name, function_types=self.function_ret_types, debug=debug, + function_name=fn_name, function_types=self.function_ret_types, noinline=fn.noinline, file_name=file_name, begin_line=begin_line, - target=self.builder.target) + options=self.builder.options) generator.visit(fn.parse()) callee_ret_type = generator.last_ret_type self.function_ret_types[fn_name] = callee_ret_type @@ -1188,28 +1189,22 @@ def kernel_suffix(signature, specialization): return suffix -def ast_to_ttir(fn, signature, specialization, constants, debug, target): - # canonicalize signature - if isinstance(signature, str): - signature = {k: v.strip() for k, v in enumerate(signature.split(","))} +def ast_to_ttir(fn, specialization, options): + attrs = specialization.attrs context = ir.context() context.load_triton() # create kernel prototype cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in constants.items()} + constants = {cst_key(key): value for key, value in specialization.constants.items()} # visit kernel AST gscope = fn.__globals__.copy() - function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)]) - tys = list(signature.values()) - new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in specialization.equal_to_1} - new_attrs = {k: [("tt.divisibility", 16)] for k in specialization.divisible_by_16} - - # Note: Here we defines 'max_divisibility' for later TMA usage. - # fp16 requires 'max_divisibility >= 8' and fp8 requires 'max_divisibility >= 16'. - # Since we only need to support TMA for fp16 and fp8 now, 'max_divisibility' is either 8 or 16. - for k in specialization.divisible_by_8: + function_name = '_'.join([fn.__name__, kernel_suffix(specialization.signature.values(), attrs)]) + tys = list(specialization.signature.values()) + new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs.equal_to_1} + new_attrs = {k: [("tt.divisibility", 16)] for k in attrs.divisible_by_16} + for k in attrs.divisible_by_8: attr = new_attrs[k] if k in new_attrs else [] - if k in specialization.divisible_by_16: + if k in attrs.divisible_by_16: attr.append(("tt.max_divisibility", 16)) else: attr.append(("tt.max_divisibility", 8)) @@ -1217,13 +1212,13 @@ def ast_to_ttir(fn, signature, specialization, constants, debug, target): all_constants = constants.copy() all_constants.update(new_constants) - arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants] + arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in constants] file_name, begin_line = _get_fn_file_line(fn) prototype = language.function_type([], arg_types) generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, - attributes=new_attrs, is_kernel=True, debug=debug, file_name=file_name, - begin_line=begin_line, target=target) + attributes=new_attrs, is_kernel=True, file_name=file_name, begin_line=begin_line, + options=options) try: generator.visit(fn.parse()) except CompilationError as e: diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 6e574b5416ae..528b824a61d5 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -1,268 +1,43 @@ from __future__ import annotations -import functools import hashlib import json -import os -import re -from collections import namedtuple -from pathlib import Path -from typing import Any - -from dataclasses import dataclass -from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, compile_ptx_to_cubin, get_env_vars, - get_num_warps, get_shared_memory_size, ir, runtime, translate_llvmir_to_ptx, - translate_triton_gpu_to_llvmir) -from ..common.backend import get_backend, get_cuda_version_key, path_to_ptxas -from ..common.build import is_hip +from .._C.libtriton.triton import (get_env_vars, ir) # from ..runtime import driver, jit, JITFunction # TODO: runtime.errors from ..runtime.autotuner import OutOfResources -from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager +from ..runtime.cache import get_cache_manager +from ..runtime.jit import get_current_device, get_cuda_stream, get_current_target from ..runtime.driver import driver -from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability) -from ..tools.disasm import get_sass +from .utils import InfoFromBackendForTensorMap +from .backends.cuda import CUDABackend +from dataclasses import dataclass from .code_generator import ast_to_ttir -from .make_launcher import make_stub -from .utils import (InfoFromBackendForTensorMap, TensorMapManager, get_ids_of_tensormaps, parse_tma_info) +from pathlib import Path +import re @dataclass -class CudaTargetDescriptor: - capability: int - num_warps: int - enable_fp_fusion: bool - - -def _is_cuda(target): - return isinstance(target, CudaTargetDescriptor) - - -class LazyDict(dict): - - def __getitem__(self, key): - val = dict.__getitem__(self, key) - if callable(val): - return val() - return val - - -def inline_triton_ir(mod): - pm = ir.pass_manager(mod.context) - pm.enable_debug() - pm.add_inliner_pass() - pm.run(mod) - return mod - - -def ttir_compute_capability_rewrite(mod, target): - # For hardware without support, we must rewrite all load/store - # with block (tensor) pointers into tensors of pointers - pm = ir.pass_manager(mod.context) - pm.enable_debug() - if _is_cuda(target): - pm.add_rewrite_tensor_pointer_pass(target.capability) - pm.run(mod) - return mod - - -def optimize_ttir(mod, target): - mod = inline_triton_ir(mod) - mod = ttir_compute_capability_rewrite(mod, target) - pm = ir.pass_manager(mod.context) - pm.enable_debug() - pm.add_inliner_pass() - pm.add_triton_combine_pass() - pm.add_canonicalizer_pass() - pm.add_reorder_broadcast_pass() - pm.add_cse_pass() - pm.add_licm_pass() - pm.add_symbol_dce_pass() - pm.run(mod) - return mod - - -def ttir_to_ttgir(mod, num_warps, num_ctas, target): - pm = ir.pass_manager(mod.context) - pm.enable_debug() - pm.add_convert_triton_to_tritongpu_pass(num_warps, 32, num_ctas, target.capability) - pm.run(mod) - return mod - - -def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, - enable_persistent, optimize_epilogue): - is_cuda = _is_cuda(target) - if is_cuda: - capability = target.capability - pm = ir.pass_manager(mod.context) - pm.enable_debug() - pm.add_tritongpu_coalesce_pass() - # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass - pm.add_plan_cta_pass(cluster_info) - if is_cuda: - pm.add_tritongpu_rewrite_tensor_pointer_pass(capability) - pm.add_plan_cta_pass(cluster_info) - pm.add_tritongpu_remove_layout_conversions_pass() - if is_cuda: - pm.add_tritongpu_accelerate_matmul_pass(capability) - pm.add_tritongpu_remove_layout_conversions_pass() - if optimize_epilogue: - pm.add_tritongpu_optimize_epilogue_pass() - pm.add_tritongpu_optimize_dot_operands_pass() - pm.add_cse_pass() - ws_enabled = False - # `num_warps` does not mean the total number of warps of a CTA when - # warp specialization is enabled. - # it's the responsibility of the compiler to figure out the exact - # `num_warps` to use. - # TODO: support the case where `num_warps` from user is not 4. - if capability // 10 >= 9 and enable_warp_specialization and num_warps == 4: - pm.add_tritongpu_ws_feasibility_checking_pass(capability) - pm.run(mod) - ws_enabled = ir.is_ws_supported(mod) - pm = ir.pass_manager(mod.context) - pm.enable_debug() - if ws_enabled: - pm.add_tritongpu_wsdecomposing_pass(capability) - pm.add_tritongpu_wspipeline_pass(num_stages, num_warps, capability) - pm.add_tritongpu_wsmutex_pass(capability) - pm.add_tritongpu_wsmaterialization_pass(capability) - pm.add_licm_pass() - pm.add_cse_pass() - else: - pm.add_tritongpu_pipeline_pass(num_stages, num_warps, num_ctas, capability) - pm.add_tritongpu_materialize_load_store_pass(num_warps, capability) - if capability // 10 <= 8: - pm.add_tritongpu_prefetch_pass() - pm.add_tritongpu_optimize_dot_operands_pass() - pm.add_tritongpu_remove_layout_conversions_pass() - pm.add_tritongpu_decompose_conversions_pass() - pm.add_tritongpu_ws_fixup_missing_attrs_pass() - pm.add_tritongpu_reorder_instructions_pass() - pm.add_cse_pass() - pm.add_symbol_dce_pass() - if capability // 10 >= 9: - pm.add_tritongpu_fence_insertion_pass() - pm.add_tritongpu_ws_fixup_missing_attrs_pass() - pm.add_tritongpu_optimize_thread_locality_pass() - pm.add_canonicalizer_pass() - pm.run(mod) - return mod - - -def _add_external_libs(mod, libs): - for name, path in libs.items(): - if len(name) == 0 or len(path) == 0: - return - add_external_libs(mod, list(libs.keys()), list(libs.values())) - - -def ttgir_to_llir(mod, extern_libs, target, tma_infos): - if extern_libs: - _add_external_libs(mod, extern_libs) - # TODO: separate tritongpu_to_llvmir for different backends - if _is_cuda(target): - return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM) - else: - return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL) - - -# PTX translation - - -@functools.lru_cache() -def ptx_get_version(cuda_version) -> int: - ''' - Get the highest PTX version supported by the current CUDA driver. - ''' - assert isinstance(cuda_version, str) - major, minor = map(int, cuda_version.split('.')) - if major == 12: - return 80 + minor - if major == 11: - return 70 + minor - if major == 10: - return 63 + minor - raise RuntimeError("Triton only support CUDA 10.0 or higher") - - -def llir_to_ptx(mod: Any, target: CudaTargetDescriptor, ptx_version: int = None) -> str: - ''' - Translate TritonGPU module to PTX code. - :param mod: a TritonGPU dialect module - :return: PTX code - ''' - if ptx_version is None: - _, cuda_version = path_to_ptxas() - ptx_version = ptx_get_version(cuda_version) - return translate_llvmir_to_ptx(mod, target.capability, ptx_version, target.enable_fp_fusion) - - -def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor): - ''' - Compile TritonGPU module to cubin. - :param ptx: ptx code - :param compute_capability: compute capability - :return: str - ''' - ptxas, _ = path_to_ptxas() - return compile_ptx_to_cubin(ptx, ptxas, target.capability, target.enable_fp_fusion) - - -# ------------------------------------------------------------------------------ -# compiler -# ------------------------------------------------------------------------------ -def get_kernel_name(src: str, pattern: str) -> str: - ''' - Get kernel name from PTX code. - This Kernel name is required when launching the kernel. - ''' - # There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin. - assert src - for line in src.split('\n'): - line = line.strip() - if line.startswith(pattern): - return line.split()[-1] - - -def convert_type_repr(x): - # Currently we only capture the pointer type and assume the pointer is on global memory. - # TODO: Capture and support shared memory space - match = re.search(r'!tt\.ptr<([^,]+)', x) - if match is not None: - return '*' + convert_type_repr(match.group(1)) - return x - - -def make_hash(fn, target, env_vars, device_backend, **kwargs): - if device_backend is None: - version_key = get_cuda_version_key() - else: - version_key = device_backend.get_version_key() - if isinstance(fn, JITFunction): - configs = kwargs["configs"] - signature = kwargs["signature"] - constants = kwargs.get("constants", dict()) - num_warps = kwargs.get("num_warps", 4) - num_ctas = kwargs.get("num_ctas", 1) - num_stages = kwargs.get("num_stages", 3) - enable_warp_specialization = kwargs.get("enable_warp_specialization", False) - enable_persistent = kwargs.get("enable_persistent", False) - debug = kwargs.get("debug", False) - # Get unique key for the compiled code - get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), - sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8)) - configs_key = [get_conf_key(conf) for conf in configs] - env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())] - key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}" +class AttrsDescriptor: + divisible_by_16: set = None + equal_to_1: set = None + ids_of_folded_args: set = None + divisible_by_8: set = None + + def __post_init__(self): + if self.divisible_by_16 is None: + self.divisible_by_16 = set() + if self.equal_to_1 is None: + self.equal_to_1 = set() + if self.ids_of_folded_args is None: + self.ids_of_folded_args = set() + if self.divisible_by_8 is None: + self.divisible_by_8 = set() + + def hash(self): + key = str([sorted(x) for x in self.__dict__.values()]) return hashlib.md5(key.encode("utf-8")).hexdigest() - assert isinstance(fn, str) - ignore_version = kwargs.get('ignore_version', False) - if (ignore_version): - return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest() - return hashlib.md5((Path(fn).read_text() + version_key).encode("utf-8")).hexdigest() # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, @@ -281,10 +56,6 @@ def make_hash(fn, target, env_vars, device_backend, **kwargs): "ptx": ptx_prototype_pattern, } -# - ((?:[^,\s<]+|<[^>]+>)+): Capturing group that matches one or more of either: -# [^,\s<]+: One or more characters that are not a comma, whitespace, or the < symbol. -# |: OR -# <[^>]+>: A string that starts with < and ends with >, containing any characters except > in between. mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<]+|<[^>]+>)+),?' ptx_arg_type_pattern = r"\.param\s+\.(\w+)" arg_type_pattern = { @@ -292,29 +63,19 @@ def make_hash(fn, target, env_vars, device_backend, **kwargs): "ttgir": mlir_arg_type_pattern, "ptx": ptx_arg_type_pattern, } -if is_hip(): - ttgir_num_warps_pattern = r'"triton_gpu_rocm.num-warps"\s?=\s?(\d+)\s?:' -else: - ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' - - -def _get_jsonable_constants(constants): - def _is_jsonable(x): - try: - json.dumps(x) - return True - except (TypeError, OverflowError): - return False - serialized_constants = {} - for constant in constants: - if _is_jsonable(constants[constant]): - serialized_constants[constant] = constants[constant] - return serialized_constants +def convert_type_repr(x): + # Currently we only capture the pointer type and assume the pointer is on global memory. + # TODO: Capture and support shared memory space + match = re.search(r'!tt\.ptr<([^,]+)', x) + if match is not None: + return '*' + convert_type_repr(match.group(1)) + return x def _get_num_warps_from_ir_str(src: str): + ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if # e.g. someone has an instruction (not module) attribute named "num-warps". num_warps_matches = re.findall(ttgir_num_warps_pattern, src) @@ -333,387 +94,178 @@ def _get_num_warps_from_ir_str(src: str): return num_warps -def parse_mlir_module(path, context): - module = ir.parse_mlir_module(path, context) - # module takes ownership of the context - module.context = context - return module - - -instance_descriptor = namedtuple("instance_descriptor", - ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], - defaults=[set(), set(), set(), set()]) - - -def get_cuda_capability(capability): - if capability is None: - device = get_current_device() - capability = get_device_capability(device) - capability = capability[0] * 10 + capability[1] - return capability - - -def get_arch_default_num_warps(device_type): - if device_type in ["cuda", "hip"]: - num_warps = 4 - else: - _device_backend = get_backend(device_type) - assert _device_backend - arch = _device_backend.get_architecture_descriptor() - num_warps = arch["num_warps"] - return num_warps - +class ASTSource: -def get_arch_default_num_stages(device_type, capability=None): - if device_type == "cuda": - num_stages = 3 if get_cuda_capability(capability) >= 75 else 2 - else: - _device_backend = get_backend(device_type) - assert _device_backend - arch = _device_backend.get_architecture_descriptor() - num_stages = arch["num_stages"] - - return num_stages - - -def add_cuda_stages(target, extern_libs, stages): - - stages["ptx"] = (lambda path: Path(path).read_text(), lambda src: llir_to_ptx(src, target)) - stages["cubin"] = (lambda path: Path(path).read_bytes(), lambda src: ptx_to_cubin(src, target)) - - -def compile(fn, **kwargs): - # Get device type to decide which backend should be used - device_type = kwargs.get("device_type", "cuda") - capability = kwargs.get("cc", None) - - if is_hip(): - device_type = "hip" - is_cuda = device_type == "cuda" - if is_hip(): - is_cuda = False - - context = ir.context() - constants = kwargs.get("constants", dict()) - num_warps = kwargs.get("num_warps", get_arch_default_num_warps(device_type)) - assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2" - num_ctas = kwargs.get("num_ctas", 1) - num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability)) - enable_fp_fusion = kwargs.get("enable_fp_fusion", True) - # TODO[shuhaoj]: Default should be to enable warp specialization once possible - enable_warp_specialization = kwargs.get("enable_warp_specialization", False) - # TODO[shuhaoj]: persistent can be decoupled with warp specialization - enable_persistent = kwargs.get("enable_persistent", enable_warp_specialization) - extern_libs = kwargs.get("extern_libs", dict()) - if extern_libs is None: - extern_libs = dict() - debug = kwargs.get("debug", False) - # Flag to control whether to store mma layout directly - optimize_epilogue = False - if os.environ.get('OPTIMIZE_EPILOGUE', '') == '1': - optimize_epilogue = True - # - cluster_info = ClusterInfo() - if "clusterDims" in kwargs: - cluster_info.clusterDimX = kwargs["clusterDims"][0] - cluster_info.clusterDimY = kwargs["clusterDims"][1] - cluster_info.clusterDimZ = kwargs["clusterDims"][2] - tma_infos = TMAInfos() - # build architecture descriptor - if device_type == "cuda": - _device_backend = get_backend(device_type) - target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps, - enable_fp_fusion=enable_fp_fusion) - else: - _device_backend = get_backend(device_type) - assert _device_backend - target = _device_backend.get_architecture_descriptor(**kwargs) - # build compilation stages - stages = dict() - stages["ast"] = (lambda path: fn, None) - stages["ttir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttir( - ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target)) - if is_cuda: - stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir( - ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, - enable_warp_specialization, enable_persistent, optimize_epilogue)) - stages["llir"] = (lambda path: Path(path).read_text(), - lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos)) - add_cuda_stages(target, extern_libs, stages) - elif device_type == "hip": - _device_backend.add_stages(target, extern_libs, stages, num_warps=num_warps, num_stages=num_stages) - else: - # pass the user's configuration to the backend device. - target["num_warps"] = num_warps - target["num_stages"] = num_stages - target["num_ctas"] = num_ctas - _device_backend.add_stages(target, extern_libs, stages) - - # find out the signature of the function - if isinstance(fn, JITFunction): - configs = kwargs.get("configs", None) - signature = kwargs["signature"] - if configs is None: - configs = [instance_descriptor()] - assert len(configs) == 1 - kwargs["configs"] = configs - name = fn.__name__ - first_stage = 0 - if isinstance(signature, str): - signature = {k: v.strip() for k, v in enumerate(signature.split(","))} - kwargs["signature"] = signature - else: - assert isinstance(fn, str) - _, ir_name = os.path.basename(fn).split(".") - src = Path(fn).read_text() - import re - match = re.search(prototype_pattern[ir_name], src, re.MULTILINE) - # TODO: support function attributes at group 3 (e.g., device function) - name, signature = match.group(1), match.group(2) - types = re.findall(arg_type_pattern[ir_name], signature) - if ir_name == 'ttgir': - num_warps_from_ir = _get_num_warps_from_ir_str(src) - assert "num_warps" not in kwargs or num_warps_from_ir == num_warps, "num_warps in ttgir does not match num_warps in compile" - num_warps = num_warps_from_ir - - param_tys = [convert_type_repr(ty) for ty in types] - signature = {k: v for k, v in enumerate(param_tys)} - first_stage = list(stages.keys()).index(ir_name) + def __init__(self, fn, signature, constants=None, attrs=None) -> None: + self.fn = fn + self.ext = "ttir" + self.name = fn.__name__ + self.signature = signature + self.constants = constants + self.attrs = attrs + if isinstance(self.signature, str): + self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} + if self.constants is None: + self.constants = dict() + if self.attrs is None: + self.attrs = AttrsDescriptor() + + def hash(self): + key = f"{self.fn.cache_key}-{self.attrs.hash()}-{self.signature.values()}-{self.constants}" + return hashlib.md5(key.encode("utf-8")).hexdigest() + def make_ir(self, options): + return ast_to_ttir(self.fn, self, options=options) + + def metadata(self): + # TODO: remove once TMA support is cleaned up + return {"ids_of_folded_args": tuple([int(k) for k in self.attrs.ids_of_folded_args])} + + def update_options(self, options): + pass + + +class IRSource: + + def __init__(self, path): + self.path = path + path = Path(path) + self.ext = path.suffix[1:] + self.src = path.read_text() + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + + def hash(self): + return hashlib.md5(self.src.encode("utf-8")).hexdigest() + + def make_ir(self, options): + context = ir.context() + module = ir.parse_mlir_module(self.path, context) + module.context = context + return module + + def metadata(self): + return dict() + + def update_options(self, options): + if self.ext == "ttgir": + options.num_warps = _get_num_warps_from_ir_str(self.src) + + +def compile(src, target=None, compiler_options=None, linker_options=None): + if target is None: + target = get_current_target() + backend = CUDABackend(target) + # create backend + compiler_options = backend.parse_compiler_options(compiler_options or dict()) + linker_options = backend.parse_linker_options(linker_options or dict()) + if not isinstance(src, ASTSource): + assert isinstance(src, str), "source must be either AST or a filepath" + src = IRSource(src) + src.update_options(compiler_options) # create cache manager - fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs)) - # managers used to dump and override IR for debugging - enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" - fn_override_manager = get_override_manager( - make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True)) - fn_dump_manager = get_dump_manager( - make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True)) - - # determine name and extension type of provided function - if isinstance(fn, JITFunction): - name, ext = fn.__name__, "ast" - else: - name, ext = os.path.basename(fn).split(".") - - # load metadata if any - metadata = None - metadata_filename = f"{name}.json" - - # The group is addressed by the metadata + key = f"{src.hash()}-{backend.hash()}-{compiler_options.hash()}-{frozenset(sorted(get_env_vars().items()))}" + hash = hashlib.md5(key.encode("utf-8")).hexdigest() + fn_cache_manager = get_cache_manager(hash) + metadata_filename = f"{src.name}.json" metadata_group = fn_cache_manager.get_group(metadata_filename) or {} - metadata_path = metadata_group.get(metadata_filename) - if metadata_path is not None: - with open(metadata_path) as f: - metadata = json.load(f) - if 'tensormaps_info' in metadata: - metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']] - else: - metadata = { - "num_warps": num_warps, - "num_ctas": num_ctas, - "num_stages": num_stages, - "enable_warp_specialization": enable_warp_specialization, - "enable_persistent": enable_persistent, - "constants": _get_jsonable_constants(constants), - "debug": debug, - "target": target, - } - metadata.update(get_env_vars()) - if ext == "ptx": - assert "shared" in kwargs, "ptx compilation must provide shared memory size" - metadata["shared"] = kwargs["shared"] - - # Add device type to meta information - metadata["device_type"] = device_type - - first_stage = list(stages.keys()).index(ext) - asm = LazyDict() - module = fn + # cache hit! + metadata = json.loads(Path(metadata_path).read_text()) + so_path = backend.make_launcher_stub(src, metadata) + return CompiledKernel(so_path, metadata_path) + # initialize metadata + metadata = { + "target": target, + **compiler_options.__dict__, + **linker_options.__dict__, + **get_env_vars(), + **src.metadata(), + } # run compilation pipeline and populate metadata - for ir_name, (parse, compile_kernel) in list(stages.items())[first_stage:]: - ir_filename = f"{name}.{ir_name}" - - if ir_name == ext: - next_module = parse(fn) - else: - path = metadata_group.get(ir_filename) - if path is None: - next_module = compile_kernel(module) - if ir_name == "amdgcn": - extra_file_name = f"{name}.hsaco_path" - metadata_group[ir_filename] = fn_cache_manager.put(next_module[0], ir_filename) - metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name) - else: - metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) - fn_dump_manager.put(next_module, ir_filename) - if (enable_override and fn_override_manager.has_file(ir_filename)): - print(f"\nOverriding kernel with file {ir_filename}") - full_name = fn_override_manager.get_file(ir_filename) - next_module = parse(full_name) - else: - if ir_name == "amdgcn": - extra_file_name = f"{name}.hsaco_path" - hasco_path = metadata_group.get(extra_file_name) - assert hasco_path is not None, "Expected to have hsaco_path in metadata when we have the amdgcn" - next_module = (parse(path), parse(hasco_path)) - else: - next_module = parse(path) - - if ir_name == "cubin": - asm[ir_name] = next_module - asm["sass"] = lambda: get_sass(next_module) - elif ir_name == "amdgcn": - asm[ir_name] = str(next_module[0]) - else: - asm[ir_name] = str(next_module) - if ir_name == "llir" and "shared" not in metadata: - if is_hip(): - metadata["shared"] = _device_backend.get_shared_memory_size(module) - else: - metadata["shared"] = get_shared_memory_size(module) - if ir_name == "ttgir": - if is_hip(): - metadata["num_warps"] = _device_backend.get_num_warps(next_module) - else: - metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module) - if metadata["enable_warp_specialization"]: - metadata["num_warps"] = get_num_warps(next_module) - if ir_name == "ptx": - metadata["name"] = get_kernel_name(next_module, pattern='// .globl') - if ir_name == "amdgcn": - metadata["name"] = get_kernel_name(next_module[0], pattern='.globl') - asm["hsaco_path"] = next_module[1] - if not is_cuda and not is_hip(): - _device_backend.add_meta_info(ir_name, module, next_module, metadata, asm) + stages = dict() + backend.add_stages(stages, compiler_options, linker_options) + first_stage = list(stages.keys()).index(src.ext) + module = src.make_ir(compiler_options) + for ext, compile_ir in list(stages.items())[first_stage:]: + next_module = compile_ir(module, metadata) + metadata_group[f"{src.name}.{ext}"] = fn_cache_manager.put(next_module, f"{src.name}.{ext}") module = next_module - - ids_of_folded_args = tuple([int(k) for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else () - if "clusterDims" not in metadata: - metadata["clusterDims"] = [cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ] - - if len(tma_infos) > 0: - metadata["tensormaps_info"] = parse_tma_info(tma_infos, ids_of_folded_args) - # set constant - if "tensormaps_info" in metadata: - for i, _ in enumerate(metadata["tensormaps_info"]): - metadata["tensormaps_info"][i].ids_of_folded_args = ids_of_folded_args - - ids_of_tensormaps = get_ids_of_tensormaps(metadata.get("tensormaps_info", None)) - if isinstance(fn, JITFunction) and "tensormaps_info" in metadata: - fn.tensormaps_info = metadata["tensormaps_info"] - - ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else () - ids = { - "ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": - ids_of_const_exprs - } - # cache manager - if is_cuda: - so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization) - else: - so_path = _device_backend.make_launcher_stub(name, signature, constants, ids) - # write-back metadata, if it didn't come from the cache - if metadata_path is None: - metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, - binary=False) + # write-back metadata + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, + binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) - + so_path = backend.make_launcher_stub(src, metadata) # return handle to compiled kernel - return CompiledKernel(fn, so_path, metadata, asm) + return CompiledKernel(so_path, metadata_group.get(metadata_filename)) class CompiledKernel: # Hooks for external tools to monitor the execution of triton kernels + # TODO: move out of this namespace since it's a runtime thing launch_enter_hook = None launch_exit_hook = None - tensormap_manager = TensorMapManager() - def __init__(self, fn, so_path, metadata, asm): + def __init__(self, so_path, metadata_path): + metadata_path = Path(metadata_path) # initialize launcher import importlib.util spec = importlib.util.spec_from_file_location("__triton_launcher", so_path) mod = importlib.util.module_from_spec(spec) - self.fn = fn spec.loader.exec_module(mod) self.c_wrapper = getattr(mod, "launch") # initialize metadata - self.shared = metadata["shared"] - self.num_warps = metadata["num_warps"] - if "threads_per_warp" in metadata: - self.threads_per_warp = metadata["threads_per_warp"] - self.num_ctas = metadata["num_ctas"] - self.num_stages = metadata["num_stages"] - self.clusterDims = metadata["clusterDims"] - if "tensormaps_info" in metadata: - self.tensormaps_info = metadata["tensormaps_info"] - self.constants = metadata["constants"] - self.device_type = metadata["device_type"] - self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda"] else None - # initialize asm dict - self.asm = asm + self.metadata = json.loads(metadata_path.read_text()) + self.metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in self.metadata['tensormaps_info'] + ] if 'tensormaps_info' in self.metadata else [] + for i, _ in enumerate(self.metadata["tensormaps_info"]): + self.metadata["tensormaps_info"][i].ids_of_folded_args = tuple(self.metadata["ids_of_folded_args"]) + for key, val in self.metadata.items(): + setattr(self, key, val) + # stores the text of each level of IR that was generated during compilation + asm_files = [file for file in metadata_path.parent.glob(f'{metadata_path.stem}.*') if file.suffix != '.json'] + self.asm = { + file.suffix[1:]: file.read_bytes() if file.suffix[1:] == driver.binary_ext else file.read_text() + for file in asm_files + } + self.kernel = self.asm[driver.binary_ext] # binaries are lazily initialized # because it involves doing runtime things # (e.g., checking amount of shared memory on current device) - self.metadata = metadata - self.cu_module = None - self.cu_function = None + self.module = None + self.function = None def _init_handles(self): - if self.cu_module is not None: + if self.module is not None: return - - if self.device_type in ["cuda"]: - device = get_current_device() - bin_path = {driver.HIP: "hsaco_path", driver.CUDA: "cubin"}[driver.backend] - max_shared = driver.utils.get_device_properties(device)["max_shared_mem"] - fn_load_binary = driver.utils.load_binary - else: - assert self.device_backend - device = self.device_backend.get_current_device() - bin_path = self.device_backend.get_kernel_bin() - max_shared = self.device_backend.get_device_properties(device)["max_shared_mem"] - fn_load_binary = self.device_backend.get_load_binary_fn() - + device = get_current_device() + # not enough shared memory to run the kernel + max_shared = driver.utils.get_device_properties(device)["max_shared_mem"] if self.shared > max_shared: raise OutOfResources(self.shared, max_shared, "shared memory") - - mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device) - - self.n_spills = n_spills - self.n_regs = n_regs - self.cu_module = mod - self.cu_function = func + # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` + self.module, self.function, self.n_regs, self.n_spills = driver.utils.load_binary( + self.name, self.kernel, self.shared, device) def __getattribute__(self, name): if name == 'c_wrapper': self._init_handles() return super().__getattribute__(name) - # capture args and expand args with cutensormap* - def assemble_tensormap_to_arg(self, args): - args_with_tma = list(args) - if hasattr(self, 'tensormaps_info'): - # tuple for hashable - args_ptr = tuple([arg.data_ptr() if hasattr(arg, 'data_ptr') else arg for arg in args]) - for i, e in enumerate(self.tensormaps_info): - args_with_tma.append(CompiledKernel.tensormap_manager[(e, args_ptr)]) - return args_with_tma - def __getitem__(self, grid): self._init_handles() def runner(*args, stream=None): - args_expand = self.assemble_tensormap_to_arg(args) + args_expand = driver.assemble_tensormap_to_arg(self.tensormaps_info, args) if stream is None: - if self.device_type in ["cuda"]: - stream = get_cuda_stream() - else: - stream = get_backend(self.device_type).get_stream(None) - self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0], - self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function, + stream = get_cuda_stream() + self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.cluster_dims[0], + self.cluster_dims[1], self.cluster_dims[2], self.shared, stream, self.function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand) return runner diff --git a/python/triton/compiler/utils.py b/python/triton/compiler/utils.py index 7844c1ebb265..48233afedd4c 100644 --- a/python/triton/compiler/utils.py +++ b/python/triton/compiler/utils.py @@ -280,25 +280,3 @@ def __eq__(self, other): other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, other.oobFill) - - -class TensorMapManager: - - def __init__(self): - self.tensormaps_device = {} - - def __getitem__(self, key: tuple): - if key in self.tensormaps_device: - return int(self.tensormaps_device[key]) - else: - (e, args) = key - t_tensormap = e.tensormap(args) - TENSORMAP_SIZE_IN_BYTES = 128 - t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES) - driver.utils.cuMemcpyHtoD(t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES) - self.tensormaps_device[key] = t_tensormap_device - return int(self.tensormaps_device[key]) - - def __del__(self): - for _, v in self.tensormaps_device.items(): - driver.utils.cuMemFree(v) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 56fb2f231a37..a60a9b7bc832 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1447,7 +1447,6 @@ def _promote_reduction_input(t, _builder=None): # hardware doesn't support FMAX, FMIN, CMP for bfloat16 if scalar_ty is bfloat16: return t.to(float32, _builder=_builder) - return t diff --git a/python/triton/language/extra/cuda.py b/python/triton/language/extra/cuda.py index 9400ae797887..1cb494d9fa7f 100644 --- a/python/triton/language/extra/cuda.py +++ b/python/triton/language/extra/cuda.py @@ -15,4 +15,4 @@ def smid(_builder=None): @core.builtin def num_threads(_builder=None): - return core.constexpr(_builder.target.num_warps * 32) + return core.constexpr(_builder.options.num_warps * 32) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 4b8723c56e45..c1ee1036ba6f 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -9,16 +9,6 @@ T = TypeVar('T') -# TODO: redundant code -- remove after 3P backend refactor - - -def _is_cuda(target): - from ..compiler.compiler import CudaTargetDescriptor - return isinstance(target, CudaTargetDescriptor) - - -# Create custom exception that prints message "hello" - class IncompatibleTypeErrorImpl(Exception): @@ -651,9 +641,8 @@ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar - if _is_cuda(builder.target) and builder.target.capability < 89 and \ - (src_sca_ty.is_fp8e4nv() or dst_sca_ty.is_fp8e4nv()): - assert False, "fp8e4nv data type is not supported on CUDA arch < 89" + if (src_sca_ty.is_fp8e4nv() or dst_sca_ty.is_fp8e4nv()): + assert builder.options.allow_fp8e4nv, "fp8e4nv data type is not supported on CUDA arch < 89" # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ @@ -1185,13 +1174,8 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool: def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_num_imprecise_acc: int, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: - def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): - # Checks for non-cuda archs - if not _is_cuda(target): - assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" - return - # Checks for cuda arch - if target.capability < 90: + def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): + if not options.allow_fp8e4nv: assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv( ), "Dot op does not support fp8e4nv on CUDA arch < 90" if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): @@ -1220,7 +1204,7 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): assert lhs.type.is_block() and rhs.type.is_block() - assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.target) + assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.options) assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!" assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!" @@ -1279,11 +1263,11 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): assert acc.type == ret_ty # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 - if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() - and ret_scalar_ty.is_fp32()): - max_num_imprecise_acc = 0 - if max_num_imprecise_acc is None: - max_num_imprecise_acc = 2**30 + max_num_imprecise_acc = 0 + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): + max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default + if max_num_imprecise_acc is None: + max_num_imprecise_acc = 2**30 return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty) diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index 249471062775..bf158bbb0e54 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -6,6 +6,7 @@ from ..common.build import _build from .cache import get_cache_manager +from ..runtime import driver class DriverBase(metaclass=abc.ABCMeta): @@ -66,7 +67,30 @@ def __init__(self): self.cu_occupancy_max_active_clusters = mod.cu_occupancy_max_active_clusters +class TensorMapManager: + + def __init__(self): + self.tensormaps_device = {} + + def __getitem__(self, key: tuple): + if key in self.tensormaps_device: + return int(self.tensormaps_device[key]) + else: + (e, args) = key + t_tensormap = e.tensormap(args) + TENSORMAP_SIZE_IN_BYTES = 128 + t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES) + driver.utils.cuMemcpyHtoD(t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES) + self.tensormaps_device[key] = t_tensormap_device + return int(self.tensormaps_device[key]) + + def __del__(self): + for _, v in self.tensormaps_device.items(): + driver.utils.cuMemFree(v) + + class CudaDriver(DriverBase): + tensormap_manager = TensorMapManager() def __new__(cls): if not hasattr(cls, "instance"): @@ -76,6 +100,16 @@ def __new__(cls): def __init__(self): self.utils = CudaUtils() self.backend = self.CUDA + self.binary_ext = "cubin" + + def assemble_tensormap_to_arg(self, tensormaps_info, args): + args_with_tma = list(args) + if tensormaps_info is not None: + # tuple for hashable + args_ptr = tuple([arg.data_ptr() if hasattr(arg, 'data_ptr') else arg for arg in args]) + for i, e in enumerate(tensormaps_info): + args_with_tma.append(CudaDriver.tensormap_manager[(e, args_ptr)]) + return args_with_tma # ----------------------------- diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 286d069c15f1..42b16e3dba74 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -13,6 +13,7 @@ from .._C.libtriton.triton import TMAInfos from ..common.backend import get_backend, get_cuda_version_key from .interpreter import InterpretedFunction +from ..runtime.driver import driver def get_cuda_stream(idx=None): @@ -46,6 +47,14 @@ def get_device_capability(idx): return torch.cuda.get_device_capability(idx) +def get_current_target(): + import torch + device = get_current_device() + capability = get_device_capability(device) + capability = capability[0] * 10 + capability[1] + return ("cuda", capability) + + T = TypeVar("T") # ----------------------------------------------------------------------------- @@ -252,6 +261,7 @@ def _spec_of(arg): # TODO(jlebar): Fold this into the KernelArg class. def _get_config(self, *args): + from ..compiler import AttrsDescriptor def is_divisible_by_16(x): if hasattr(x, "data_ptr"): @@ -288,10 +298,8 @@ def is_divisible_by_8(x): # TODO: method to collect all folded args none_args = {param.num for param, arg in zip(self.params, args) if arg is None and not param.do_not_specialize} ids_of_folded_args = equal_to_1 | none_args - return namedtuple("instance_descriptor", - ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])( # - tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), - tuple(divisible_by_8)) + return AttrsDescriptor(tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), + tuple(divisible_by_8)) # return _triton.code_gen.instance_descriptor(divisible_by_16, # equal_to_1) @@ -400,7 +408,7 @@ def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: Li return device_types[0] if len(device_types) > 0 else "cuda" def run(self, *args, **kwargs): - from ..compiler import CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps + from ..compiler import CompiledKernel, compile, ASTSource # Get a compiler-flags arg like `num_warps` and remove it from kwargs. def get_special_arg(name: str, default=None): @@ -472,9 +480,9 @@ def get_special_arg(name: str, default=None): stream = device_backend.get_stream() if num_warps is None: - num_warps = get_arch_default_num_warps(device_type) + num_warps = 4 if num_stages is None: - num_stages = get_arch_default_num_stages(device_type) + num_stages = 3 if device_type in ["cuda"]: version_key = get_cuda_version_key() @@ -529,20 +537,21 @@ def get_special_arg(name: str, default=None): ): return None + capability = get_device_capability(device) + capability = capability[0] * 10 + capability[1] + src = ASTSource(self, signature, constants, configs[0]) self.cache[device][key] = compile( - self, - signature=signature, - device=device, - constants=constants, - num_warps=num_warps, - num_ctas=num_ctas, - num_stages=num_stages, - enable_warp_specialization=enable_warp_specialization, - enable_fp_fusion=enable_fp_fusion, - extern_libs=extern_libs, - configs=configs, - debug=self.debug, - device_type=device_type, + src, + target=(device_type, capability), + compiler_options={ + "num_warps": num_warps, + "num_ctas": num_ctas, + "num_stages": num_stages, + "enable_warp_specialization": enable_warp_specialization, + "enable_fp_fusion": enable_fp_fusion, + "debug": self.debug, + }, + linker_options={"libs": extern_libs}, ) bin = self.cache[device][key] @@ -553,16 +562,16 @@ def get_special_arg(name: str, default=None): grid_2, bin.num_warps, bin.num_ctas, - bin.clusterDims[0], - bin.clusterDims[1], - bin.clusterDims[2], + bin.cluster_dims[0], + bin.cluster_dims[1], + bin.cluster_dims[2], bin.shared, stream, - bin.cu_function, + bin.function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, - *bin.assemble_tensormap_to_arg(non_constexpr_arg_values), + *driver.assemble_tensormap_to_arg(bin.metadata["tensormaps_info"], non_constexpr_arg_values), ) return bin diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index a69c7100ddd0..e28a7a404c02 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -92,11 +92,11 @@ def constexpr(s): hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} hints = {k: v for k, v in hints.items() if v is not None} - constexprs = {i: constexpr(s) for i, s in enumerate(signature)} - constexprs = {k: v for k, v in constexprs.items() if v is not None} - signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constexprs} - const_sig = 'x'.join([str(v) for v in constexprs.values()]) - doc_string = [f"{kernel.arg_names[i]}={constexprs[i]}" for i in constexprs.keys()] + constants = {i: constexpr(s) for i, s in enumerate(signature)} + constants = {k: v for k, v in constants.items() if v is not None} + signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constants} + const_sig = 'x'.join([str(v) for v in constants.values()]) + doc_string = [f"{kernel.arg_names[i]}={constants[i]}" for i in constants.keys()] doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] # compile ast into cubin @@ -104,11 +104,12 @@ def constexpr(s): assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" divisible_by_16 = [i for i, h in hints.items() if h == 16] equal_to_1 = [i for i, h in hints.items() if h == 1] - config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) + attrs = triton.compiler.AttrsDescriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) for i in equal_to_1: - constexprs.update({i: 1}) - ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], - num_warps=args.num_warps, num_stages=args.num_stages) + constants.update({i: 1}) + src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) + opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} + ccinfo = triton.compile(src, compiler_options=opts) arg_names = [] arg_types = [] for i in signature.keys(): @@ -117,7 +118,7 @@ def constexpr(s): arg_types += [signature[i]] # dump C stub code - suffix = kernel_suffix(signature.values(), config) + suffix = kernel_suffix(signature.values(), attrs) func_name = '_'.join([out_name, sig_hash, suffix]) triton_kernel_name = '_'.join([args.kernel_name, suffix]) hex_ = str(binascii.hexlify(ccinfo.asm["cubin"]))[2:-1]