From adfdb35997781346346a399053b2812b960140d5 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 12 Nov 2023 17:45:59 -0800 Subject: [PATCH 01/64] some cleaning --- python/triton/compiler/__init__.py | 4 + python/triton/compiler/compiler.py | 425 +++++------------------------ python/triton/compiler/target.py | 0 3 files changed, 74 insertions(+), 355 deletions(-) delete mode 100644 python/triton/compiler/target.py diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index fd0665e1e549..d700656239a4 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,8 +1,12 @@ from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps, instance_descriptor) from .errors import CompilationError +from .backends.cuda import CUDABackend +from ..common.backend import register_backend __all__ = [ "compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages" ] + +register_backend("cuda", CUDABackend) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 9ecd15d718d3..4fbcdaf61d26 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -1,38 +1,23 @@ from __future__ import annotations -import functools import hashlib import json import os import re from collections import namedtuple from pathlib import Path -from typing import Any -from dataclasses import dataclass - -from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, compile_ptx_to_cubin, get_env_vars, - get_num_warps, get_shared_memory_size, ir, runtime, translate_llvmir_to_ptx, - translate_triton_gpu_to_llvmir) -from ..common.backend import get_backend, get_cuda_version_key, path_to_ptxas +from .._C.libtriton.triton import (get_env_vars, ir) +from ..common.backend import get_backend from ..common.build import is_hip # from ..runtime import driver, jit, JITFunction # TODO: runtime.errors from ..runtime.autotuner import OutOfResources -from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager -from ..runtime.driver import driver +from ..runtime.cache import get_cache_manager from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability) -from ..tools.disasm import get_sass from .code_generator import ast_to_ttir -from .make_launcher import make_stub -from .utils import (InfoFromBackendForTensorMap, TensorMapManager, get_ids_of_tensormaps, parse_tma_info) - - -@dataclass -class CudaTargetDescriptor: - capability: int - num_warps: int - enable_fp_fusion: bool +from .utils import (InfoFromBackendForTensorMap, TensorMapManager) +from .backends.cuda import CudaTargetDescriptor def _is_cuda(target): @@ -83,148 +68,9 @@ def optimize_ttir(mod, target): return mod -def ttir_to_ttgir(mod, num_warps, num_ctas, target): - pm = ir.pass_manager(mod.context) - pm.enable_debug() - pm.add_convert_triton_to_tritongpu_pass(num_warps, 32, num_ctas, target.capability) - pm.run(mod) - return mod - - -def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, - enable_persistent, optimize_epilogue): - is_cuda = _is_cuda(target) - if is_cuda: - capability = target.capability - pm = ir.pass_manager(mod.context) - pm.enable_debug() - pm.add_tritongpu_coalesce_pass() - # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass - pm.add_plan_cta_pass(cluster_info) - if is_cuda: - pm.add_tritongpu_rewrite_tensor_pointer_pass(capability) - pm.add_plan_cta_pass(cluster_info) - pm.add_tritongpu_remove_layout_conversions_pass() - if is_cuda: - pm.add_tritongpu_accelerate_matmul_pass(capability) - pm.add_tritongpu_remove_layout_conversions_pass() - if optimize_epilogue: - pm.add_tritongpu_optimize_epilogue_pass() - pm.add_tritongpu_optimize_dot_operands_pass() - pm.add_cse_pass() - ws_enabled = False - # `num_warps` does not mean the total number of warps of a CTA when - # warp specialization is enabled. - # it's the responsibility of the compiler to figure out the exact - # `num_warps` to use. - # TODO: support the case where `num_warps` from user is not 4. - if capability // 10 >= 9 and enable_warp_specialization and num_warps == 4: - pm.add_tritongpu_ws_feasibility_checking_pass(capability) - pm.run(mod) - ws_enabled = ir.is_ws_supported(mod) - pm = ir.pass_manager(mod.context) - pm.enable_debug() - if ws_enabled: - pm.add_tritongpu_wsdecomposing_pass(capability) - pm.add_tritongpu_wspipeline_pass(num_stages, num_warps, capability) - pm.add_tritongpu_wsmutex_pass(capability) - pm.add_tritongpu_wsmaterialization_pass(capability) - pm.add_licm_pass() - pm.add_cse_pass() - else: - pm.add_tritongpu_pipeline_pass(num_stages, num_warps, num_ctas, capability) - pm.add_tritongpu_materialize_load_store_pass(num_warps, capability) - if capability // 10 <= 8: - pm.add_tritongpu_prefetch_pass() - pm.add_tritongpu_optimize_dot_operands_pass() - pm.add_tritongpu_remove_layout_conversions_pass() - pm.add_tritongpu_decompose_conversions_pass() - pm.add_tritongpu_ws_fixup_missing_attrs_pass() - pm.add_tritongpu_reorder_instructions_pass() - pm.add_cse_pass() - pm.add_symbol_dce_pass() - if capability // 10 >= 9: - pm.add_tritongpu_fence_insertion_pass() - pm.add_tritongpu_ws_fixup_missing_attrs_pass() - pm.add_tritongpu_optimize_thread_locality_pass() - pm.add_canonicalizer_pass() - pm.run(mod) - return mod - - -def _add_external_libs(mod, libs): - for name, path in libs.items(): - if len(name) == 0 or len(path) == 0: - return - add_external_libs(mod, list(libs.keys()), list(libs.values())) - - -def ttgir_to_llir(mod, extern_libs, target, tma_infos): - if extern_libs: - _add_external_libs(mod, extern_libs) - # TODO: separate tritongpu_to_llvmir for different backends - if _is_cuda(target): - return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM) - else: - return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL) - - -# PTX translation - - -@functools.lru_cache() -def ptx_get_version(cuda_version) -> int: - ''' - Get the highest PTX version supported by the current CUDA driver. - ''' - assert isinstance(cuda_version, str) - major, minor = map(int, cuda_version.split('.')) - if major == 12: - return 80 + minor - if major == 11: - return 70 + minor - if major == 10: - return 63 + minor - raise RuntimeError("Triton only support CUDA 10.0 or higher") - - -def llir_to_ptx(mod: Any, target: CudaTargetDescriptor, ptx_version: int = None) -> str: - ''' - Translate TritonGPU module to PTX code. - :param mod: a TritonGPU dialect module - :return: PTX code - ''' - if ptx_version is None: - _, cuda_version = path_to_ptxas() - ptx_version = ptx_get_version(cuda_version) - return translate_llvmir_to_ptx(mod, target.capability, ptx_version, target.enable_fp_fusion) - - -def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor): - ''' - Compile TritonGPU module to cubin. - :param ptx: ptx code - :param compute_capability: compute capability - :return: str - ''' - ptxas, _ = path_to_ptxas() - return compile_ptx_to_cubin(ptx, ptxas, target.capability, target.enable_fp_fusion) - - # ------------------------------------------------------------------------------ # compiler # ------------------------------------------------------------------------------ -def get_kernel_name(src: str, pattern: str) -> str: - ''' - Get kernel name from PTX code. - This Kernel name is required when launching the kernel. - ''' - # There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin. - assert src - for line in src.split('\n'): - line = line.strip() - if line.startswith(pattern): - return line.split()[-1] def convert_type_repr(x): @@ -237,10 +83,7 @@ def convert_type_repr(x): def make_hash(fn, target, env_vars, device_backend, **kwargs): - if device_backend is None: - version_key = get_cuda_version_key() - else: - version_key = device_backend.get_version_key() + version_key = device_backend.get_version_key() if isinstance(fn, JITFunction): configs = kwargs["configs"] signature = kwargs["signature"] @@ -326,6 +169,7 @@ def parse_mlir_module(path, context): defaults=[set(), set(), set(), set()]) +# TODO: remove def get_cuda_capability(capability): if capability is None: device = get_current_device() @@ -357,96 +201,64 @@ def get_arch_default_num_stages(device_type, capability=None): return num_stages -def add_cuda_stages(target, extern_libs, stages): - - stages["ptx"] = (lambda path: Path(path).read_text(), lambda src: llir_to_ptx(src, target)) - stages["cubin"] = (lambda path: Path(path).read_bytes(), lambda src: ptx_to_cubin(src, target)) - - -def compile(fn, **kwargs): +def compile(src, **kwargs): # Get device type to decide which backend should be used - device_type = kwargs.get("device_type", "cuda") - capability = kwargs.get("cc", None) - - if is_hip(): - device_type = "hip" - is_cuda = device_type == "cuda" - if is_hip(): - is_cuda = False - - context = ir.context() constants = kwargs.get("constants", dict()) - num_warps = kwargs.get("num_warps", get_arch_default_num_warps(device_type)) - assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2" - num_ctas = kwargs.get("num_ctas", 1) - num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability)) - enable_fp_fusion = kwargs.get("enable_fp_fusion", True) - # TODO[shuhaoj]: Default should be to enable warp specialization once possible - enable_warp_specialization = kwargs.get("enable_warp_specialization", False) - # TODO[shuhaoj]: persistent can be decoupled with warp specialization - enable_persistent = kwargs.get("enable_persistent", enable_warp_specialization) + # create backend handler + device_type = kwargs.get("device_type", "cuda") + _device_backend = get_backend(device_type) + print(_device_backend) + assert _device_backend + target = _device_backend.get_architecture_descriptor(**kwargs) + # extern libs extern_libs = kwargs.get("extern_libs", dict()) if extern_libs is None: extern_libs = dict() - debug = kwargs.get("debug", False) - # Flag to control whether to store mma layout directly - optimize_epilogue = False - if os.environ.get('OPTIMIZE_EPILOGUE', '') == '1': - optimize_epilogue = True - # - cluster_info = ClusterInfo() - if "clusterDims" in kwargs: - cluster_info.clusterDimX = kwargs["clusterDims"][0] - cluster_info.clusterDimY = kwargs["clusterDims"][1] - cluster_info.clusterDimZ = kwargs["clusterDims"][2] - tma_infos = TMAInfos() - # build architecture descriptor - if device_type == "cuda": - _device_backend = get_backend(device_type) - target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps, - enable_fp_fusion=enable_fp_fusion) - else: - _device_backend = get_backend(device_type) - assert _device_backend - target = _device_backend.get_architecture_descriptor(**kwargs) + + # compilation options + opts = dict() + opts["num_warps"] = kwargs.get("num_warps", None) + assert opts["num_warps"] > 0 and (opts["num_warps"] & + (opts["num_warps"] - 1)) == 0, "num_warps must be a power of 2" + opts["num_ctas"] = kwargs.get("num_ctas", None) + opts["num_stages"] = kwargs.get("num_stages", None) + opts["enable_fp_fusion"] = kwargs.get("enable_fp_fusion", True) + opts["enable_warp_specialization"] = kwargs.get("enable_warp_specialization", False) + opts["enable_persistent"] = kwargs.get("enable_persistent", False) + opts["optimize_epilogue"] = os.environ.get('OPTIMIZE_EPILOGUE', '') == '1' + opts["cluster_dims"] = kwargs.get('clusterDims', None) + opts["debug"] = kwargs.get("debug", False) + # build compilation stages + context = ir.context() stages = dict() - stages["ast"] = (lambda path: fn, None) - stages["ttir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttir( - ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target)) - if is_cuda: - stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir( - ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, - enable_warp_specialization, enable_persistent, optimize_epilogue)) - stages["llir"] = (lambda path: Path(path).read_text(), - lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos)) - add_cuda_stages(target, extern_libs, stages) - elif device_type == "hip": - _device_backend.add_stages(target, extern_libs, stages, num_warps=num_warps, num_stages=num_stages) - else: - # pass the user's configuration to the backend device. - target["num_warps"] = num_warps - target["num_stages"] = num_stages - target["num_ctas"] = num_ctas - _device_backend.add_stages(target, extern_libs, stages) + # TODO: CompilationStage object w/ both `parser` and `creator` attributes + stages["ast"] = (lambda path: src, None) + + def create_ttir(src): + ttir = ast_to_ttir(src, signature, configs[0], constants, debug=opts["debug"], target=target) + return optimize_ttir(ttir, target=target) + + stages["ttir"] = (lambda path: parse_mlir_module(path, context), create_ttir) + _device_backend.add_stages(target, extern_libs, stages, opts, context) # find out the signature of the function - if isinstance(fn, JITFunction): + if isinstance(src, JITFunction): configs = kwargs.get("configs", None) signature = kwargs["signature"] if configs is None: configs = [instance_descriptor()] assert len(configs) == 1 kwargs["configs"] = configs - name = fn.__name__ + name = src.__name__ first_stage = 0 if isinstance(signature, str): signature = {k: v.strip() for k, v in enumerate(signature.split(","))} kwargs["signature"] = signature else: - assert isinstance(fn, str) - _, ir_name = os.path.basename(fn).split(".") - src = Path(fn).read_text() + assert isinstance(src, str) + _, ir_name = os.path.basename(src).split(".") + src = Path(src).read_text() import re match = re.search(prototype_pattern[ir_name], src, re.MULTILINE) # TODO: support function attributes at group 3 (e.g., device function) @@ -455,155 +267,65 @@ def compile(fn, **kwargs): if ir_name == 'ttgir': num_warps_matches = re.findall(ttgir_num_warps_pattern, src) assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" - assert "num_warps" not in kwargs or int( - num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile" - num_warps = int(num_warps_matches[0]) + # assert "num_warps" not in kwargs or int( + # num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile" + # num_warps = int(num_warps_matches[0]) param_tys = [convert_type_repr(ty) for ty in types] signature = {k: v for k, v in enumerate(param_tys)} first_stage = list(stages.keys()).index(ir_name) # create cache manager - fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs)) - # managers used to dump and override IR for debugging - enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" - fn_override_manager = get_override_manager( - make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True)) - fn_dump_manager = get_dump_manager( - make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True)) - + fn_cache_manager = get_cache_manager(make_hash(src, target, get_env_vars(), _device_backend, **kwargs)) # determine name and extension type of provided function - if isinstance(fn, JITFunction): - name, ext = fn.__name__, "ast" + if isinstance(src, JITFunction): + name, ext = src.__name__, "ast" else: - name, ext = os.path.basename(fn).split(".") - + name, ext = os.path.basename(src).split(".") # load metadata if any metadata = None metadata_filename = f"{name}.json" - # The group is addressed by the metadata metadata_group = fn_cache_manager.get_group(metadata_filename) or {} - metadata_path = metadata_group.get(metadata_filename) - + # initialize metadata if metadata_path is not None: with open(metadata_path) as f: metadata = json.load(f) if 'tensormaps_info' in metadata: metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']] else: - metadata = { - "num_warps": num_warps, - "num_ctas": num_ctas, - "num_stages": num_stages, - "enable_warp_specialization": enable_warp_specialization, - "enable_persistent": enable_persistent, - "constants": _get_jsonable_constants(constants), - "debug": debug, - "target": target, - } + metadata = {"constants": _get_jsonable_constants(constants), "target": target} + metadata.update(opts) metadata.update(get_env_vars()) if ext == "ptx": assert "shared" in kwargs, "ptx compilation must provide shared memory size" metadata["shared"] = kwargs["shared"] - - # Add device type to meta information metadata["device_type"] = device_type + # run compilation pipeline and populate metadata first_stage = list(stages.keys()).index(ext) asm = LazyDict() - module = fn - # run compilation pipeline and populate metadata + module = src for ir_name, (parse, compile_kernel) in list(stages.items())[first_stage:]: ir_filename = f"{name}.{ir_name}" - + path = metadata_group.get(ir_filename) if ir_name == ext: - next_module = parse(fn) - else: - path = metadata_group.get(ir_filename) - if path is None: - next_module = compile_kernel(module) - if ir_name == "amdgcn": - extra_file_name = f"{name}.hsaco_path" - metadata_group[ir_filename] = fn_cache_manager.put(next_module[0], ir_filename) - metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name) - else: - metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) - fn_dump_manager.put(next_module, ir_filename) - if (enable_override and fn_override_manager.has_file(ir_filename)): - print(f"\nOverriding kernel with file {ir_filename}") - full_name = fn_override_manager.get_file(ir_filename) - next_module = parse(full_name) - else: - if ir_name == "amdgcn": - extra_file_name = f"{name}.hsaco_path" - hasco_path = metadata_group.get(extra_file_name) - assert hasco_path is not None, "Expected to have hsaco_path in metadata when we have the amdgcn" - next_module = (parse(path), parse(hasco_path)) - else: - next_module = parse(path) - - if ir_name == "cubin": - asm[ir_name] = next_module - asm["sass"] = lambda: get_sass(next_module) - elif ir_name == "amdgcn": - asm[ir_name] = str(next_module[0]) - else: - asm[ir_name] = str(next_module) - if ir_name == "llir" and "shared" not in metadata: - if is_hip(): - metadata["shared"] = _device_backend.get_shared_memory_size(module) - else: - metadata["shared"] = get_shared_memory_size(module) - if ir_name == "ttgir": - if is_hip(): - metadata["num_warps"] = _device_backend.get_num_warps(next_module) - else: - metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module) - if metadata["enable_warp_specialization"]: - metadata["num_warps"] = get_num_warps(next_module) - if ir_name == "ptx": - metadata["name"] = get_kernel_name(next_module, pattern='// .globl') - if ir_name == "amdgcn": - metadata["name"] = get_kernel_name(next_module[0], pattern='.globl') - asm["hsaco_path"] = next_module[1] - if not is_cuda and not is_hip(): - _device_backend.add_meta_info(ir_name, module, next_module, metadata, asm) + next_module = parse(src if name == ext else path) + continue + next_module = compile_kernel(module) + metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) + _device_backend.add_meta_info(ir_name, module, next_module, metadata, asm) module = next_module - ids_of_folded_args = tuple([int(k) for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else () - if "clusterDims" not in metadata: - metadata["clusterDims"] = [cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ] - - if len(tma_infos) > 0: - metadata["tensormaps_info"] = parse_tma_info(tma_infos, ids_of_folded_args) - # set constant - if "tensormaps_info" in metadata: - for i, _ in enumerate(metadata["tensormaps_info"]): - metadata["tensormaps_info"][i].ids_of_folded_args = ids_of_folded_args - - ids_of_tensormaps = get_ids_of_tensormaps(metadata.get("tensormaps_info", None)) - if isinstance(fn, JITFunction) and "tensormaps_info" in metadata: - fn.tensormaps_info = metadata["tensormaps_info"] - - ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else () - ids = { - "ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": - ids_of_const_exprs - } # cache manager - if is_cuda: - so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization) - else: - so_path = _device_backend.make_launcher_stub(name, signature, constants, ids) + so_path = _device_backend.make_launcher_stub(src, kwargs["configs"], metadata, name, signature, constants) # write-back metadata, if it didn't come from the cache if metadata_path is None: metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) - # return handle to compiled kernel - return CompiledKernel(fn, so_path, metadata, asm) + return CompiledKernel(src, so_path, metadata, asm) class CompiledKernel: @@ -631,7 +353,7 @@ def __init__(self, fn, so_path, metadata, asm): self.tensormaps_info = metadata["tensormaps_info"] self.constants = metadata["constants"] self.device_type = metadata["device_type"] - self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda"] else None + self.device_backend = get_backend(self.device_type) # initialize asm dict self.asm = asm # binaries are lazily initialized @@ -645,17 +367,10 @@ def _init_handles(self): if self.cu_module is not None: return - if self.device_type in ["cuda"]: - device = get_current_device() - bin_path = {driver.HIP: "hsaco_path", driver.CUDA: "cubin"}[driver.backend] - max_shared = driver.utils.get_device_properties(device)["max_shared_mem"] - fn_load_binary = driver.utils.load_binary - else: - assert self.device_backend - device = self.device_backend.get_current_device() - bin_path = self.device_backend.get_kernel_bin() - max_shared = self.device_backend.get_device_properties(device)["max_shared_mem"] - fn_load_binary = self.device_backend.get_load_binary_fn() + device = self.device_backend.get_current_device() + bin_path = self.device_backend.get_kernel_bin() + max_shared = self.device_backend.get_device_properties(device)["max_shared_mem"] + fn_load_binary = self.device_backend.get_load_binary_fn() if self.shared > max_shared: raise OutOfResources(self.shared, max_shared, "shared memory") diff --git a/python/triton/compiler/target.py b/python/triton/compiler/target.py deleted file mode 100644 index e69de29bb2d1..000000000000 From 3d3de11c4f25b40539f070b5a580b040eb78590a Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 12 Nov 2023 17:50:56 -0800 Subject: [PATCH 02/64] . --- python/triton/compiler/backends/cuda.py | 279 ++++++++++++++++++++++++ 1 file changed, 279 insertions(+) create mode 100644 python/triton/compiler/backends/cuda.py diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py new file mode 100644 index 000000000000..edd1be0146fd --- /dev/null +++ b/python/triton/compiler/backends/cuda.py @@ -0,0 +1,279 @@ +from triton.common.backend import BaseBackend +from pathlib import Path +from dataclasses import dataclass +import torch +from ..._C.libtriton.triton import ClusterInfo, get_num_warps, TMAInfos, translate_triton_gpu_to_llvmir, get_shared_memory_size, translate_llvmir_to_ptx, compile_ptx_to_cubin, add_external_libs +from ...common.backend import get_cuda_version_key, path_to_ptxas +from ..._C.libtriton.triton import ir, runtime +import functools +from typing import Any +from ...runtime.jit import JITFunction, get_cuda_stream +from ..utils import get_ids_of_tensormaps, parse_tma_info +from ..make_launcher import make_stub +from ...runtime.driver import driver +from ...tools.disasm import get_sass + + +@dataclass +class CudaTargetDescriptor: + capability: int + num_warps: int + enable_fp_fusion: bool + + +def ttir_to_ttgir(mod, num_warps, num_ctas, target): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + pm.add_convert_triton_to_tritongpu_pass(num_warps, 32, num_ctas, target.capability) + pm.run(mod) + return mod + + +def parse_mlir_module(path, context): + module = ir.parse_mlir_module(path, context) + # module takes ownership of the context + module.context = context + return module + + +def get_kernel_name(src: str, pattern: str) -> str: + ''' + Get kernel name from PTX code. + This Kernel name is required when launching the kernel. + ''' + # There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin. + assert src + for line in src.split('\n'): + line = line.strip() + if line.startswith(pattern): + return line.split()[-1] + + +def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, + enable_persistent, optimize_epilogue): + capability = target.capability + pm = ir.pass_manager(mod.context) + pm.enable_debug() + pm.add_tritongpu_coalesce_pass() + # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass + pm.add_plan_cta_pass(cluster_info) + pm.add_tritongpu_rewrite_tensor_pointer_pass(capability) + pm.add_plan_cta_pass(cluster_info) + pm.add_tritongpu_remove_layout_conversions_pass() + pm.add_tritongpu_accelerate_matmul_pass(capability) + pm.add_tritongpu_remove_layout_conversions_pass() + if optimize_epilogue: + pm.add_tritongpu_optimize_epilogue_pass() + pm.add_tritongpu_optimize_dot_operands_pass() + pm.add_cse_pass() + ws_enabled = False + # `num_warps` does not mean the total number of warps of a CTA when + # warp specialization is enabled. + # it's the responsibility of the compiler to figure out the exact + # `num_warps` to use. + # TODO: support the case where `num_warps` from user is not 4. + if capability // 10 >= 9 and enable_warp_specialization and num_warps == 4: + pm.add_tritongpu_ws_feasibility_checking_pass(capability) + pm.run(mod) + ws_enabled = ir.is_ws_supported(mod) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + if ws_enabled: + pm.add_tritongpu_wsdecomposing_pass(capability) + pm.add_tritongpu_wspipeline_pass(num_stages, num_warps, capability) + pm.add_tritongpu_wsmutex_pass(capability) + pm.add_tritongpu_wsmaterialization_pass(capability) + pm.add_licm_pass() + pm.add_cse_pass() + else: + pm.add_tritongpu_pipeline_pass(num_stages, num_warps, num_ctas, capability) + pm.add_tritongpu_materialize_load_store_pass(num_warps, capability) + if capability // 10 <= 8: + pm.add_tritongpu_prefetch_pass() + pm.add_tritongpu_optimize_dot_operands_pass() + pm.add_tritongpu_remove_layout_conversions_pass() + pm.add_tritongpu_decompose_conversions_pass() + pm.add_tritongpu_ws_fixup_missing_attrs_pass() + pm.add_tritongpu_reorder_instructions_pass() + pm.add_cse_pass() + pm.add_symbol_dce_pass() + if capability // 10 >= 9: + pm.add_tritongpu_fence_insertion_pass() + pm.add_tritongpu_ws_fixup_missing_attrs_pass() + pm.add_tritongpu_optimize_thread_locality_pass() + pm.add_canonicalizer_pass() + pm.run(mod) + return mod + + +def _add_external_libs(mod, libs): + for name, path in libs.items(): + if len(name) == 0 or len(path) == 0: + return + add_external_libs(mod, list(libs.keys()), list(libs.values())) + + +def ttgir_to_llir(mod, extern_libs, target, tma_infos): + if extern_libs: + _add_external_libs(mod, extern_libs) + return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM) + + +# PTX translation + + +@functools.lru_cache() +def ptx_get_version(cuda_version) -> int: + ''' + Get the highest PTX version supported by the current CUDA driver. + ''' + assert isinstance(cuda_version, str) + major, minor = map(int, cuda_version.split('.')) + if major == 12: + return 80 + minor + if major == 11: + return 70 + minor + if major == 10: + return 63 + minor + raise RuntimeError("Triton only support CUDA 10.0 or higher") + + +def llir_to_ptx(mod: Any, target: CudaTargetDescriptor, ptx_version: int = None) -> str: + ''' + Translate TritonGPU module to PTX code. + :param mod: a TritonGPU dialect module + :return: PTX code + ''' + if ptx_version is None: + _, cuda_version = path_to_ptxas() + ptx_version = ptx_get_version(cuda_version) + return translate_llvmir_to_ptx(mod, target.capability, ptx_version, target.enable_fp_fusion) + + +def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor): + ''' + Compile TritonGPU module to cubin. + :param ptx: ptx code + :param compute_capability: compute capability + :return: str + ''' + ptxas, _ = path_to_ptxas() + return compile_ptx_to_cubin(ptx, ptxas, target.capability, target.enable_fp_fusion) + + +class CUDABackend(BaseBackend): + + def __init__(self, device_type: str) -> None: + super().__init__(device_type) + + def add_stages(self, target, extern_libs, stages, opt, context): + num_warps = opt['num_warps'] + num_ctas = opt['num_ctas'] + num_stages = opt['num_stages'] + cluster_dims = opt['cluster_dims'] + enable_warp_specialization = opt['enable_warp_specialization'] + enable_persistent = opt['enable_persistent'] + optimize_epilogue = opt['optimize_epilogue'] + + cluster_info = ClusterInfo() + if cluster_dims is not None: + cluster_info.clusterDimX = cluster_dims[0] + cluster_info.clusterDimY = cluster_dims[1] + cluster_info.clusterDimZ = cluster_dims[2] + + # TTIR -> TTGIR stage + def create_ttgir(src): + ttgir = ttir_to_ttgir(src, num_warps, num_ctas, target) + return optimize_ttgir(ttgir, num_stages, num_warps, num_ctas, target, cluster_info, + enable_warp_specialization, enable_persistent, optimize_epilogue) + + stages["ttgir"] = (lambda path: parse_mlir_module(path, context), create_ttgir) + # TTGIR -> LLIR stage + tma_infos = TMAInfos() + + def create_llir(src): + return ttgir_to_llir(src, extern_libs, target, tma_infos) + + stages["llir"] = (lambda path: Path(path).read_text(), create_llir) + + # LLIR -> PTX stage + def create_ptx(src): + return llir_to_ptx(src, target) + + stages["ptx"] = (lambda path: Path(path).read_text(), create_ptx) + + # PTx -> CUBIN stage + def create_cubin(src): + return ptx_to_cubin(src, target) + + stages["cubin"] = (lambda path: Path(path).read_bytes(), create_cubin) + self.tma_infos = tma_infos + + def add_meta_info(self, ir_name, cur_module, next_module, metadata, asm): + if ir_name == "cubin": + asm[ir_name] = next_module + asm["sass"] = lambda: get_sass(next_module) + if ir_name == "llir" and "shared" not in metadata: + metadata["shared"] = get_shared_memory_size(cur_module) + if ir_name == "ttgir": + metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module) + metadata["num_warps"] = get_num_warps(next_module) + if ir_name == "ptx": + metadata["name"] = get_kernel_name(next_module, pattern='// .globl') + + def get_load_binary_fn(self): + return driver.utils.load_binary + + def get_stream(self): + return get_cuda_stream() + + def get_device_properties(self, device): + return driver.utils.get_device_properties(device) + + def get_version_key(self): + return get_cuda_version_key() + + def get_current_device(self): + return torch.cuda.current_device() + + def set_current_device(self, device): + torch.cuda.set_device(device) + + def get_kernel_bin(self): + return "cubin" + + def make_launcher_stub(self, fn, configs, metadata, name, signature, constants): + ids_of_folded_args = tuple([int(k) + for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else () + if "clusterDims" not in metadata: + metadata["clusterDims"] = [1, 1, 1] + if len(self.tma_infos) > 0: + metadata["tensormaps_info"] = parse_tma_info(self.tma_infos, ids_of_folded_args) + # set constant + if "tensormaps_info" in metadata: + for i, _ in enumerate(metadata["tensormaps_info"]): + metadata["tensormaps_info"][i].ids_of_folded_args = ids_of_folded_args + ids_of_tensormaps = get_ids_of_tensormaps(metadata.get("tensormaps_info", None)) + if isinstance(fn, JITFunction) and "tensormaps_info" in metadata: + fn.tensormaps_info = metadata["tensormaps_info"] + ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else () + ids = { + "ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": + ids_of_const_exprs + } + enable_warp_specialization = False + + return make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization) + + def get_architecture_descriptor(self, **kwargs): + capability = kwargs.get("cc", None) + if capability is None: + device = self.get_current_device() + capability = torch.cuda.get_device_capability(device) + capability = capability[0] * 10 + capability[1] + return CudaTargetDescriptor(capability=capability, num_warps=kwargs["num_warps"], + enable_fp_fusion=kwargs["enable_fp_fusion"]) + + @classmethod + def create_backend(cls, device_type: str): + return cls(device_type) From 2d2579ff44fe827ace72ad32e8e448db274e72e5 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 12 Nov 2023 20:17:53 -0800 Subject: [PATCH 03/64] . --- python/triton/compiler/backends/cuda.py | 41 +++++++++------- python/triton/compiler/compiler.py | 64 +++++++++++-------------- 2 files changed, 54 insertions(+), 51 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index edd1be0146fd..7c3757471cea 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -161,38 +161,47 @@ def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor): return compile_ptx_to_cubin(ptx, ptxas, target.capability, target.enable_fp_fusion) +@dataclass +class CUDAOptions: + num_warps: int = 4 + num_ctas: int = 1 + num_stages: int = 3 + cluster_dims: list = None + enable_warp_specialization: bool = False + enable_persistent: bool = False + optimize_epilogue: bool = False + enable_fp_fusion: bool = True + extern_libs = None + debug: bool = False + + class CUDABackend(BaseBackend): def __init__(self, device_type: str) -> None: super().__init__(device_type) - def add_stages(self, target, extern_libs, stages, opt, context): - num_warps = opt['num_warps'] - num_ctas = opt['num_ctas'] - num_stages = opt['num_stages'] - cluster_dims = opt['cluster_dims'] - enable_warp_specialization = opt['enable_warp_specialization'] - enable_persistent = opt['enable_persistent'] - optimize_epilogue = opt['optimize_epilogue'] + def parse_options(self, **opts) -> Any: + return CUDAOptions(**opts) + def add_stages(self, target, extern_libs, stages, opt, context): cluster_info = ClusterInfo() - if cluster_dims is not None: - cluster_info.clusterDimX = cluster_dims[0] - cluster_info.clusterDimY = cluster_dims[1] - cluster_info.clusterDimZ = cluster_dims[2] + if opt.cluster_dims is not None: + cluster_info.clusterDimX = opt.cluster_dims[0] + cluster_info.clusterDimY = opt.cluster_dims[1] + cluster_info.clusterDimZ = opt.cluster_dims[2] # TTIR -> TTGIR stage def create_ttgir(src): - ttgir = ttir_to_ttgir(src, num_warps, num_ctas, target) - return optimize_ttgir(ttgir, num_stages, num_warps, num_ctas, target, cluster_info, - enable_warp_specialization, enable_persistent, optimize_epilogue) + ttgir = ttir_to_ttgir(src, opt.num_warps, opt.num_ctas, target) + return optimize_ttgir(ttgir, opt.num_stages, opt.num_warps, opt.num_ctas, target, cluster_info, + opt.enable_warp_specialization, opt.enable_persistent, opt.optimize_epilogue) stages["ttgir"] = (lambda path: parse_mlir_module(path, context), create_ttgir) # TTGIR -> LLIR stage tma_infos = TMAInfos() def create_llir(src): - return ttgir_to_llir(src, extern_libs, target, tma_infos) + return ttgir_to_llir(src, opt.extern_libs, target, tma_infos) stages["llir"] = (lambda path: Path(path).read_text(), create_llir) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 4fbcdaf61d26..ed1addf8ae3f 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -82,11 +82,9 @@ def convert_type_repr(x): return x -def make_hash(fn, target, env_vars, device_backend, **kwargs): +def make_hash(fn, target, env_vars, device_backend, configs, signature, **kwargs): version_key = device_backend.get_version_key() if isinstance(fn, JITFunction): - configs = kwargs["configs"] - signature = kwargs["signature"] constants = kwargs.get("constants", dict()) num_warps = kwargs.get("num_warps", 4) num_ctas = kwargs.get("num_ctas", 1) @@ -201,33 +199,30 @@ def get_arch_default_num_stages(device_type, capability=None): return num_stages -def compile(src, **kwargs): +def compile(src, device_type="cuda", signature=None, configs=None, device=None, constants=None, extern_libs=None, + **kwargs): # Get device type to decide which backend should be used - constants = kwargs.get("constants", dict()) + if constants is None: + constants = dict() # create backend handler - device_type = kwargs.get("device_type", "cuda") _device_backend = get_backend(device_type) - print(_device_backend) + options = _device_backend.parse_options(**kwargs) assert _device_backend target = _device_backend.get_architecture_descriptor(**kwargs) - # extern libs - extern_libs = kwargs.get("extern_libs", dict()) - if extern_libs is None: - extern_libs = dict() # compilation options - opts = dict() - opts["num_warps"] = kwargs.get("num_warps", None) - assert opts["num_warps"] > 0 and (opts["num_warps"] & - (opts["num_warps"] - 1)) == 0, "num_warps must be a power of 2" - opts["num_ctas"] = kwargs.get("num_ctas", None) - opts["num_stages"] = kwargs.get("num_stages", None) - opts["enable_fp_fusion"] = kwargs.get("enable_fp_fusion", True) - opts["enable_warp_specialization"] = kwargs.get("enable_warp_specialization", False) - opts["enable_persistent"] = kwargs.get("enable_persistent", False) - opts["optimize_epilogue"] = os.environ.get('OPTIMIZE_EPILOGUE', '') == '1' - opts["cluster_dims"] = kwargs.get('clusterDims', None) - opts["debug"] = kwargs.get("debug", False) + # opts = dict() + # opts["num_warps"] = kwargs.get("num_warps", None) + # assert opts["num_warps"] > 0 and (opts["num_warps"] & + # (opts["num_warps"] - 1)) == 0, "num_warps must be a power of 2" + # opts["num_ctas"] = kwargs.get("num_ctas", None) + # opts["num_stages"] = kwargs.get("num_stages", None) + # opts["enable_fp_fusion"] = kwargs.get("enable_fp_fusion", True) + # opts["enable_warp_specialization"] = kwargs.get("enable_warp_specialization", False) + # opts["enable_persistent"] = kwargs.get("enable_persistent", False) + # opts["optimize_epilogue"] = os.environ.get('OPTIMIZE_EPILOGUE', '') == '1' + # opts["cluster_dims"] = kwargs.get('clusterDims', None) + # opts["debug"] = kwargs.get("debug", False) # build compilation stages context = ir.context() @@ -236,25 +231,22 @@ def compile(src, **kwargs): stages["ast"] = (lambda path: src, None) def create_ttir(src): - ttir = ast_to_ttir(src, signature, configs[0], constants, debug=opts["debug"], target=target) + ttir = ast_to_ttir(src, signature, configs[0], constants, debug=options.debug, target=target) return optimize_ttir(ttir, target=target) stages["ttir"] = (lambda path: parse_mlir_module(path, context), create_ttir) - _device_backend.add_stages(target, extern_libs, stages, opts, context) + _device_backend.add_stages(target, extern_libs, stages, options, context) # find out the signature of the function if isinstance(src, JITFunction): - configs = kwargs.get("configs", None) - signature = kwargs["signature"] + # signature = kwargs["signature"] if configs is None: configs = [instance_descriptor()] assert len(configs) == 1 - kwargs["configs"] = configs name = src.__name__ first_stage = 0 if isinstance(signature, str): signature = {k: v.strip() for k, v in enumerate(signature.split(","))} - kwargs["signature"] = signature else: assert isinstance(src, str) _, ir_name = os.path.basename(src).split(".") @@ -275,7 +267,9 @@ def create_ttir(src): first_stage = list(stages.keys()).index(ir_name) # create cache manager - fn_cache_manager = get_cache_manager(make_hash(src, target, get_env_vars(), _device_backend, **kwargs)) + fn_cache_manager = get_cache_manager( + make_hash(src, target, get_env_vars(), _device_backend, configs=configs, signature=signature, + **options.__dict__)) # determine name and extension type of provided function if isinstance(src, JITFunction): name, ext = src.__name__, "ast" @@ -295,11 +289,11 @@ def create_ttir(src): metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']] else: metadata = {"constants": _get_jsonable_constants(constants), "target": target} - metadata.update(opts) + metadata.update(options.__dict__) metadata.update(get_env_vars()) - if ext == "ptx": - assert "shared" in kwargs, "ptx compilation must provide shared memory size" - metadata["shared"] = kwargs["shared"] + # if ext == "ptx": + # assert "shared" in kwargs, "ptx compilation must provide shared memory size" + # metadata["shared"] = kwargs["shared"] metadata["device_type"] = device_type # run compilation pipeline and populate metadata @@ -318,7 +312,7 @@ def create_ttir(src): module = next_module # cache manager - so_path = _device_backend.make_launcher_stub(src, kwargs["configs"], metadata, name, signature, constants) + so_path = _device_backend.make_launcher_stub(src, configs, metadata, name, signature, constants) # write-back metadata, if it didn't come from the cache if metadata_path is None: metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, From 769c08a2379da7696f94c8cfb7ac20b57232e6ae Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 12 Nov 2023 20:18:20 -0800 Subject: [PATCH 04/64] . --- python/triton/compiler/compiler.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index ed1addf8ae3f..b0768901a6fb 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -210,20 +210,6 @@ def compile(src, device_type="cuda", signature=None, configs=None, device=None, assert _device_backend target = _device_backend.get_architecture_descriptor(**kwargs) - # compilation options - # opts = dict() - # opts["num_warps"] = kwargs.get("num_warps", None) - # assert opts["num_warps"] > 0 and (opts["num_warps"] & - # (opts["num_warps"] - 1)) == 0, "num_warps must be a power of 2" - # opts["num_ctas"] = kwargs.get("num_ctas", None) - # opts["num_stages"] = kwargs.get("num_stages", None) - # opts["enable_fp_fusion"] = kwargs.get("enable_fp_fusion", True) - # opts["enable_warp_specialization"] = kwargs.get("enable_warp_specialization", False) - # opts["enable_persistent"] = kwargs.get("enable_persistent", False) - # opts["optimize_epilogue"] = os.environ.get('OPTIMIZE_EPILOGUE', '') == '1' - # opts["cluster_dims"] = kwargs.get('clusterDims', None) - # opts["debug"] = kwargs.get("debug", False) - # build compilation stages context = ir.context() stages = dict() From a00020d4b6f9e62d3878261b5ecf6964ea57cfc3 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 12 Nov 2023 22:07:46 -0800 Subject: [PATCH 05/64] cleanup target somewhat --- python/triton/compiler/__init__.py | 7 +- python/triton/compiler/backends/cuda.py | 51 ++++------- python/triton/compiler/compiler.py | 107 ++++++++++-------------- python/triton/runtime/jit.py | 11 +-- 4 files changed, 71 insertions(+), 105 deletions(-) diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index d700656239a4..bdf64d0757de 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,12 +1,7 @@ -from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps, - instance_descriptor) +from .compiler import (CompiledKernel, compile, instance_descriptor) from .errors import CompilationError -from .backends.cuda import CUDABackend -from ..common.backend import register_backend __all__ = [ "compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages" ] - -register_backend("cuda", CUDABackend) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 7c3757471cea..df0e5faee363 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -14,17 +14,10 @@ from ...tools.disasm import get_sass -@dataclass -class CudaTargetDescriptor: - capability: int - num_warps: int - enable_fp_fusion: bool - - -def ttir_to_ttgir(mod, num_warps, num_ctas, target): +def ttir_to_ttgir(mod, num_warps, num_ctas, capability): pm = ir.pass_manager(mod.context) pm.enable_debug() - pm.add_convert_triton_to_tritongpu_pass(num_warps, 32, num_ctas, target.capability) + pm.add_convert_triton_to_tritongpu_pass(num_warps, 32, num_ctas, capability) pm.run(mod) return mod @@ -49,9 +42,8 @@ def get_kernel_name(src: str, pattern: str) -> str: return line.split()[-1] -def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, +def optimize_ttgir(mod, num_stages, num_warps, num_ctas, capability, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue): - capability = target.capability pm = ir.pass_manager(mod.context) pm.enable_debug() pm.add_tritongpu_coalesce_pass() @@ -113,10 +105,10 @@ def _add_external_libs(mod, libs): add_external_libs(mod, list(libs.keys()), list(libs.values())) -def ttgir_to_llir(mod, extern_libs, target, tma_infos): +def ttgir_to_llir(mod, extern_libs, capability, tma_infos): if extern_libs: _add_external_libs(mod, extern_libs) - return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM) + return translate_triton_gpu_to_llvmir(mod, capability, tma_infos, runtime.TARGET.NVVM) # PTX translation @@ -138,7 +130,7 @@ def ptx_get_version(cuda_version) -> int: raise RuntimeError("Triton only support CUDA 10.0 or higher") -def llir_to_ptx(mod: Any, target: CudaTargetDescriptor, ptx_version: int = None) -> str: +def llir_to_ptx(mod: Any, enable_fp_fusion: bool, capability: int, ptx_version: int = None) -> str: ''' Translate TritonGPU module to PTX code. :param mod: a TritonGPU dialect module @@ -147,10 +139,10 @@ def llir_to_ptx(mod: Any, target: CudaTargetDescriptor, ptx_version: int = None) if ptx_version is None: _, cuda_version = path_to_ptxas() ptx_version = ptx_get_version(cuda_version) - return translate_llvmir_to_ptx(mod, target.capability, ptx_version, target.enable_fp_fusion) + return translate_llvmir_to_ptx(mod, capability, ptx_version, enable_fp_fusion) -def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor): +def ptx_to_cubin(ptx: str, capability: int, enable_fp_fusion: bool): ''' Compile TritonGPU module to cubin. :param ptx: ptx code @@ -158,7 +150,7 @@ def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor): :return: str ''' ptxas, _ = path_to_ptxas() - return compile_ptx_to_cubin(ptx, ptxas, target.capability, target.enable_fp_fusion) + return compile_ptx_to_cubin(ptx, ptxas, capability, enable_fp_fusion) @dataclass @@ -177,13 +169,15 @@ class CUDAOptions: class CUDABackend(BaseBackend): - def __init__(self, device_type: str) -> None: + def __init__(self, device_type: tuple) -> None: super().__init__(device_type) + self.capability = device_type[1] + assert isinstance(self.capability, int) def parse_options(self, **opts) -> Any: return CUDAOptions(**opts) - def add_stages(self, target, extern_libs, stages, opt, context): + def add_stages(self, extern_libs, stages, opt, context): cluster_info = ClusterInfo() if opt.cluster_dims is not None: cluster_info.clusterDimX = opt.cluster_dims[0] @@ -192,8 +186,8 @@ def add_stages(self, target, extern_libs, stages, opt, context): # TTIR -> TTGIR stage def create_ttgir(src): - ttgir = ttir_to_ttgir(src, opt.num_warps, opt.num_ctas, target) - return optimize_ttgir(ttgir, opt.num_stages, opt.num_warps, opt.num_ctas, target, cluster_info, + ttgir = ttir_to_ttgir(src, opt.num_warps, opt.num_ctas, self.capability) + return optimize_ttgir(ttgir, opt.num_stages, opt.num_warps, opt.num_ctas, self.capability, cluster_info, opt.enable_warp_specialization, opt.enable_persistent, opt.optimize_epilogue) stages["ttgir"] = (lambda path: parse_mlir_module(path, context), create_ttgir) @@ -201,19 +195,19 @@ def create_ttgir(src): tma_infos = TMAInfos() def create_llir(src): - return ttgir_to_llir(src, opt.extern_libs, target, tma_infos) + return ttgir_to_llir(src, opt.extern_libs, self.capability, tma_infos) stages["llir"] = (lambda path: Path(path).read_text(), create_llir) # LLIR -> PTX stage def create_ptx(src): - return llir_to_ptx(src, target) + return llir_to_ptx(src, opt.enable_fp_fusion, self.capability) stages["ptx"] = (lambda path: Path(path).read_text(), create_ptx) # PTx -> CUBIN stage def create_cubin(src): - return ptx_to_cubin(src, target) + return ptx_to_cubin(src, self.capability, opt.enable_fp_fusion) stages["cubin"] = (lambda path: Path(path).read_bytes(), create_cubin) self.tma_infos = tma_infos @@ -274,15 +268,6 @@ def make_launcher_stub(self, fn, configs, metadata, name, signature, constants): return make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization) - def get_architecture_descriptor(self, **kwargs): - capability = kwargs.get("cc", None) - if capability is None: - device = self.get_current_device() - capability = torch.cuda.get_device_capability(device) - capability = capability[0] * 10 + capability[1] - return CudaTargetDescriptor(capability=capability, num_warps=kwargs["num_warps"], - enable_fp_fusion=kwargs["enable_fp_fusion"]) - @classmethod def create_backend(cls, device_type: str): return cls(device_type) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index b0768901a6fb..123997df3981 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -8,20 +8,18 @@ from pathlib import Path from .._C.libtriton.triton import (get_env_vars, ir) -from ..common.backend import get_backend from ..common.build import is_hip # from ..runtime import driver, jit, JITFunction # TODO: runtime.errors from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager -from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability) +from ..runtime.jit import (JITFunction, get_cuda_stream) from .code_generator import ast_to_ttir from .utils import (InfoFromBackendForTensorMap, TensorMapManager) -from .backends.cuda import CudaTargetDescriptor +from .backends.cuda import CUDABackend - -def _is_cuda(target): - return isinstance(target, CudaTargetDescriptor) +from ..runtime.driver import driver +import torch class LazyDict(dict): @@ -46,8 +44,7 @@ def ttir_compute_capability_rewrite(mod, target): # with block (tensor) pointers into tensors of pointers pm = ir.pass_manager(mod.context) pm.enable_debug() - if _is_cuda(target): - pm.add_rewrite_tensor_pointer_pass(target.capability) + pm.add_rewrite_tensor_pointer_pass(target.capability) pm.run(mod) return mod @@ -167,48 +164,16 @@ def parse_mlir_module(path, context): defaults=[set(), set(), set(), set()]) -# TODO: remove -def get_cuda_capability(capability): - if capability is None: - device = get_current_device() - capability = get_device_capability(device) - capability = capability[0] * 10 + capability[1] - return capability - - -def get_arch_default_num_warps(device_type): - if device_type in ["cuda", "hip"]: - num_warps = 4 - else: - _device_backend = get_backend(device_type) - assert _device_backend - arch = _device_backend.get_architecture_descriptor() - num_warps = arch["num_warps"] - return num_warps - - -def get_arch_default_num_stages(device_type, capability=None): - if device_type == "cuda": - num_stages = 3 if get_cuda_capability(capability) >= 75 else 2 - else: - _device_backend = get_backend(device_type) - assert _device_backend - arch = _device_backend.get_architecture_descriptor() - num_stages = arch["num_stages"] - - return num_stages - - -def compile(src, device_type="cuda", signature=None, configs=None, device=None, constants=None, extern_libs=None, - **kwargs): +def compile(src, device_type=("cuda", None), signature=None, configs=None, device=None, constants=None, + extern_libs=None, **kwargs): # Get device type to decide which backend should be used if constants is None: constants = dict() # create backend handler - _device_backend = get_backend(device_type) - options = _device_backend.parse_options(**kwargs) - assert _device_backend - target = _device_backend.get_architecture_descriptor(**kwargs) + backend = CUDABackend(device_type) + options = backend.parse_options(**kwargs) + target = namedtuple("target", ["capability", "num_warps", "num_stages"])(device_type[1], options.num_warps, + options.num_stages) # build compilation stages context = ir.context() @@ -217,11 +182,11 @@ def compile(src, device_type="cuda", signature=None, configs=None, device=None, stages["ast"] = (lambda path: src, None) def create_ttir(src): - ttir = ast_to_ttir(src, signature, configs[0], constants, debug=options.debug, target=target) + ttir = ast_to_ttir(src, signature, configs[0], constants, target=target, debug=options.debug) return optimize_ttir(ttir, target=target) stages["ttir"] = (lambda path: parse_mlir_module(path, context), create_ttir) - _device_backend.add_stages(target, extern_libs, stages, options, context) + backend.add_stages(extern_libs, stages, options, context) # find out the signature of the function if isinstance(src, JITFunction): @@ -254,8 +219,7 @@ def create_ttir(src): # create cache manager fn_cache_manager = get_cache_manager( - make_hash(src, target, get_env_vars(), _device_backend, configs=configs, signature=signature, - **options.__dict__)) + make_hash(src, target, get_env_vars(), backend, configs=configs, signature=signature, **options.__dict__)) # determine name and extension type of provided function if isinstance(src, JITFunction): name, ext = src.__name__, "ast" @@ -294,11 +258,11 @@ def create_ttir(src): continue next_module = compile_kernel(module) metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) - _device_backend.add_meta_info(ir_name, module, next_module, metadata, asm) + backend.add_meta_info(ir_name, module, next_module, metadata, asm) module = next_module # cache manager - so_path = _device_backend.make_launcher_stub(src, configs, metadata, name, signature, constants) + so_path = backend.make_launcher_stub(src, configs, metadata, name, signature, constants) # write-back metadata, if it didn't come from the cache if metadata_path is None: metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, @@ -308,6 +272,30 @@ def create_ttir(src): return CompiledKernel(src, so_path, metadata, asm) +class RuntimeCudaBackend: + + def __init__(self) -> None: + pass + + def get_load_binary_fn(self): + return driver.utils.load_binary + + def get_stream(self): + return get_cuda_stream() + + def get_device_properties(self, device): + return driver.utils.get_device_properties(device) + + def get_current_device(self): + return torch.cuda.current_device() + + def set_current_device(self, device): + torch.cuda.set_device(device) + + def get_kernel_bin(self): + return "cubin" + + class CompiledKernel: # Hooks for external tools to monitor the execution of triton kernels @@ -333,7 +321,6 @@ def __init__(self, fn, so_path, metadata, asm): self.tensormaps_info = metadata["tensormaps_info"] self.constants = metadata["constants"] self.device_type = metadata["device_type"] - self.device_backend = get_backend(self.device_type) # initialize asm dict self.asm = asm # binaries are lazily initialized @@ -342,15 +329,16 @@ def __init__(self, fn, so_path, metadata, asm): self.metadata = metadata self.cu_module = None self.cu_function = None + self.driver = RuntimeCudaBackend() def _init_handles(self): if self.cu_module is not None: return - device = self.device_backend.get_current_device() - bin_path = self.device_backend.get_kernel_bin() - max_shared = self.device_backend.get_device_properties(device)["max_shared_mem"] - fn_load_binary = self.device_backend.get_load_binary_fn() + device = self.driver.get_current_device() + bin_path = self.driver.get_kernel_bin() + max_shared = self.driver.get_device_properties(device)["max_shared_mem"] + fn_load_binary = self.driver.get_load_binary_fn() if self.shared > max_shared: raise OutOfResources(self.shared, max_shared, "shared memory") @@ -383,10 +371,7 @@ def __getitem__(self, grid): def runner(*args, stream=None): args_expand = self.assemble_tensormap_to_arg(args) if stream is None: - if self.device_type in ["cuda"]: - stream = get_cuda_stream() - else: - stream = get_backend(self.device_type).get_stream(None) + stream = self.driver.get_stream(None) self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0], self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 286d069c15f1..b74fb7773e6a 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -400,7 +400,7 @@ def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: Li return device_types[0] if len(device_types) > 0 else "cuda" def run(self, *args, **kwargs): - from ..compiler import CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps + from ..compiler import CompiledKernel, compile # Get a compiler-flags arg like `num_warps` and remove it from kwargs. def get_special_arg(name: str, default=None): @@ -472,9 +472,9 @@ def get_special_arg(name: str, default=None): stream = device_backend.get_stream() if num_warps is None: - num_warps = get_arch_default_num_warps(device_type) + num_warps = 4 if num_stages is None: - num_stages = get_arch_default_num_stages(device_type) + num_stages = 3 if device_type in ["cuda"]: version_key = get_cuda_version_key() @@ -529,10 +529,12 @@ def get_special_arg(name: str, default=None): ): return None + capability = get_device_capability(device) + capability = capability[0] * 10 + capability[1] self.cache[device][key] = compile( self, signature=signature, - device=device, + device_type=(device_type, capability), constants=constants, num_warps=num_warps, num_ctas=num_ctas, @@ -542,7 +544,6 @@ def get_special_arg(name: str, default=None): extern_libs=extern_libs, configs=configs, debug=self.debug, - device_type=device_type, ) bin = self.cache[device][key] From 0b98e0c088f5c4b397d808b044fff6d9fc41d7f2 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 12 Nov 2023 22:08:52 -0800 Subject: [PATCH 06/64] more cleaning --- python/triton/compiler/backends/cuda.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index df0e5faee363..3ddf9ab8e0dd 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -1,16 +1,14 @@ from triton.common.backend import BaseBackend from pathlib import Path from dataclasses import dataclass -import torch from ..._C.libtriton.triton import ClusterInfo, get_num_warps, TMAInfos, translate_triton_gpu_to_llvmir, get_shared_memory_size, translate_llvmir_to_ptx, compile_ptx_to_cubin, add_external_libs from ...common.backend import get_cuda_version_key, path_to_ptxas from ..._C.libtriton.triton import ir, runtime import functools from typing import Any -from ...runtime.jit import JITFunction, get_cuda_stream +from ...runtime.jit import JITFunction from ..utils import get_ids_of_tensormaps, parse_tma_info from ..make_launcher import make_stub -from ...runtime.driver import driver from ...tools.disasm import get_sass @@ -224,27 +222,9 @@ def add_meta_info(self, ir_name, cur_module, next_module, metadata, asm): if ir_name == "ptx": metadata["name"] = get_kernel_name(next_module, pattern='// .globl') - def get_load_binary_fn(self): - return driver.utils.load_binary - - def get_stream(self): - return get_cuda_stream() - - def get_device_properties(self, device): - return driver.utils.get_device_properties(device) - def get_version_key(self): return get_cuda_version_key() - def get_current_device(self): - return torch.cuda.current_device() - - def set_current_device(self, device): - torch.cuda.set_device(device) - - def get_kernel_bin(self): - return "cubin" - def make_launcher_stub(self, fn, configs, metadata, name, signature, constants): ids_of_folded_args = tuple([int(k) for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else () From e6a9f8af5a1e54327cc4b06d0643e2f01897034b Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 12 Nov 2023 22:57:09 -0800 Subject: [PATCH 07/64] simplify hash --- python/triton/compiler/backends/cuda.py | 5 ++ python/triton/compiler/compiler.py | 90 +++++++------------------ python/triton/runtime/jit.py | 11 ++- 3 files changed, 36 insertions(+), 70 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 3ddf9ab8e0dd..cb0b756f7985 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -10,6 +10,7 @@ from ..utils import get_ids_of_tensormaps, parse_tma_info from ..make_launcher import make_stub from ...tools.disasm import get_sass +import hashlib def ttir_to_ttgir(mod, num_warps, num_ctas, capability): @@ -164,6 +165,10 @@ class CUDAOptions: extern_libs = None debug: bool = False + def hash(self): + key = '-'.join([str(x) for x in self.__dict__]) + return hashlib.md5(key.encode("utf-8")).hexdigest() + class CUDABackend(BaseBackend): diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 123997df3981..499218800baf 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -2,10 +2,8 @@ import hashlib import json -import os import re from collections import namedtuple -from pathlib import Path from .._C.libtriton.triton import (get_env_vars, ir) from ..common.build import is_hip @@ -13,13 +11,14 @@ # TODO: runtime.errors from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager -from ..runtime.jit import (JITFunction, get_cuda_stream) +from ..runtime.jit import (get_cuda_stream) from .code_generator import ast_to_ttir from .utils import (InfoFromBackendForTensorMap, TensorMapManager) from .backends.cuda import CUDABackend from ..runtime.driver import driver import torch +from dataclasses import dataclass class LazyDict(dict): @@ -79,28 +78,11 @@ def convert_type_repr(x): return x -def make_hash(fn, target, env_vars, device_backend, configs, signature, **kwargs): +def make_hash(fn, target, env_vars, device_backend, config, signature, constants, options): version_key = device_backend.get_version_key() - if isinstance(fn, JITFunction): - constants = kwargs.get("constants", dict()) - num_warps = kwargs.get("num_warps", 4) - num_ctas = kwargs.get("num_ctas", 1) - num_stages = kwargs.get("num_stages", 3) - enable_warp_specialization = kwargs.get("enable_warp_specialization", False) - enable_persistent = kwargs.get("enable_persistent", False) - debug = kwargs.get("debug", False) - # Get unique key for the compiled code - get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), - sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8)) - configs_key = [get_conf_key(conf) for conf in configs] - env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())] - key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}" - return hashlib.md5(key.encode("utf-8")).hexdigest() - assert isinstance(fn, str) - ignore_version = kwargs.get('ignore_version', False) - if (ignore_version): - return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest() - return hashlib.md5((Path(fn).read_text() + version_key).encode("utf-8")).hexdigest() + env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())] + key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{config.hash()}-{constants}-{options.hash()}-{target}-{env_vars_list}" + return hashlib.md5(key.encode("utf-8")).hexdigest() # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, @@ -159,12 +141,19 @@ def parse_mlir_module(path, context): return module -instance_descriptor = namedtuple("instance_descriptor", - ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], - defaults=[set(), set(), set(), set()]) +@dataclass +class instance_descriptor: + divisible_by_16: set = None + equal_to_1: set = None + ids_of_folded_args: set = None + divisible_by_8: set = None + + def hash(self): + key = str([sorted(x) for x in self.__dict__.values()]) + return hashlib.md5(key.encode("utf-8")).hexdigest() -def compile(src, device_type=("cuda", None), signature=None, configs=None, device=None, constants=None, +def compile(src, device_type=("cuda", None), signature=None, config=instance_descriptor(), constants=None, extern_libs=None, **kwargs): # Get device type to decide which backend should be used if constants is None: @@ -178,53 +167,26 @@ def compile(src, device_type=("cuda", None), signature=None, configs=None, devic # build compilation stages context = ir.context() stages = dict() - # TODO: CompilationStage object w/ both `parser` and `creator` attributes stages["ast"] = (lambda path: src, None) def create_ttir(src): - ttir = ast_to_ttir(src, signature, configs[0], constants, target=target, debug=options.debug) + ttir = ast_to_ttir(src, signature, config, constants, target=target, debug=options.debug) return optimize_ttir(ttir, target=target) stages["ttir"] = (lambda path: parse_mlir_module(path, context), create_ttir) backend.add_stages(extern_libs, stages, options, context) # find out the signature of the function - if isinstance(src, JITFunction): - # signature = kwargs["signature"] - if configs is None: - configs = [instance_descriptor()] - assert len(configs) == 1 - name = src.__name__ - first_stage = 0 - if isinstance(signature, str): - signature = {k: v.strip() for k, v in enumerate(signature.split(","))} - else: - assert isinstance(src, str) - _, ir_name = os.path.basename(src).split(".") - src = Path(src).read_text() - import re - match = re.search(prototype_pattern[ir_name], src, re.MULTILINE) - # TODO: support function attributes at group 3 (e.g., device function) - name, signature = match.group(1), match.group(2) - types = re.findall(arg_type_pattern[ir_name], signature) - if ir_name == 'ttgir': - num_warps_matches = re.findall(ttgir_num_warps_pattern, src) - assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" - # assert "num_warps" not in kwargs or int( - # num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile" - # num_warps = int(num_warps_matches[0]) - param_tys = [convert_type_repr(ty) for ty in types] - signature = {k: v for k, v in enumerate(param_tys)} - first_stage = list(stages.keys()).index(ir_name) + name = src.__name__ + if isinstance(signature, str): + signature = {k: v.strip() for k, v in enumerate(signature.split(","))} # create cache manager - fn_cache_manager = get_cache_manager( - make_hash(src, target, get_env_vars(), backend, configs=configs, signature=signature, **options.__dict__)) + hash = make_hash(src, target, get_env_vars(), backend, config=config, constants=constants, signature=signature, + options=options) + fn_cache_manager = get_cache_manager(hash) # determine name and extension type of provided function - if isinstance(src, JITFunction): - name, ext = src.__name__, "ast" - else: - name, ext = os.path.basename(src).split(".") + name, ext = src.__name__, "ast" # load metadata if any metadata = None metadata_filename = f"{name}.json" @@ -262,7 +224,7 @@ def create_ttir(src): module = next_module # cache manager - so_path = backend.make_launcher_stub(src, configs, metadata, name, signature, constants) + so_path = backend.make_launcher_stub(src, [config], metadata, name, signature, constants) # write-back metadata, if it didn't come from the cache if metadata_path is None: metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index b74fb7773e6a..8583b5d7a82d 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -252,6 +252,7 @@ def _spec_of(arg): # TODO(jlebar): Fold this into the KernelArg class. def _get_config(self, *args): + from ..compiler import instance_descriptor def is_divisible_by_16(x): if hasattr(x, "data_ptr"): @@ -288,10 +289,8 @@ def is_divisible_by_8(x): # TODO: method to collect all folded args none_args = {param.num for param, arg in zip(self.params, args) if arg is None and not param.do_not_specialize} ids_of_folded_args = equal_to_1 | none_args - return namedtuple("instance_descriptor", - ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])( # - tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), - tuple(divisible_by_8)) + return instance_descriptor(tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), + tuple(divisible_by_8)) # return _triton.code_gen.instance_descriptor(divisible_by_16, # equal_to_1) @@ -400,7 +399,7 @@ def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: Li return device_types[0] if len(device_types) > 0 else "cuda" def run(self, *args, **kwargs): - from ..compiler import CompiledKernel, compile + from ..compiler import CompiledKernel, compile, instance_descriptor # Get a compiler-flags arg like `num_warps` and remove it from kwargs. def get_special_arg(name: str, default=None): @@ -542,7 +541,7 @@ def get_special_arg(name: str, default=None): enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, - configs=configs, + config=configs[0], debug=self.debug, ) From 2a4143178c34c4fcd00fd102459fc8d2bdf1c9f2 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 13 Nov 2023 18:27:16 -0800 Subject: [PATCH 08/64] semantic analysis no longer get target --- python/triton/compiler/__init__.py | 2 +- python/triton/compiler/backends/cuda.py | 14 ++++-- python/triton/compiler/code_generator.py | 14 +++--- python/triton/compiler/compiler.py | 55 +++++++----------------- python/triton/language/extra/cuda.py | 2 +- python/triton/language/semantic.py | 29 +++---------- python/triton/runtime/jit.py | 8 ++-- 7 files changed, 45 insertions(+), 79 deletions(-) diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index bdf64d0757de..3795b5ec099b 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,4 +1,4 @@ -from .compiler import (CompiledKernel, compile, instance_descriptor) +from .compiler import (CompiledKernel, compile) from .errors import CompilationError __all__ = [ diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index cb0b756f7985..123f2add8a65 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -163,6 +163,9 @@ class CUDAOptions: optimize_epilogue: bool = False enable_fp_fusion: bool = True extern_libs = None + allow_fp8e4nv: bool = False + max_num_imprecise_acc: bool = None + debug: bool = False def hash(self): @@ -178,9 +181,14 @@ def __init__(self, device_type: tuple) -> None: assert isinstance(self.capability, int) def parse_options(self, **opts) -> Any: - return CUDAOptions(**opts) + options = CUDAOptions(**opts) + options.allow_fp8e4nv = self.capability >= 89 + options.max_num_imprecise_acc = 0 if self.capability >= 89 else None + return options + + def add_stages(self, extern_libs, stages, opt): - def add_stages(self, extern_libs, stages, opt, context): + context = ir.context() cluster_info = ClusterInfo() if opt.cluster_dims is not None: cluster_info.clusterDimX = opt.cluster_dims[0] @@ -228,7 +236,7 @@ def add_meta_info(self, ir_name, cur_module, next_module, metadata, asm): metadata["name"] = get_kernel_name(next_module, pattern='// .globl') def get_version_key(self): - return get_cuda_version_key() + return f'{get_cuda_version_key()}-{self.capability}' def make_launcher_stub(self, fn, configs, metadata, name, signature, constants): ids_of_folded_args = tuple([int(k) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 19017b844022..591dd7120304 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -208,8 +208,8 @@ def visit_Call(self, node: ast.Call) -> bool: class CodeGenerator(ast.NodeVisitor): - def __init__(self, context, prototype, gscope, attributes, constants, function_name, target, module=None, - is_kernel=False, function_types: Optional[Dict] = None, debug=False, noinline=False, + def __init__(self, context, prototype, gscope, attributes, constants, function_name, options, module=None, + is_kernel=False, function_types: Optional[Dict] = None, noinline=False, file_name: Optional[str] = None, begin_line=0): self.context = context self.builder = ir.builder(context) @@ -217,7 +217,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n # node.lineno starts from 1, so we need to subtract 1 self.begin_line = begin_line - 1 self.builder.set_loc(file_name, begin_line, 0) - self.builder.target = target + self.builder.options = options self.module = self.builder.create_module() if module is None else module self.function_ret_types = {} if function_types is None else function_types self.prototype = prototype @@ -228,7 +228,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n self.function_name = function_name self.is_kernel = is_kernel self.last_node = None - self.debug = debug + self.debug = options.debug self.noinline = noinline self.scf_stack = [] self.last_ret_type = None @@ -1188,7 +1188,7 @@ def kernel_suffix(signature, specialization): return suffix -def ast_to_ttir(fn, signature, specialization, constants, debug, target): +def ast_to_ttir(fn, signature, specialization, constants, options): # canonicalize signature if isinstance(signature, str): signature = {k: v.strip() for k, v in enumerate(signature.split(","))} @@ -1215,8 +1215,8 @@ def ast_to_ttir(fn, signature, specialization, constants, debug, target): prototype = language.function_type([], arg_types) generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, - attributes=new_attrs, is_kernel=True, debug=debug, file_name=file_name, - begin_line=begin_line, target=target) + attributes=new_attrs, is_kernel=True, file_name=file_name, begin_line=begin_line, + options=options) try: generator.visit(fn.parse()) except CompilationError as e: diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 499218800baf..0af9b1ce7b1c 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -3,7 +3,6 @@ import hashlib import json import re -from collections import namedtuple from .._C.libtriton.triton import (get_env_vars, ir) from ..common.build import is_hip @@ -12,13 +11,13 @@ from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager from ..runtime.jit import (get_cuda_stream) -from .code_generator import ast_to_ttir from .utils import (InfoFromBackendForTensorMap, TensorMapManager) from .backends.cuda import CUDABackend from ..runtime.driver import driver import torch from dataclasses import dataclass +from .code_generator import ast_to_ttir class LazyDict(dict): @@ -30,30 +29,16 @@ def __getitem__(self, key): return val -def inline_triton_ir(mod): - pm = ir.pass_manager(mod.context) - pm.enable_debug() - pm.add_inliner_pass() - pm.run(mod) - return mod - - -def ttir_compute_capability_rewrite(mod, target): - # For hardware without support, we must rewrite all load/store - # with block (tensor) pointers into tensors of pointers - pm = ir.pass_manager(mod.context) - pm.enable_debug() - pm.add_rewrite_tensor_pointer_pass(target.capability) - pm.run(mod) - return mod +# ------------------------------------------------------------------------------ +# compiler +# ------------------------------------------------------------------------------ -def optimize_ttir(mod, target): - mod = inline_triton_ir(mod) - mod = ttir_compute_capability_rewrite(mod, target) +def optimize_ttir(mod, capability): pm = ir.pass_manager(mod.context) pm.enable_debug() pm.add_inliner_pass() + pm.add_rewrite_tensor_pointer_pass(capability) pm.add_triton_combine_pass() pm.add_canonicalizer_pass() pm.add_reorder_broadcast_pass() @@ -64,11 +49,6 @@ def optimize_ttir(mod, target): return mod -# ------------------------------------------------------------------------------ -# compiler -# ------------------------------------------------------------------------------ - - def convert_type_repr(x): # Currently we only capture the pointer type and assume the pointer is on global memory. # TODO: Capture and support shared memory space @@ -78,10 +58,10 @@ def convert_type_repr(x): return x -def make_hash(fn, target, env_vars, device_backend, config, signature, constants, options): +def make_hash(fn, env_vars, device_backend, config, signature, constants, options): version_key = device_backend.get_version_key() env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())] - key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{config.hash()}-{constants}-{options.hash()}-{target}-{env_vars_list}" + key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{config.hash()}-{constants}-{options.hash()}-{env_vars_list}" return hashlib.md5(key.encode("utf-8")).hexdigest() @@ -142,7 +122,7 @@ def parse_mlir_module(path, context): @dataclass -class instance_descriptor: +class InstanceDescriptor: divisible_by_16: set = None equal_to_1: set = None ids_of_folded_args: set = None @@ -153,7 +133,7 @@ def hash(self): return hashlib.md5(key.encode("utf-8")).hexdigest() -def compile(src, device_type=("cuda", None), signature=None, config=instance_descriptor(), constants=None, +def compile(src, device_type=("cuda", None), signature=None, config=InstanceDescriptor(), constants=None, extern_libs=None, **kwargs): # Get device type to decide which backend should be used if constants is None: @@ -161,8 +141,6 @@ def compile(src, device_type=("cuda", None), signature=None, config=instance_des # create backend handler backend = CUDABackend(device_type) options = backend.parse_options(**kwargs) - target = namedtuple("target", ["capability", "num_warps", "num_stages"])(device_type[1], options.num_warps, - options.num_stages) # build compilation stages context = ir.context() @@ -170,11 +148,11 @@ def compile(src, device_type=("cuda", None), signature=None, config=instance_des stages["ast"] = (lambda path: src, None) def create_ttir(src): - ttir = ast_to_ttir(src, signature, config, constants, target=target, debug=options.debug) - return optimize_ttir(ttir, target=target) + ttir = ast_to_ttir(src, signature, config, constants, options=options) + return optimize_ttir(ttir, capability=device_type[1]) stages["ttir"] = (lambda path: parse_mlir_module(path, context), create_ttir) - backend.add_stages(extern_libs, stages, options, context) + backend.add_stages(extern_libs, stages, options) # find out the signature of the function name = src.__name__ @@ -182,7 +160,7 @@ def create_ttir(src): signature = {k: v.strip() for k, v in enumerate(signature.split(","))} # create cache manager - hash = make_hash(src, target, get_env_vars(), backend, config=config, constants=constants, signature=signature, + hash = make_hash(src, get_env_vars(), backend, config=config, constants=constants, signature=signature, options=options) fn_cache_manager = get_cache_manager(hash) # determine name and extension type of provided function @@ -200,12 +178,9 @@ def create_ttir(src): if 'tensormaps_info' in metadata: metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']] else: - metadata = {"constants": _get_jsonable_constants(constants), "target": target} + metadata = {"constants": _get_jsonable_constants(constants)} metadata.update(options.__dict__) metadata.update(get_env_vars()) - # if ext == "ptx": - # assert "shared" in kwargs, "ptx compilation must provide shared memory size" - # metadata["shared"] = kwargs["shared"] metadata["device_type"] = device_type # run compilation pipeline and populate metadata diff --git a/python/triton/language/extra/cuda.py b/python/triton/language/extra/cuda.py index 9400ae797887..1cb494d9fa7f 100644 --- a/python/triton/language/extra/cuda.py +++ b/python/triton/language/extra/cuda.py @@ -15,4 +15,4 @@ def smid(_builder=None): @core.builtin def num_threads(_builder=None): - return core.constexpr(_builder.target.num_warps * 32) + return core.constexpr(_builder.options.num_warps * 32) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index d74cbb150254..e9496c27a869 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -9,16 +9,6 @@ T = TypeVar('T') -# TODO: redundant code -- remove after 3P backend refactor - - -def _is_cuda(target): - from ..compiler.compiler import CudaTargetDescriptor - return isinstance(target, CudaTargetDescriptor) - - -# Create custom exception that prints message "hello" - class IncompatibleTypeErrorImpl(Exception): @@ -654,8 +644,7 @@ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar - if _is_cuda(builder.target) and builder.target.capability < 89 and \ - (src_sca_ty.is_fp8e4nv() or dst_sca_ty.is_fp8e4nv()): + if builder.options.allow_fp8e4nv: assert False, "fp8e4nv data type is not supported on CUDA arch < 89" # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 @@ -1188,13 +1177,8 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool: def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_num_imprecise_acc: int, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: - def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): - # Checks for non-cuda archs - if not _is_cuda(target): - assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" - return - # Checks for cuda arch - if target.capability < 90: + def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): + if not options.allow_fp8e4nv: assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv( ), "Dot op does not support fp8e4nv on CUDA arch < 90" if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): @@ -1223,7 +1207,7 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): assert lhs.type.is_block() and rhs.type.is_block() - assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.target) + assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.options) assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!" assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!" @@ -1282,9 +1266,8 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): assert acc.type == ret_ty # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 - if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() - and ret_scalar_ty.is_fp32()): - max_num_imprecise_acc = 0 + if lhs.dype.is_fp8() and rhs.dtype.is_fp8(): + max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default if max_num_imprecise_acc is None: max_num_imprecise_acc = 2**30 diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 8583b5d7a82d..275d7f17f7cd 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -252,7 +252,7 @@ def _spec_of(arg): # TODO(jlebar): Fold this into the KernelArg class. def _get_config(self, *args): - from ..compiler import instance_descriptor + from ..compiler import InstanceDescriptor def is_divisible_by_16(x): if hasattr(x, "data_ptr"): @@ -289,8 +289,8 @@ def is_divisible_by_8(x): # TODO: method to collect all folded args none_args = {param.num for param, arg in zip(self.params, args) if arg is None and not param.do_not_specialize} ids_of_folded_args = equal_to_1 | none_args - return instance_descriptor(tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), - tuple(divisible_by_8)) + return InstanceDescriptor(tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), + tuple(divisible_by_8)) # return _triton.code_gen.instance_descriptor(divisible_by_16, # equal_to_1) @@ -399,7 +399,7 @@ def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: Li return device_types[0] if len(device_types) > 0 else "cuda" def run(self, *args, **kwargs): - from ..compiler import CompiledKernel, compile, instance_descriptor + from ..compiler import CompiledKernel, compile, InstanceDescriptor # Get a compiler-flags arg like `num_warps` and remove it from kwargs. def get_special_arg(name: str, default=None): From a414f6d0e96d8ddb0a65e13bcd5c3696fd8d51a3 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Wed, 22 Nov 2023 21:04:01 -0800 Subject: [PATCH 09/64] more cleaning --- python/triton/compiler/compiler.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 0af9b1ce7b1c..6281557b1d46 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -145,7 +145,6 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc # build compilation stages context = ir.context() stages = dict() - stages["ast"] = (lambda path: src, None) def create_ttir(src): ttir = ast_to_ttir(src, signature, config, constants, options=options) @@ -164,14 +163,14 @@ def create_ttir(src): options=options) fn_cache_manager = get_cache_manager(hash) # determine name and extension type of provided function - name, ext = src.__name__, "ast" + name, ext = src.__name__, "ttir" # load metadata if any - metadata = None - metadata_filename = f"{name}.json" # The group is addressed by the metadata + metadata_filename = f"{name}.json" metadata_group = fn_cache_manager.get_group(metadata_filename) or {} metadata_path = metadata_group.get(metadata_filename) # initialize metadata + metadata = None if metadata_path is not None: with open(metadata_path) as f: metadata = json.load(f) @@ -187,13 +186,10 @@ def create_ttir(src): first_stage = list(stages.keys()).index(ext) asm = LazyDict() module = src - for ir_name, (parse, compile_kernel) in list(stages.items())[first_stage:]: + for ir_name, (parse_ir, compile_ir) in list(stages.items())[first_stage:]: ir_filename = f"{name}.{ir_name}" path = metadata_group.get(ir_filename) - if ir_name == ext: - next_module = parse(src if name == ext else path) - continue - next_module = compile_kernel(module) + next_module = compile_ir(module) if path is None else parse_ir(path) metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) backend.add_meta_info(ir_name, module, next_module, metadata, asm) module = next_module From b40ae9797b572003f81e9e4e1801bcab62c2bc01 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Wed, 22 Nov 2023 22:53:22 -0800 Subject: [PATCH 10/64] . --- python/triton/compiler/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index 3795b5ec099b..e3bf82207b14 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,7 +1,7 @@ -from .compiler import (CompiledKernel, compile) +from .compiler import (CompiledKernel, compile, InstanceDescriptor) from .errors import CompilationError __all__ = [ - "compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", + "compile", "InstanceDescriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages" ] From 3d7b773d9dc7739e9566aeb1d991a9d00d97221d Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Wed, 22 Nov 2023 23:38:06 -0800 Subject: [PATCH 11/64] optimize_ttir no longer depends on target --- .../Transforms/RewriteTensorPointer.cpp | 13 ++------ python/src/triton.cc | 5 ++-- python/triton/compiler/backends/cuda.py | 1 + python/triton/compiler/compiler.py | 30 +++++-------------- 4 files changed, 13 insertions(+), 36 deletions(-) diff --git a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp index 31a53af78b69..58b6e73ce986 100644 --- a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -196,9 +196,7 @@ class RewriteTensorPointerPass DenseMap rewritedInfo; public: - explicit RewriteTensorPointerPass(int computeCapability) { - this->computeCapability = computeCapability; - } + explicit RewriteTensorPointerPass() {} static bool needRewrite(Operation *op) { return std::any_of(op->getOperands().begin(), op->getOperands().end(), @@ -473,10 +471,6 @@ class RewriteTensorPointerPass } void runOnOperation() override { - // Only rewrite if the hardware does not support - if (computeCapability >= 90) - return; - // NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because // MLIR does not support one-multiple value mapping. For example, if we use // `ConversionPatternRewriter`, we can not make a type converter, which @@ -502,7 +496,6 @@ class RewriteTensorPointerPass } }; -std::unique_ptr -triton::createRewriteTensorPointerPass(int computeCapability) { - return std::make_unique(computeCapability); +std::unique_ptr triton::createRewriteTensorPointerPass() { + return std::make_unique(); } diff --git a/python/src/triton.cc b/python/src/triton.cc index 51307f445b70..8d0f832ef814 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1770,9 +1770,8 @@ void init_triton_ir(py::module &&m) { self.addPass(mlir::createTritonGPUReorderInstructionsPass()); }) .def("add_tritongpu_rewrite_tensor_pointer_pass", - [](mlir::PassManager &self, int computeCapability) { - self.addPass(mlir::createTritonGPURewriteTensorPointerPass( - computeCapability)); + [](mlir::PassManager &self) { + self.addPass(mlir::createTritonGPURewriteTensorPointerPass()); }) .def("add_tritongpu_decompose_conversions_pass", [](mlir::PassManager &self) { diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 123f2add8a65..d23c37e0734f 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -164,6 +164,7 @@ class CUDAOptions: enable_fp_fusion: bool = True extern_libs = None allow_fp8e4nv: bool = False + rewrite_tensor_pointer: bool = True max_num_imprecise_acc: bool = None debug: bool = False diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 6281557b1d46..36bed4324002 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -34,11 +34,12 @@ def __getitem__(self, key): # ------------------------------------------------------------------------------ -def optimize_ttir(mod, capability): +def optimize_ttir(mod, options): pm = ir.pass_manager(mod.context) pm.enable_debug() pm.add_inliner_pass() - pm.add_rewrite_tensor_pointer_pass(capability) + if options.rewrite_tensor_pointer: + pm.add_rewrite_tensor_pointer_pass() pm.add_triton_combine_pass() pm.add_canonicalizer_pass() pm.add_reorder_broadcast_pass() @@ -98,22 +99,6 @@ def make_hash(fn, env_vars, device_backend, config, signature, constants, option ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' -def _get_jsonable_constants(constants): - - def _is_jsonable(x): - try: - json.dumps(x) - return True - except (TypeError, OverflowError): - return False - - serialized_constants = {} - for constant in constants: - if _is_jsonable(constants[constant]): - serialized_constants[constant] = constants[constant] - return serialized_constants - - def parse_mlir_module(path, context): module = ir.parse_mlir_module(path, context) # module takes ownership of the context @@ -148,13 +133,12 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc def create_ttir(src): ttir = ast_to_ttir(src, signature, config, constants, options=options) - return optimize_ttir(ttir, capability=device_type[1]) + return optimize_ttir(ttir, options=options) stages["ttir"] = (lambda path: parse_mlir_module(path, context), create_ttir) backend.add_stages(extern_libs, stages, options) # find out the signature of the function - name = src.__name__ if isinstance(signature, str): signature = {k: v.strip() for k, v in enumerate(signature.split(","))} @@ -162,10 +146,10 @@ def create_ttir(src): hash = make_hash(src, get_env_vars(), backend, config=config, constants=constants, signature=signature, options=options) fn_cache_manager = get_cache_manager(hash) - # determine name and extension type of provided function - name, ext = src.__name__, "ttir" + # load metadata if any # The group is addressed by the metadata + name, ext = src.__name__, "ttir" metadata_filename = f"{name}.json" metadata_group = fn_cache_manager.get_group(metadata_filename) or {} metadata_path = metadata_group.get(metadata_filename) @@ -177,7 +161,7 @@ def create_ttir(src): if 'tensormaps_info' in metadata: metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']] else: - metadata = {"constants": _get_jsonable_constants(constants)} + metadata = {"constants": constants} metadata.update(options.__dict__) metadata.update(get_env_vars()) metadata["device_type"] = device_type From 7274e2a24f2ceb50d222e51e2b0661c6b23368dc Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Thu, 23 Nov 2023 12:26:55 -0800 Subject: [PATCH 12/64] rewrite_tensor_pointer no longer depends on capability --- .../triton/Dialect/Triton/Transforms/Passes.h | 3 +- .../Dialect/Triton/Transforms/Passes.td | 6 +--- .../TritonNvidiaGPU/Transforms/Passes.h | 3 +- .../TritonNvidiaGPU/Transforms/Passes.td | 6 +--- .../Transforms/RewriteTensorPointer.cpp | 25 ++++++--------- python/src/triton.cc | 5 ++- python/triton/compiler/backends/cuda.py | 3 +- python/triton/compiler/compiler.py | 31 +++++++++---------- 8 files changed, 32 insertions(+), 50 deletions(-) diff --git a/include/triton/Dialect/Triton/Transforms/Passes.h b/include/triton/Dialect/Triton/Transforms/Passes.h index 1d1ef2615d83..fde54fe17125 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.h +++ b/include/triton/Dialect/Triton/Transforms/Passes.h @@ -9,8 +9,7 @@ namespace triton { std::unique_ptr createCombineOpsPass(); std::unique_ptr createReorderBroadcastPass(); -std::unique_ptr -createRewriteTensorPointerPass(int computeCapability = 80); +std::unique_ptr createRewriteTensorPointerPass(); } // namespace triton diff --git a/include/triton/Dialect/Triton/Transforms/Passes.td b/include/triton/Dialect/Triton/Transforms/Passes.td index 219e72b0950b..404e8896c062 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.td +++ b/include/triton/Dialect/Triton/Transforms/Passes.td @@ -40,11 +40,7 @@ def TritonRewriteTensorPointer : Pass - ]; + let options = []; } #endif diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h index 9d3fd70890c7..a9ac3ffeab89 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h @@ -70,8 +70,7 @@ createTritonNvidiaGPUWSMaterializationPass(int computeCapability = 90); std::unique_ptr createTritonNvidiaGPUFenceInsertionPass(int computeCapability = 90); -std::unique_ptr -createTritonGPURewriteTensorPointerPass(int computeCapability = 80); +std::unique_ptr createTritonGPURewriteTensorPointerPass(); std::unique_ptr createTritonNvidiaGPUWSFixupMissingAttrs(); diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td index d038c610f999..b94b8bcb64ca 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td @@ -218,11 +218,7 @@ def TritonGPURewriteTensorPointer : Pass - ]; + let options = []; } def TritonGPUWSFixupMissingAttrs : Pass<"triton-nvidia-gpu-ws-fixup-missing-attrs", "mlir::ModuleOp"> { diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp index ba42896eeb64..31ffdeb7fb29 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp @@ -71,8 +71,8 @@ bool isDivisible(Value v, unsigned divisor) { } } -bool shouldRemove(tt::MakeTensorPtrOp &op, int computeCapability) { - if (computeCapability < 90 || !::triton::tools::getBoolEnv("ENABLE_TMA")) +bool shouldRemove(tt::MakeTensorPtrOp &op) { + if (!::triton::tools::getBoolEnv("ENABLE_TMA")) return true; auto resType = op.getResult() .getType() @@ -357,13 +357,7 @@ class TritonGPURewriteTensorPointerPass DenseMap rewritedInfo; public: - // explicit TritonGPURewriteTensorPointerPass(int computeCapability) - // : computeCapability(computeCapability) {} - TritonGPURewriteTensorPointerPass() = default; - TritonGPURewriteTensorPointerPass(int computeCapability) { - this->computeCapability = computeCapability; - } static bool needRewrite(Operation *op, const DenseSet &valueToRemove) { if (auto ifOp = dyn_cast(op)) { @@ -765,14 +759,14 @@ class TritonGPURewriteTensorPointerPass DenseSet valueToRemove; mod.walk([&valueToRemove, this](Operation *op) { if (auto makeTensorPtrOp = dyn_cast(op)) { - if (shouldRemove(makeTensorPtrOp, this->computeCapability)) + if (shouldRemove(makeTensorPtrOp)) valueToRemove.insert(op->getResult(0)); } if (llvm::isa(op)) { auto src = op->getOperand(0); if (tt::isTensorPointerType(src.getType())) { auto makeTensorPtrOp = getMakeTensorPtrOp(src); - if (shouldRemove(makeTensorPtrOp, this->computeCapability)) { + if (shouldRemove(makeTensorPtrOp)) { valueToRemove.insert(op->getResult(0)); } } @@ -781,7 +775,7 @@ class TritonGPURewriteTensorPointerPass auto src = op->getOperand(0); if (tt::isTensorPointerType(src.getType())) { auto makeTensorPtrOp = getMakeTensorPtrOp(src); - if (shouldRemove(makeTensorPtrOp, this->computeCapability)) + if (shouldRemove(makeTensorPtrOp)) valueToRemove.insert(src); } } @@ -790,7 +784,7 @@ class TritonGPURewriteTensorPointerPass for (unsigned i = 0, size = forOp.getInitArgs().size(); i < size; ++i) { if (tt::isTensorPointerType(iterOperands[i].getType())) { auto makeTensorPtrOp = getMakeTensorPtrOp(iterOperands[i]); - if (shouldRemove(makeTensorPtrOp, this->computeCapability)) + if (shouldRemove(makeTensorPtrOp)) valueToRemove.insert(iterOperands[i]); } } @@ -799,7 +793,7 @@ class TritonGPURewriteTensorPointerPass for (unsigned i = 0, size = yieldOp.getNumOperands(); i < size; ++i) { if (tt::isTensorPointerType(operands[i].getType())) { auto makeTensorPtrOp = getMakeTensorPtrOp(operands[i]); - if (shouldRemove(makeTensorPtrOp, this->computeCapability)) + if (shouldRemove(makeTensorPtrOp)) valueToRemove.insert(operands[i]); } } @@ -832,7 +826,6 @@ class TritonGPURewriteTensorPointerPass } }; -std::unique_ptr -mlir::createTritonGPURewriteTensorPointerPass(int computeCapability) { - return std::make_unique(computeCapability); +std::unique_ptr mlir::createTritonGPURewriteTensorPointerPass() { + return std::make_unique(); } diff --git a/python/src/triton.cc b/python/src/triton.cc index 8d0f832ef814..3285b069d6c8 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1694,9 +1694,8 @@ void init_triton_ir(py::module &&m) { self.addPass(mlir::triton::createReorderBroadcastPass()); }) .def("add_rewrite_tensor_pointer_pass", - [](mlir::PassManager &self, int computeCapability) { - self.addPass(mlir::triton::createRewriteTensorPointerPass( - computeCapability)); + [](mlir::PassManager &self) { + self.addPass(mlir::triton::createRewriteTensorPointerPass()); }) .def("add_tritongpu_ws_feasibility_checking_pass", [](mlir::PassManager &self, int computeCapability) { diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index d23c37e0734f..786721a11cda 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -48,7 +48,8 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, capability, cluster_inf pm.add_tritongpu_coalesce_pass() # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass pm.add_plan_cta_pass(cluster_info) - pm.add_tritongpu_rewrite_tensor_pointer_pass(capability) + if capability // 10 < 9: + pm.add_tritongpu_rewrite_tensor_pointer_pass() pm.add_plan_cta_pass(cluster_info) pm.add_tritongpu_remove_layout_conversions_pass() pm.add_tritongpu_accelerate_matmul_pass(capability) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 36bed4324002..adb8a580c0da 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -120,12 +120,23 @@ def hash(self): def compile(src, device_type=("cuda", None), signature=None, config=InstanceDescriptor(), constants=None, extern_libs=None, **kwargs): - # Get device type to decide which backend should be used - if constants is None: - constants = dict() # create backend handler backend = CUDABackend(device_type) options = backend.parse_options(**kwargs) + # Get device type to decide which backend should be used + if constants is None: + constants = dict() + # find out the signature of the function + if isinstance(signature, str): + signature = {k: v.strip() for k, v in enumerate(signature.split(","))} + # create cache manager + hash = make_hash(src, get_env_vars(), backend, config=config, constants=constants, signature=signature, + options=options) + fn_cache_manager = get_cache_manager(hash) + name = src.__name__ + metadata_filename = f"{name}.json" + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} + metadata_path = metadata_group.get(metadata_filename) # build compilation stages context = ir.context() @@ -138,21 +149,8 @@ def create_ttir(src): stages["ttir"] = (lambda path: parse_mlir_module(path, context), create_ttir) backend.add_stages(extern_libs, stages, options) - # find out the signature of the function - if isinstance(signature, str): - signature = {k: v.strip() for k, v in enumerate(signature.split(","))} - - # create cache manager - hash = make_hash(src, get_env_vars(), backend, config=config, constants=constants, signature=signature, - options=options) - fn_cache_manager = get_cache_manager(hash) - # load metadata if any # The group is addressed by the metadata - name, ext = src.__name__, "ttir" - metadata_filename = f"{name}.json" - metadata_group = fn_cache_manager.get_group(metadata_filename) or {} - metadata_path = metadata_group.get(metadata_filename) # initialize metadata metadata = None if metadata_path is not None: @@ -167,6 +165,7 @@ def create_ttir(src): metadata["device_type"] = device_type # run compilation pipeline and populate metadata + ext = "ttir" first_stage = list(stages.keys()).index(ext) asm = LazyDict() module = src From 4b44f8ca321218ebdf42670bf79e8dc81f5d8644 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Thu, 23 Nov 2023 15:41:06 -0800 Subject: [PATCH 13/64] cleaning --- python/triton/compiler/backends/cuda.py | 7 ++--- python/triton/compiler/compiler.py | 42 +++++++++++++------------ python/triton/runtime/jit.py | 6 ++-- 3 files changed, 28 insertions(+), 27 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 786721a11cda..d2c830e4d3a1 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -165,7 +165,6 @@ class CUDAOptions: enable_fp_fusion: bool = True extern_libs = None allow_fp8e4nv: bool = False - rewrite_tensor_pointer: bool = True max_num_imprecise_acc: bool = None debug: bool = False @@ -218,7 +217,7 @@ def create_ptx(src): stages["ptx"] = (lambda path: Path(path).read_text(), create_ptx) - # PTx -> CUBIN stage + # PTX -> CUBIN stage def create_cubin(src): return ptx_to_cubin(src, self.capability, opt.enable_fp_fusion) @@ -243,8 +242,8 @@ def get_version_key(self): def make_launcher_stub(self, fn, configs, metadata, name, signature, constants): ids_of_folded_args = tuple([int(k) for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else () - if "clusterDims" not in metadata: - metadata["clusterDims"] = [1, 1, 1] + if "cluster_dims" not in metadata: + metadata["cluster_dims"] = [1, 1, 1] if len(self.tma_infos) > 0: metadata["tensormaps_info"] = parse_tma_info(self.tma_infos, ids_of_folded_args) # set constant diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index adb8a580c0da..634e9d0b05e2 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -18,6 +18,7 @@ import torch from dataclasses import dataclass from .code_generator import ast_to_ttir +from pathlib import Path class LazyDict(dict): @@ -38,8 +39,6 @@ def optimize_ttir(mod, options): pm = ir.pass_manager(mod.context) pm.enable_debug() pm.add_inliner_pass() - if options.rewrite_tensor_pointer: - pm.add_rewrite_tensor_pointer_pass() pm.add_triton_combine_pass() pm.add_canonicalizer_pass() pm.add_reorder_broadcast_pass() @@ -137,6 +136,8 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc metadata_filename = f"{name}.json" metadata_group = fn_cache_manager.get_group(metadata_filename) or {} metadata_path = metadata_group.get(metadata_filename) + if metadata_path is not None: + pass # build compilation stages context = ir.context() @@ -185,7 +186,7 @@ def create_ttir(src): binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) # return handle to compiled kernel - return CompiledKernel(src, so_path, metadata, asm) + return CompiledKernel(so_path, metadata_group.get(metadata_filename)) class RuntimeCudaBackend: @@ -219,33 +220,34 @@ class CompiledKernel: launch_exit_hook = None tensormap_manager = TensorMapManager() - def __init__(self, fn, so_path, metadata, asm): + @staticmethod + def read_text_or_bytes(path): + try: + return path.read_text() + except UnicodeDecodeError: + return path.read_bytes() + + def __init__(self, so_path, metadata_path): + metadata_path = Path(metadata_path) + self.driver = RuntimeCudaBackend() # initialize launcher import importlib.util spec = importlib.util.spec_from_file_location("__triton_launcher", so_path) mod = importlib.util.module_from_spec(spec) - self.fn = fn spec.loader.exec_module(mod) self.c_wrapper = getattr(mod, "launch") # initialize metadata - self.shared = metadata["shared"] - self.num_warps = metadata["num_warps"] - self.num_ctas = metadata["num_ctas"] - self.num_stages = metadata["num_stages"] - self.clusterDims = metadata["clusterDims"] - if "tensormaps_info" in metadata: - self.tensormaps_info = metadata["tensormaps_info"] - self.constants = metadata["constants"] - self.device_type = metadata["device_type"] - # initialize asm dict - self.asm = asm + self.metadata = json.loads(metadata_path.read_text()) + for key, val in self.metadata.items(): + setattr(self, key, val) + # stores the text of each level of IR that was generated during compilation + asm_files = [file for file in metadata_path.parent.glob(f'{metadata_path.stem}.*') if file.suffix != '.json'] + self.asm = {file.suffix[1:]: self.read_text_or_bytes(file) for file in asm_files} # binaries are lazily initialized # because it involves doing runtime things # (e.g., checking amount of shared memory on current device) - self.metadata = metadata self.cu_module = None self.cu_function = None - self.driver = RuntimeCudaBackend() def _init_handles(self): if self.cu_module is not None: @@ -288,8 +290,8 @@ def runner(*args, stream=None): args_expand = self.assemble_tensormap_to_arg(args) if stream is None: stream = self.driver.get_stream(None) - self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0], - self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function, + self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.cluster_dims[0], + self.cluster_dims[1], self.cluster_dims[2], self.shared, stream, self.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand) return runner diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 275d7f17f7cd..b0f9186f5ec7 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -553,9 +553,9 @@ def get_special_arg(name: str, default=None): grid_2, bin.num_warps, bin.num_ctas, - bin.clusterDims[0], - bin.clusterDims[1], - bin.clusterDims[2], + bin.cluster_dims[0], + bin.cluster_dims[1], + bin.cluster_dims[2], bin.shared, stream, bin.cu_function, From 5f729eb2f40ca3cb39c211f455228865e5364d82 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Thu, 23 Nov 2023 15:57:01 -0800 Subject: [PATCH 14/64] cleaning --- python/triton/compiler/backends/cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index d2c830e4d3a1..412354a7f645 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -158,7 +158,7 @@ class CUDAOptions: num_warps: int = 4 num_ctas: int = 1 num_stages: int = 3 - cluster_dims: list = None + cluster_dims: tuple = (1, 1, 1) enable_warp_specialization: bool = False enable_persistent: bool = False optimize_epilogue: bool = False From 3692c4a1dec984bcbd8c685c4b7288a1fe4bca92 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Thu, 23 Nov 2023 22:13:37 -0800 Subject: [PATCH 15/64] cleaning --- python/triton/compiler/backends/cuda.py | 53 ++++++++++--------------- python/triton/compiler/compiler.py | 34 +++++++--------- 2 files changed, 34 insertions(+), 53 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 412354a7f645..1c05d75c2714 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -9,7 +9,6 @@ from ...runtime.jit import JITFunction from ..utils import get_ids_of_tensormaps, parse_tma_info from ..make_launcher import make_stub -from ...tools.disasm import get_sass import hashlib @@ -197,69 +196,57 @@ def add_stages(self, extern_libs, stages, opt): cluster_info.clusterDimZ = opt.cluster_dims[2] # TTIR -> TTGIR stage - def create_ttgir(src): + def create_ttgir(src, metadata): ttgir = ttir_to_ttgir(src, opt.num_warps, opt.num_ctas, self.capability) return optimize_ttgir(ttgir, opt.num_stages, opt.num_warps, opt.num_ctas, self.capability, cluster_info, opt.enable_warp_specialization, opt.enable_persistent, opt.optimize_epilogue) stages["ttgir"] = (lambda path: parse_mlir_module(path, context), create_ttgir) + # TTGIR -> LLIR stage - tma_infos = TMAInfos() - def create_llir(src): - return ttgir_to_llir(src, opt.extern_libs, self.capability, tma_infos) + def create_llir(src, metadata): + metadata["enable_warp_specialization"] = ir.is_ws_supported(src) + metadata["num_warps"] = get_num_warps(src) + tma_infos = TMAInfos() + ret = ttgir_to_llir(src, opt.extern_libs, self.capability, tma_infos) + if len(tma_infos) > 0: + metadata["tensormaps_info"] = parse_tma_info(tma_infos, metadata["ids_of_folded_args"]) + for i, _ in enumerate(metadata["tensormaps_info"]): + metadata["tensormaps_info"][i].ids_of_folded_args = metadata["ids_of_folded_args"] + metadata["ids_of_tensormaps"] = get_ids_of_tensormaps(metadata.get("tensormaps_info", None)) + metadata["shared"] = get_shared_memory_size(src) + return ret stages["llir"] = (lambda path: Path(path).read_text(), create_llir) # LLIR -> PTX stage - def create_ptx(src): + def create_ptx(src, metadata): return llir_to_ptx(src, opt.enable_fp_fusion, self.capability) stages["ptx"] = (lambda path: Path(path).read_text(), create_ptx) # PTX -> CUBIN stage - def create_cubin(src): + def create_cubin(src, metadata): + metadata["name"] = get_kernel_name(src, pattern='// .globl') return ptx_to_cubin(src, self.capability, opt.enable_fp_fusion) stages["cubin"] = (lambda path: Path(path).read_bytes(), create_cubin) - self.tma_infos = tma_infos - - def add_meta_info(self, ir_name, cur_module, next_module, metadata, asm): - if ir_name == "cubin": - asm[ir_name] = next_module - asm["sass"] = lambda: get_sass(next_module) - if ir_name == "llir" and "shared" not in metadata: - metadata["shared"] = get_shared_memory_size(cur_module) - if ir_name == "ttgir": - metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module) - metadata["num_warps"] = get_num_warps(next_module) - if ir_name == "ptx": - metadata["name"] = get_kernel_name(next_module, pattern='// .globl') def get_version_key(self): return f'{get_cuda_version_key()}-{self.capability}' def make_launcher_stub(self, fn, configs, metadata, name, signature, constants): - ids_of_folded_args = tuple([int(k) - for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else () - if "cluster_dims" not in metadata: - metadata["cluster_dims"] = [1, 1, 1] - if len(self.tma_infos) > 0: - metadata["tensormaps_info"] = parse_tma_info(self.tma_infos, ids_of_folded_args) - # set constant - if "tensormaps_info" in metadata: - for i, _ in enumerate(metadata["tensormaps_info"]): - metadata["tensormaps_info"][i].ids_of_folded_args = ids_of_folded_args - ids_of_tensormaps = get_ids_of_tensormaps(metadata.get("tensormaps_info", None)) if isinstance(fn, JITFunction) and "tensormaps_info" in metadata: fn.tensormaps_info = metadata["tensormaps_info"] ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else () ids = { - "ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": - ids_of_const_exprs + "ids_of_tensormaps": metadata["ids_of_tensormaps"], "ids_of_folded_args": metadata["ids_of_folded_args"], + "ids_of_const_exprs": ids_of_const_exprs } enable_warp_specialization = False + # set constant return make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization) @classmethod diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 634e9d0b05e2..cccd72e3be84 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -11,7 +11,7 @@ from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager from ..runtime.jit import (get_cuda_stream) -from .utils import (InfoFromBackendForTensorMap, TensorMapManager) +from .utils import (TensorMapManager) from .backends.cuda import CUDABackend from ..runtime.driver import driver @@ -137,54 +137,48 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc metadata_group = fn_cache_manager.get_group(metadata_filename) or {} metadata_path = metadata_group.get(metadata_filename) if metadata_path is not None: - pass + metadata = json.loads(Path(metadata_path).read_text()) + so_path = backend.make_launcher_stub(src, [config], metadata, name, signature, constants) + return CompiledKernel(so_path, metadata_path) # build compilation stages context = ir.context() stages = dict() - def create_ttir(src): + def create_ttir(src, metadata): ttir = ast_to_ttir(src, signature, config, constants, options=options) return optimize_ttir(ttir, options=options) stages["ttir"] = (lambda path: parse_mlir_module(path, context), create_ttir) backend.add_stages(extern_libs, stages, options) - # load metadata if any - # The group is addressed by the metadata # initialize metadata - metadata = None - if metadata_path is not None: - with open(metadata_path) as f: - metadata = json.load(f) - if 'tensormaps_info' in metadata: - metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']] - else: - metadata = {"constants": constants} - metadata.update(options.__dict__) - metadata.update(get_env_vars()) - metadata["device_type"] = device_type + metadata = { + "constants": constants, + "device_type": device_type, + "ids_of_folded_args": tuple([int(k) for k in config.ids_of_folded_args]), + **options.__dict__, + **get_env_vars(), + } # run compilation pipeline and populate metadata ext = "ttir" first_stage = list(stages.keys()).index(ext) - asm = LazyDict() module = src for ir_name, (parse_ir, compile_ir) in list(stages.items())[first_stage:]: ir_filename = f"{name}.{ir_name}" path = metadata_group.get(ir_filename) - next_module = compile_ir(module) if path is None else parse_ir(path) + next_module = compile_ir(module, metadata) if path is None else parse_ir(path) metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) - backend.add_meta_info(ir_name, module, next_module, metadata, asm) module = next_module # cache manager - so_path = backend.make_launcher_stub(src, [config], metadata, name, signature, constants) # write-back metadata, if it didn't come from the cache if metadata_path is None: metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) + so_path = backend.make_launcher_stub(src, [config], metadata, name, signature, constants) # return handle to compiled kernel return CompiledKernel(so_path, metadata_group.get(metadata_filename)) From 756512e0049147843f643fb12c3aa27dd2cbf5d8 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Thu, 23 Nov 2023 22:21:08 -0800 Subject: [PATCH 16/64] remove parser from stage --- python/triton/compiler/backends/cuda.py | 11 ++++------- python/triton/compiler/compiler.py | 14 +++++--------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 1c05d75c2714..c5e82eb1bcf4 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -1,5 +1,4 @@ from triton.common.backend import BaseBackend -from pathlib import Path from dataclasses import dataclass from ..._C.libtriton.triton import ClusterInfo, get_num_warps, TMAInfos, translate_triton_gpu_to_llvmir, get_shared_memory_size, translate_llvmir_to_ptx, compile_ptx_to_cubin, add_external_libs from ...common.backend import get_cuda_version_key, path_to_ptxas @@ -187,8 +186,6 @@ def parse_options(self, **opts) -> Any: return options def add_stages(self, extern_libs, stages, opt): - - context = ir.context() cluster_info = ClusterInfo() if opt.cluster_dims is not None: cluster_info.clusterDimX = opt.cluster_dims[0] @@ -201,7 +198,7 @@ def create_ttgir(src, metadata): return optimize_ttgir(ttgir, opt.num_stages, opt.num_warps, opt.num_ctas, self.capability, cluster_info, opt.enable_warp_specialization, opt.enable_persistent, opt.optimize_epilogue) - stages["ttgir"] = (lambda path: parse_mlir_module(path, context), create_ttgir) + stages["ttgir"] = create_ttgir # TTGIR -> LLIR stage @@ -218,20 +215,20 @@ def create_llir(src, metadata): metadata["shared"] = get_shared_memory_size(src) return ret - stages["llir"] = (lambda path: Path(path).read_text(), create_llir) + stages["llir"] = create_llir # LLIR -> PTX stage def create_ptx(src, metadata): return llir_to_ptx(src, opt.enable_fp_fusion, self.capability) - stages["ptx"] = (lambda path: Path(path).read_text(), create_ptx) + stages["ptx"] = create_ptx # PTX -> CUBIN stage def create_cubin(src, metadata): metadata["name"] = get_kernel_name(src, pattern='// .globl') return ptx_to_cubin(src, self.capability, opt.enable_fp_fusion) - stages["cubin"] = (lambda path: Path(path).read_bytes(), create_cubin) + stages["cubin"] = create_cubin def get_version_key(self): return f'{get_cuda_version_key()}-{self.capability}' diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index cccd72e3be84..3a13c447166d 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -142,14 +142,13 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc return CompiledKernel(so_path, metadata_path) # build compilation stages - context = ir.context() stages = dict() def create_ttir(src, metadata): ttir = ast_to_ttir(src, signature, config, constants, options=options) return optimize_ttir(ttir, options=options) - stages["ttir"] = (lambda path: parse_mlir_module(path, context), create_ttir) + stages["ttir"] = create_ttir backend.add_stages(extern_libs, stages, options) # initialize metadata @@ -162,14 +161,11 @@ def create_ttir(src, metadata): } # run compilation pipeline and populate metadata - ext = "ttir" - first_stage = list(stages.keys()).index(ext) + first_stage = list(stages.keys()).index("ttir") module = src - for ir_name, (parse_ir, compile_ir) in list(stages.items())[first_stage:]: - ir_filename = f"{name}.{ir_name}" - path = metadata_group.get(ir_filename) - next_module = compile_ir(module, metadata) if path is None else parse_ir(path) - metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) + for ext, compile_ir in list(stages.items())[first_stage:]: + next_module = compile_ir(module, metadata) + metadata_group[f"{name}.{ext}"] = fn_cache_manager.put(next_module, f"{name}.{ext}") module = next_module # cache manager From 44ba22d7bd3008b92847173d674dc0266f1cbdbf Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Thu, 23 Nov 2023 22:37:19 -0800 Subject: [PATCH 17/64] cleaning --- python/triton/compiler/compiler.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 3a13c447166d..c53de0b8f006 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -122,12 +122,15 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc # create backend handler backend = CUDABackend(device_type) options = backend.parse_options(**kwargs) + # Get device type to decide which backend should be used if constants is None: constants = dict() + # find out the signature of the function if isinstance(signature, str): signature = {k: v.strip() for k, v in enumerate(signature.split(","))} + # create cache manager hash = make_hash(src, get_env_vars(), backend, config=config, constants=constants, signature=signature, options=options) @@ -143,12 +146,7 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc # build compilation stages stages = dict() - - def create_ttir(src, metadata): - ttir = ast_to_ttir(src, signature, config, constants, options=options) - return optimize_ttir(ttir, options=options) - - stages["ttir"] = create_ttir + stages["ttir"] = lambda src, metadata: optimize_ttir(src, options=options) backend.add_stages(extern_libs, stages, options) # initialize metadata @@ -162,7 +160,7 @@ def create_ttir(src, metadata): # run compilation pipeline and populate metadata first_stage = list(stages.keys()).index("ttir") - module = src + module = ast_to_ttir(src, signature, config, constants, options=options) for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) metadata_group[f"{name}.{ext}"] = fn_cache_manager.put(next_module, f"{name}.{ext}") From 9f62ce624d4d3f625cc4885a414c714be6b123d1 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Thu, 23 Nov 2023 22:49:48 -0800 Subject: [PATCH 18/64] more cleaning --- python/triton/compiler/backends/cuda.py | 16 ++++++++++++++++ python/triton/compiler/compiler.py | 20 ++------------------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index c5e82eb1bcf4..60b71ae6cabd 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -11,6 +11,20 @@ import hashlib +def optimize_ttir(mod, options): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + pm.add_inliner_pass() + pm.add_triton_combine_pass() + pm.add_canonicalizer_pass() + pm.add_reorder_broadcast_pass() + pm.add_cse_pass() + pm.add_licm_pass() + pm.add_symbol_dce_pass() + pm.run(mod) + return mod + + def ttir_to_ttgir(mod, num_warps, num_ctas, capability): pm = ir.pass_manager(mod.context) pm.enable_debug() @@ -192,6 +206,8 @@ def add_stages(self, extern_libs, stages, opt): cluster_info.clusterDimY = opt.cluster_dims[1] cluster_info.clusterDimZ = opt.cluster_dims[2] + stages["ttir"] = lambda src, metadata: optimize_ttir(src, opt) + # TTIR -> TTGIR stage def create_ttgir(src, metadata): ttgir = ttir_to_ttgir(src, opt.num_warps, opt.num_ctas, self.capability) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index c53de0b8f006..ff0617b39729 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -35,20 +35,6 @@ def __getitem__(self, key): # ------------------------------------------------------------------------------ -def optimize_ttir(mod, options): - pm = ir.pass_manager(mod.context) - pm.enable_debug() - pm.add_inliner_pass() - pm.add_triton_combine_pass() - pm.add_canonicalizer_pass() - pm.add_reorder_broadcast_pass() - pm.add_cse_pass() - pm.add_licm_pass() - pm.add_symbol_dce_pass() - pm.run(mod) - return mod - - def convert_type_repr(x): # Currently we only capture the pointer type and assume the pointer is on global memory. # TODO: Capture and support shared memory space @@ -100,8 +86,7 @@ def make_hash(fn, env_vars, device_backend, config, signature, constants, option def parse_mlir_module(path, context): module = ir.parse_mlir_module(path, context) - # module takes ownership of the context - module.context = context + module.context = context # module takes ownership of the context return module @@ -140,13 +125,13 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc metadata_group = fn_cache_manager.get_group(metadata_filename) or {} metadata_path = metadata_group.get(metadata_filename) if metadata_path is not None: + # cache hit! metadata = json.loads(Path(metadata_path).read_text()) so_path = backend.make_launcher_stub(src, [config], metadata, name, signature, constants) return CompiledKernel(so_path, metadata_path) # build compilation stages stages = dict() - stages["ttir"] = lambda src, metadata: optimize_ttir(src, options=options) backend.add_stages(extern_libs, stages, options) # initialize metadata @@ -157,7 +142,6 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc **options.__dict__, **get_env_vars(), } - # run compilation pipeline and populate metadata first_stage = list(stages.keys()).index("ttir") module = ast_to_ttir(src, signature, config, constants, options=options) From 6ee5cee7e66f1679b2f0dbc74c46e2209f591154 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Fri, 24 Nov 2023 11:14:26 -0800 Subject: [PATCH 19/64] more cleaning --- python/triton/compiler/backends/cuda.py | 5 +-- python/triton/compiler/code_generator.py | 19 +++++----- python/triton/compiler/compiler.py | 44 ++++++++++++++++-------- 3 files changed, 41 insertions(+), 27 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 60b71ae6cabd..236fab74f1c7 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -249,7 +249,7 @@ def create_cubin(src, metadata): def get_version_key(self): return f'{get_cuda_version_key()}-{self.capability}' - def make_launcher_stub(self, fn, configs, metadata, name, signature, constants): + def make_launcher_stub(self, fn, metadata, name, specialization): if isinstance(fn, JITFunction) and "tensormaps_info" in metadata: fn.tensormaps_info = metadata["tensormaps_info"] ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else () @@ -260,7 +260,8 @@ def make_launcher_stub(self, fn, configs, metadata, name, signature, constants): enable_warp_specialization = False # set constant - return make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization) + return make_stub(name, specialization.signature, specialization.constants, ids, + enable_warp_specialization=enable_warp_specialization) @classmethod def create_backend(cls, device_type: str): diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 591dd7120304..efbcedf7f621 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -1188,29 +1188,28 @@ def kernel_suffix(signature, specialization): return suffix -def ast_to_ttir(fn, signature, specialization, constants, options): +def ast_to_ttir(fn, specialization, options): + config = specialization.config # canonicalize signature - if isinstance(signature, str): - signature = {k: v.strip() for k, v in enumerate(signature.split(","))} context = ir.context() context.load_triton() # create kernel prototype cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in constants.items()} + constants = {cst_key(key): value for key, value in specialization.constants.items()} # visit kernel AST gscope = fn.__globals__.copy() - function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)]) - tys = list(signature.values()) - new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in specialization.equal_to_1} - new_attrs = {k: [("tt.divisibility", 16)] for k in specialization.divisible_by_16} - for k in specialization.divisible_by_8: + function_name = '_'.join([fn.__name__, kernel_suffix(specialization.signature.values(), config)]) + tys = list(specialization.signature.values()) + new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in config.equal_to_1} + new_attrs = {k: [("tt.divisibility", 16)] for k in config.divisible_by_16} + for k in config.divisible_by_8: attr = new_attrs[k] if k in new_attrs else [] attr.append(("tt.max_divisibility", 8)) new_attrs[k] = attr all_constants = constants.copy() all_constants.update(new_constants) - arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants] + arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in constants] file_name, begin_line = _get_fn_file_line(fn) prototype = language.function_type([], arg_types) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index ff0617b39729..5d0e11c0f3a3 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -44,10 +44,10 @@ def convert_type_repr(x): return x -def make_hash(fn, env_vars, device_backend, config, signature, constants, options): +def make_hash(fn, env_vars, device_backend, specialization, options): version_key = device_backend.get_version_key() env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())] - key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{config.hash()}-{constants}-{options.hash()}-{env_vars_list}" + key = f"{fn.cache_key}-{version_key}-{specialization.hash()}-{options.hash()}-{env_vars_list}" return hashlib.md5(key.encode("utf-8")).hexdigest() @@ -102,23 +102,37 @@ def hash(self): return hashlib.md5(key.encode("utf-8")).hexdigest() +@dataclass +class SpecializationDescriptor: + config: InstanceDescriptor + signature: dict + constants: dict + + def __post_init__(self): + if isinstance(self.signature, str): + self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} + if self.constants is None: + self.constants = dict() + + def hash(self): + key = f"{self.config.hash()}-{self.signature.values()}-{self.constants}" + return hashlib.md5(key.encode("utf-8")).hexdigest() + + def compile(src, device_type=("cuda", None), signature=None, config=InstanceDescriptor(), constants=None, extern_libs=None, **kwargs): + # TODO (backward-breaking): + # - merge InstanceDescriptor and SpecializationDescriptor + # - extern_libs => linker_flags: Dict + # - **kwargs -> compiler_flags: Dict + # create backend handler backend = CUDABackend(device_type) options = backend.parse_options(**kwargs) - - # Get device type to decide which backend should be used - if constants is None: - constants = dict() - - # find out the signature of the function - if isinstance(signature, str): - signature = {k: v.strip() for k, v in enumerate(signature.split(","))} + specialization = SpecializationDescriptor(config, signature, constants) # create cache manager - hash = make_hash(src, get_env_vars(), backend, config=config, constants=constants, signature=signature, - options=options) + hash = make_hash(src, get_env_vars(), backend, specialization, options=options) fn_cache_manager = get_cache_manager(hash) name = src.__name__ metadata_filename = f"{name}.json" @@ -127,7 +141,7 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc if metadata_path is not None: # cache hit! metadata = json.loads(Path(metadata_path).read_text()) - so_path = backend.make_launcher_stub(src, [config], metadata, name, signature, constants) + so_path = backend.make_launcher_stub(src, metadata, name, specialization) return CompiledKernel(so_path, metadata_path) # build compilation stages @@ -144,7 +158,7 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc } # run compilation pipeline and populate metadata first_stage = list(stages.keys()).index("ttir") - module = ast_to_ttir(src, signature, config, constants, options=options) + module = ast_to_ttir(src, specialization, options=options) for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) metadata_group[f"{name}.{ext}"] = fn_cache_manager.put(next_module, f"{name}.{ext}") @@ -156,7 +170,7 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) - so_path = backend.make_launcher_stub(src, [config], metadata, name, signature, constants) + so_path = backend.make_launcher_stub(src, metadata, name, specialization) # return handle to compiled kernel return CompiledKernel(so_path, metadata_group.get(metadata_filename)) From c0c1076470b0846221e47a97eb54d4f72a67f8a8 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Fri, 24 Nov 2023 11:17:20 -0800 Subject: [PATCH 20/64] . --- python/triton/compiler/compiler.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 5d0e11c0f3a3..dcd9a47ccb38 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -123,8 +123,8 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc extern_libs=None, **kwargs): # TODO (backward-breaking): # - merge InstanceDescriptor and SpecializationDescriptor - # - extern_libs => linker_flags: Dict - # - **kwargs -> compiler_flags: Dict + # - extern_libs => linker_flags: dict + # - **kwargs -> compiler_flags: dict # create backend handler backend = CUDABackend(device_type) @@ -144,10 +144,6 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc so_path = backend.make_launcher_stub(src, metadata, name, specialization) return CompiledKernel(so_path, metadata_path) - # build compilation stages - stages = dict() - backend.add_stages(extern_libs, stages, options) - # initialize metadata metadata = { "constants": constants, @@ -157,18 +153,17 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc **get_env_vars(), } # run compilation pipeline and populate metadata + stages = dict() + backend.add_stages(extern_libs, stages, options) first_stage = list(stages.keys()).index("ttir") module = ast_to_ttir(src, specialization, options=options) for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) metadata_group[f"{name}.{ext}"] = fn_cache_manager.put(next_module, f"{name}.{ext}") module = next_module - - # cache manager - # write-back metadata, if it didn't come from the cache - if metadata_path is None: - metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, - binary=False) + # write-back metadata + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, + binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) so_path = backend.make_launcher_stub(src, metadata, name, specialization) # return handle to compiled kernel From 78f9670e336a51a686d68c8e343e754a24bd2253 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Fri, 24 Nov 2023 16:28:08 -0800 Subject: [PATCH 21/64] removed more dead code --- python/triton/compiler/compiler.py | 111 +++++------------------------ python/triton/compiler/utils.py | 22 ------ python/triton/runtime/driver.py | 42 +++++++++++ python/triton/runtime/jit.py | 5 +- 4 files changed, 63 insertions(+), 117 deletions(-) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index dcd9a47ccb38..39c86c4148d6 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -2,7 +2,6 @@ import hashlib import json -import re from .._C.libtriton.triton import (get_env_vars, ir) from ..common.build import is_hip @@ -10,40 +9,19 @@ # TODO: runtime.errors from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager -from ..runtime.jit import (get_cuda_stream) -from .utils import (TensorMapManager) +from ..runtime.jit import (get_current_device) from .backends.cuda import CUDABackend from ..runtime.driver import driver -import torch from dataclasses import dataclass from .code_generator import ast_to_ttir from pathlib import Path - -class LazyDict(dict): - - def __getitem__(self, key): - val = dict.__getitem__(self, key) - if callable(val): - return val() - return val - - # ------------------------------------------------------------------------------ # compiler # ------------------------------------------------------------------------------ -def convert_type_repr(x): - # Currently we only capture the pointer type and assume the pointer is on global memory. - # TODO: Capture and support shared memory space - match = re.search(r'!tt\.ptr<([^,]+)', x) - if match is not None: - return '*' + convert_type_repr(match.group(1)) - return x - - def make_hash(fn, env_vars, device_backend, specialization, options): version_key = device_backend.get_version_key() env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())] @@ -170,47 +148,15 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc return CompiledKernel(so_path, metadata_group.get(metadata_filename)) -class RuntimeCudaBackend: - - def __init__(self) -> None: - pass - - def get_load_binary_fn(self): - return driver.utils.load_binary - - def get_stream(self): - return get_cuda_stream() - - def get_device_properties(self, device): - return driver.utils.get_device_properties(device) - - def get_current_device(self): - return torch.cuda.current_device() - - def set_current_device(self, device): - torch.cuda.set_device(device) - - def get_kernel_bin(self): - return "cubin" - - class CompiledKernel: # Hooks for external tools to monitor the execution of triton kernels + # TODO: move out of this namespace since it's a runtime thing launch_enter_hook = None launch_exit_hook = None - tensormap_manager = TensorMapManager() - - @staticmethod - def read_text_or_bytes(path): - try: - return path.read_text() - except UnicodeDecodeError: - return path.read_bytes() def __init__(self, so_path, metadata_path): metadata_path = Path(metadata_path) - self.driver = RuntimeCudaBackend() # initialize launcher import importlib.util spec = importlib.util.spec_from_file_location("__triton_launcher", so_path) @@ -223,56 +169,35 @@ def __init__(self, so_path, metadata_path): setattr(self, key, val) # stores the text of each level of IR that was generated during compilation asm_files = [file for file in metadata_path.parent.glob(f'{metadata_path.stem}.*') if file.suffix != '.json'] - self.asm = {file.suffix[1:]: self.read_text_or_bytes(file) for file in asm_files} + self.asm = { + file.suffix[1:]: file.read_bytes() if file.suffix[1:] == driver.binary_ext else file.read_text() + for file in asm_files + } + self.kernel = self.asm[driver.binary_ext] # binaries are lazily initialized # because it involves doing runtime things # (e.g., checking amount of shared memory on current device) - self.cu_module = None - self.cu_function = None + self.module = None + self.function = None def _init_handles(self): - if self.cu_module is not None: + if self.module is not None: return - - device = self.driver.get_current_device() - bin_path = self.driver.get_kernel_bin() - max_shared = self.driver.get_device_properties(device)["max_shared_mem"] - fn_load_binary = self.driver.get_load_binary_fn() - + device = get_current_device() + # not enough shared memory to run the kernel + max_shared = driver.utils.get_device_properties(device)["max_shared_mem"] if self.shared > max_shared: raise OutOfResources(self.shared, max_shared, "shared memory") - - mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device) - - self.n_spills = n_spills - self.n_regs = n_regs - self.cu_module = mod - self.cu_function = func + # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` + self.module, self.function, self.n_regs, self.n_spills = driver.utils.load_binary( + self.name, self.kernel, self.shared, device) def __getattribute__(self, name): if name == 'c_wrapper': self._init_handles() return super().__getattribute__(name) - # capture args and expand args with cutensormap* - def assemble_tensormap_to_arg(self, args): - args_with_tma = list(args) - if hasattr(self, 'tensormaps_info'): - # tuple for hashable - args_ptr = tuple([arg.data_ptr() if hasattr(arg, 'data_ptr') else arg for arg in args]) - for i, e in enumerate(self.tensormaps_info): - args_with_tma.append(CompiledKernel.tensormap_manager[(e, args_ptr)]) - return args_with_tma - def __getitem__(self, grid): self._init_handles() - - def runner(*args, stream=None): - args_expand = self.assemble_tensormap_to_arg(args) - if stream is None: - stream = self.driver.get_stream(None) - self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.cluster_dims[0], - self.cluster_dims[1], self.cluster_dims[2], self.shared, stream, self.cu_function, - CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand) - - return runner + return lambda *args, stream=None: driver.launch_kernel(self, stream, grid, CompiledKernel.launch_enter_hook, + CompiledKernel.launch_exit_hook, *args) diff --git a/python/triton/compiler/utils.py b/python/triton/compiler/utils.py index ef629c75a6bc..6135e35d3a9e 100644 --- a/python/triton/compiler/utils.py +++ b/python/triton/compiler/utils.py @@ -280,25 +280,3 @@ def __eq__(self, other): other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, other.oobFill) - - -class TensorMapManager: - - def __init__(self): - self.tensormaps_device = {} - - def __getitem__(self, key: tuple): - if key in self.tensormaps_device: - return int(self.tensormaps_device[key]) - else: - (e, args) = key - t_tensormap = e.tensormap(args) - TENSORMAP_SIZE_IN_BYTES = 128 - t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES) - driver.utils.cuMemcpyHtoD(t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES) - self.tensormaps_device[key] = t_tensormap_device - return int(self.tensormaps_device[key]) - - def __del__(self): - for _, v in self.tensormaps_device.items(): - driver.utils.cuMemFree(v) diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index 767a567c452b..59a809fb861a 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -6,6 +6,7 @@ from ..common.build import _build from .cache import get_cache_manager +from ..runtime import driver class DriverBase(metaclass=abc.ABCMeta): @@ -65,7 +66,30 @@ def __init__(self): self.cuMemFree = mod.cuMemFree +class TensorMapManager: + + def __init__(self): + self.tensormaps_device = {} + + def __getitem__(self, key: tuple): + if key in self.tensormaps_device: + return int(self.tensormaps_device[key]) + else: + (e, args) = key + t_tensormap = e.tensormap(args) + TENSORMAP_SIZE_IN_BYTES = 128 + t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES) + driver.utils.cuMemcpyHtoD(t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES) + self.tensormaps_device[key] = t_tensormap_device + return int(self.tensormaps_device[key]) + + def __del__(self): + for _, v in self.tensormaps_device.items(): + driver.utils.cuMemFree(v) + + class CudaDriver(DriverBase): + tensormap_manager = TensorMapManager() def __new__(cls): if not hasattr(cls, "instance"): @@ -75,6 +99,24 @@ def __new__(cls): def __init__(self): self.utils = CudaUtils() self.backend = self.CUDA + self.binary_ext = "cubin" + + def assemble_tensormap_to_arg(self, args): + args_with_tma = list(args) + if hasattr(self, 'tensormaps_info'): + # tuple for hashable + args_ptr = tuple([arg.data_ptr() if hasattr(arg, 'data_ptr') else arg for arg in args]) + for i, e in enumerate(self.tensormaps_info): + args_with_tma.append(CudaDriver.tensormap_manager[(e, args_ptr)]) + return args_with_tma + + def launch_kernel(self, kernel, stream, grid, launch_enter_hook, launch_exit_hook, *args): + args_expand = self.assemble_tensormap_to_arg(args) + if stream is None: + stream = self.get_stream(None) + kernel.c_wrapper(grid[0], grid[1], grid[2], kernel.num_warps, kernel.num_ctas, kernel.cluster_dims[0], + kernel.cluster_dims[1], kernel.cluster_dims[2], kernel.shared, stream, kernel.function, + launch_enter_hook, launch_exit_hook, kernel, *args_expand) # ----------------------------- diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index b0f9186f5ec7..bb185da30ca6 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -13,6 +13,7 @@ from .._C.libtriton.triton import TMAInfos from ..common.backend import get_backend, get_cuda_version_key from .interpreter import InterpretedFunction +from ..runtime.driver import driver def get_cuda_stream(idx=None): @@ -558,11 +559,11 @@ def get_special_arg(name: str, default=None): bin.cluster_dims[2], bin.shared, stream, - bin.cu_function, + bin.function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, - *bin.assemble_tensormap_to_arg(non_constexpr_arg_values), + *driver.assemble_tensormap_to_arg(non_constexpr_arg_values), ) return bin From 5c61a547e3400b710502fd9872007d6475e37bb3 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Fri, 24 Nov 2023 16:30:39 -0800 Subject: [PATCH 22/64] removed override-related code. Will re-add support for `ttgir` in `triton.compile` later. --- python/triton/compiler/compiler.py | 42 +----------------------------- 1 file changed, 1 insertion(+), 41 deletions(-) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 39c86c4148d6..98624fd0ca6b 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -3,8 +3,7 @@ import hashlib import json -from .._C.libtriton.triton import (get_env_vars, ir) -from ..common.build import is_hip +from .._C.libtriton.triton import (get_env_vars) # from ..runtime import driver, jit, JITFunction # TODO: runtime.errors from ..runtime.autotuner import OutOfResources @@ -29,45 +28,6 @@ def make_hash(fn, env_vars, device_backend, specialization, options): return hashlib.md5(key.encode("utf-8")).hexdigest() -# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, -# and any following whitespace -# - (public\s+)? : optionally match the keyword public and any following whitespace -# - (@\w+) : match an @ symbol followed by one or more word characters -# (letters, digits, or underscores), and capture it as group 1 (the function name) -# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing -# zero or more arguments separated by commas, and capture it as group 2 (the argument list) -# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 -mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" -ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" -prototype_pattern = { - "ttir": mlir_prototype_pattern, - "ttgir": mlir_prototype_pattern, - "ptx": ptx_prototype_pattern, -} - -# - ((?:[^,\s<]+|<[^>]+>)+): Capturing group that matches one or more of either: -# [^,\s<]+: One or more characters that are not a comma, whitespace, or the < symbol. -# |: OR -# <[^>]+>: A string that starts with < and ends with >, containing any characters except > in between. -mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<]+|<[^>]+>)+),?' -ptx_arg_type_pattern = r"\.param\s+\.(\w+)" -arg_type_pattern = { - "ttir": mlir_arg_type_pattern, - "ttgir": mlir_arg_type_pattern, - "ptx": ptx_arg_type_pattern, -} -if is_hip(): - ttgir_num_warps_pattern = r'"triton_gpu_rocm.num-warps"\s?=\s?(\d+)\s?:' -else: - ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' - - -def parse_mlir_module(path, context): - module = ir.parse_mlir_module(path, context) - module.context = context # module takes ownership of the context - return module - - @dataclass class InstanceDescriptor: divisible_by_16: set = None From 7a1f16734a9ff53c871cc65b2bbd657a30baf356 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Fri, 24 Nov 2023 22:06:49 -0800 Subject: [PATCH 23/64] more cleaning --- python/triton/compiler/backends/cuda.py | 2 +- python/triton/compiler/code_generator.py | 1 - python/triton/compiler/compiler.py | 21 +++++---------------- 3 files changed, 6 insertions(+), 18 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 236fab74f1c7..20d47f07fa93 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -246,7 +246,7 @@ def create_cubin(src, metadata): stages["cubin"] = create_cubin - def get_version_key(self): + def hash(self): return f'{get_cuda_version_key()}-{self.capability}' def make_launcher_stub(self, fn, metadata, name, specialization): diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index efbcedf7f621..9ff2dbff8c3f 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -1190,7 +1190,6 @@ def kernel_suffix(signature, specialization): def ast_to_ttir(fn, specialization, options): config = specialization.config - # canonicalize signature context = ir.context() context.load_triton() # create kernel prototype diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 98624fd0ca6b..3d8150023e75 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -8,25 +8,13 @@ # TODO: runtime.errors from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager -from ..runtime.jit import (get_current_device) -from .backends.cuda import CUDABackend - +from ..runtime.jit import get_current_device from ..runtime.driver import driver +from .backends.cuda import CUDABackend from dataclasses import dataclass from .code_generator import ast_to_ttir from pathlib import Path -# ------------------------------------------------------------------------------ -# compiler -# ------------------------------------------------------------------------------ - - -def make_hash(fn, env_vars, device_backend, specialization, options): - version_key = device_backend.get_version_key() - env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())] - key = f"{fn.cache_key}-{version_key}-{specialization.hash()}-{options.hash()}-{env_vars_list}" - return hashlib.md5(key.encode("utf-8")).hexdigest() - @dataclass class InstanceDescriptor: @@ -64,13 +52,14 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc # - extern_libs => linker_flags: dict # - **kwargs -> compiler_flags: dict - # create backend handler + # create backend backend = CUDABackend(device_type) options = backend.parse_options(**kwargs) specialization = SpecializationDescriptor(config, signature, constants) # create cache manager - hash = make_hash(src, get_env_vars(), backend, specialization, options=options) + key = f"{src.cache_key}-{backend.hash()}-{specialization.hash()}-{options.hash()}-{frozenset(get_env_vars().items())}" + hash = hashlib.md5(key.encode("utf-8")).hexdigest() fn_cache_manager = get_cache_manager(hash) name = src.__name__ metadata_filename = f"{name}.json" From 478dce1e15194526e0c5456e5fa19fc75b363343 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Fri, 24 Nov 2023 22:14:56 -0800 Subject: [PATCH 24/64] . --- python/triton/compiler/compiler.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 3d8150023e75..a663d26dafae 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -145,8 +145,3 @@ def __getattribute__(self, name): if name == 'c_wrapper': self._init_handles() return super().__getattribute__(name) - - def __getitem__(self, grid): - self._init_handles() - return lambda *args, stream=None: driver.launch_kernel(self, stream, grid, CompiledKernel.launch_enter_hook, - CompiledKernel.launch_exit_hook, *args) From 1f64f5bd9b434258073c406b6cf944afc102ffdc Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Fri, 24 Nov 2023 23:24:12 -0800 Subject: [PATCH 25/64] partial support for ttgir input (not finished) --- python/triton/compiler/backends/cuda.py | 13 ++- python/triton/compiler/code_generator.py | 5 +- python/triton/compiler/compiler.py | 103 ++++++++++++++++++++--- python/triton/language/core.py | 1 - 4 files changed, 99 insertions(+), 23 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 20d47f07fa93..fca2cce3a04e 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -5,7 +5,6 @@ from ..._C.libtriton.triton import ir, runtime import functools from typing import Any -from ...runtime.jit import JITFunction from ..utils import get_ids_of_tensormaps, parse_tma_info from ..make_launcher import make_stub import hashlib @@ -249,18 +248,16 @@ def create_cubin(src, metadata): def hash(self): return f'{get_cuda_version_key()}-{self.capability}' - def make_launcher_stub(self, fn, metadata, name, specialization): - if isinstance(fn, JITFunction) and "tensormaps_info" in metadata: - fn.tensormaps_info = metadata["tensormaps_info"] - ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else () + def make_launcher_stub(self, src, metadata, specialization): + ids_of_const_exprs = tuple(src.fn.constexprs) if hasattr(src, "fn") else () ids = { - "ids_of_tensormaps": metadata["ids_of_tensormaps"], "ids_of_folded_args": metadata["ids_of_folded_args"], - "ids_of_const_exprs": ids_of_const_exprs + "ids_of_tensormaps": metadata.get("ids_of_tensormaps", tuple()), "ids_of_folded_args": + metadata.get("ids_of_folded_args", tuple()), "ids_of_const_exprs": ids_of_const_exprs } enable_warp_specialization = False # set constant - return make_stub(name, specialization.signature, specialization.constants, ids, + return make_stub(src.name, specialization.signature, specialization.constants, ids, enable_warp_specialization=enable_warp_specialization) @classmethod diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 9ff2dbff8c3f..a6f498111c51 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -980,12 +980,11 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): prototype = language.function_type([], arg_types) gscope = sys.modules[fn.fn.__module__].__dict__ # If the callee is not set, we use the same debug setting as the caller - debug = self.debug if fn.debug is None else fn.debug file_name, begin_line = _get_fn_file_line(fn) generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, - function_name=fn_name, function_types=self.function_ret_types, debug=debug, + function_name=fn_name, function_types=self.function_ret_types, noinline=fn.noinline, file_name=file_name, begin_line=begin_line, - target=self.builder.target) + options=self.builder.options) generator.visit(fn.parse()) callee_ret_type = generator.last_ret_type self.function_ret_types[fn_name] = callee_ret_type diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index a663d26dafae..e95b59dce088 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -3,7 +3,7 @@ import hashlib import json -from .._C.libtriton.triton import (get_env_vars) +from .._C.libtriton.triton import (get_env_vars, ir) # from ..runtime import driver, jit, JITFunction # TODO: runtime.errors from ..runtime.autotuner import OutOfResources @@ -14,6 +14,7 @@ from dataclasses import dataclass from .code_generator import ast_to_ttir from pathlib import Path +import re @dataclass @@ -45,7 +46,73 @@ def hash(self): return hashlib.md5(key.encode("utf-8")).hexdigest() -def compile(src, device_type=("cuda", None), signature=None, config=InstanceDescriptor(), constants=None, +# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, +# and any following whitespace +# - (public\s+)? : optionally match the keyword public and any following whitespace +# - (@\w+) : match an @ symbol followed by one or more word characters +# (letters, digits, or underscores), and capture it as group 1 (the function name) +# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing +# zero or more arguments separated by commas, and capture it as group 2 (the argument list) +# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 +mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" +ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" +prototype_pattern = { + "ttir": mlir_prototype_pattern, + "ttgir": mlir_prototype_pattern, + "ptx": ptx_prototype_pattern, +} + +mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<]+|<[^>]+>)+),?' +ptx_arg_type_pattern = r"\.param\s+\.(\w+)" +arg_type_pattern = { + "ttir": mlir_arg_type_pattern, + "ttgir": mlir_arg_type_pattern, + "ptx": ptx_arg_type_pattern, +} + + +def convert_type_repr(x): + # Currently we only capture the pointer type and assume the pointer is on global memory. + # TODO: Capture and support shared memory space + match = re.search(r'!tt\.ptr<([^,]+)', x) + if match is not None: + return '*' + convert_type_repr(match.group(1)) + return x + + +class SourceDescriptor: + + def __init__(self, src): + if isinstance(src, str): + src_path = Path(src) + self.path = src + self.ext = src_path.suffix[1:] + self.src = src_path.read_text() + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + # TODO: signature shouldn't be here + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + # TODO: number of warps + # if ir_name == 'ttgir': + # num_warps_from_ir = _get_num_warps_from_ir_str(src) + # assert "num_warps" not in kwargs or num_warps_from_ir == num_warps, "num_warps in ttgir does not match num_warps in compile" + # num_warps = num_warps_from_ir + self.is_ast = False + + else: + self.fn = src + self.name = src.__name__ + self.is_ast = True + + def hash(self): + if self.is_ast: + return self.fn.cache_key + return hashlib.md5(self.src.encode("utf-8")).hexdigest() + + +def compile(src, device_type=("cuda", 80), signature=None, config=InstanceDescriptor(), constants=None, extern_libs=None, **kwargs): # TODO (backward-breaking): # - merge InstanceDescriptor and SpecializationDescriptor @@ -53,46 +120,60 @@ def compile(src, device_type=("cuda", None), signature=None, config=InstanceDesc # - **kwargs -> compiler_flags: dict # create backend + src = SourceDescriptor(src) backend = CUDABackend(device_type) options = backend.parse_options(**kwargs) specialization = SpecializationDescriptor(config, signature, constants) # create cache manager - key = f"{src.cache_key}-{backend.hash()}-{specialization.hash()}-{options.hash()}-{frozenset(get_env_vars().items())}" + key = f"{src.hash()}-{backend.hash()}-{options.hash()}-{frozenset(sorted(get_env_vars().items()))}" + if src.is_ast: + key = f"{key}-{specialization.hash()}" + else: + # TODO: clean up + specialization.signature = src.signature hash = hashlib.md5(key.encode("utf-8")).hexdigest() fn_cache_manager = get_cache_manager(hash) - name = src.__name__ - metadata_filename = f"{name}.json" + metadata_filename = f"{src.name}.json" metadata_group = fn_cache_manager.get_group(metadata_filename) or {} metadata_path = metadata_group.get(metadata_filename) if metadata_path is not None: # cache hit! metadata = json.loads(Path(metadata_path).read_text()) - so_path = backend.make_launcher_stub(src, metadata, name, specialization) + so_path = backend.make_launcher_stub(src, metadata, specialization) return CompiledKernel(so_path, metadata_path) # initialize metadata metadata = { "constants": constants, "device_type": device_type, - "ids_of_folded_args": tuple([int(k) for k in config.ids_of_folded_args]), **options.__dict__, **get_env_vars(), } + if signature is not None: + metadata["ids_of_folded_args"] = tuple([int(k) for k in config.ids_of_folded_args]) # run compilation pipeline and populate metadata stages = dict() backend.add_stages(extern_libs, stages, options) - first_stage = list(stages.keys()).index("ttir") - module = ast_to_ttir(src, specialization, options=options) + # TODO: clean up + if src.is_ast: + first_stage = list(stages.keys()).index("ttir") + module = ast_to_ttir(src.fn, specialization, options=options) + else: + context = ir.context() + first_stage = list(stages.keys()).index(src.ext) + module = ir.parse_mlir_module(src.path, context) + module.context = context + for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) - metadata_group[f"{name}.{ext}"] = fn_cache_manager.put(next_module, f"{name}.{ext}") + metadata_group[f"{src.name}.{ext}"] = fn_cache_manager.put(next_module, f"{src.name}.{ext}") module = next_module # write-back metadata metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) - so_path = backend.make_launcher_stub(src, metadata, name, specialization) + so_path = backend.make_launcher_stub(src, metadata, specialization) # return handle to compiled kernel return CompiledKernel(so_path, metadata_group.get(metadata_filename)) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 60ee285bfa59..2a5a1e73d796 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1448,7 +1448,6 @@ def _promote_reduction_input(t, _builder=None): # hardware doesn't support FMAX, FMIN, CMP for bfloat16 if scalar_ty is bfloat16: return t.to(float32, _builder=_builder) - return t From eb61e52243d6548e6b34c8bfeba8ab88f645e6b3 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 25 Nov 2023 14:44:54 -0800 Subject: [PATCH 26/64] fixed bugs --- python/triton/compiler/backends/cuda.py | 9 +++-- python/triton/compiler/compiler.py | 49 ++++++++++++++++++++----- python/triton/language/semantic.py | 2 +- python/triton/runtime/driver.py | 8 ---- 4 files changed, 47 insertions(+), 21 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index fca2cce3a04e..4d6e2213c991 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -176,12 +176,15 @@ class CUDAOptions: enable_fp_fusion: bool = True extern_libs = None allow_fp8e4nv: bool = False - max_num_imprecise_acc: bool = None - + max_num_imprecise_acc_default: bool = None debug: bool = False + def __post_init__(self): + assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ + "num_warps must be a power of 2" + def hash(self): - key = '-'.join([str(x) for x in self.__dict__]) + key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()]) return hashlib.md5(key.encode("utf-8")).hexdigest() diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index e95b59dce088..cc73444bce44 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -8,7 +8,7 @@ # TODO: runtime.errors from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager -from ..runtime.jit import get_current_device +from ..runtime.jit import get_current_device, get_cuda_stream from ..runtime.driver import driver from .backends.cuda import CUDABackend from dataclasses import dataclass @@ -80,6 +80,26 @@ def convert_type_repr(x): return x +def _get_num_warps_from_ir_str(src: str): + ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' + # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if + # e.g. someone has an instruction (not module) attribute named "num-warps". + num_warps_matches = re.findall(ttgir_num_warps_pattern, src) + assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" + num_warps = int(num_warps_matches[0]) + + # If warp specialization is enabled, the true number of warps from + # the perspective of e.g. CUDA is num-warps times the number of + # specialized groups. + num_warp_groups_matches = re.findall(r'"triton_gpu.num-warp-groups-per-cta"\s?=\s?(\d+)\s?:', src) + assert len(num_warp_groups_matches) == 0 or len(num_warp_groups_matches) == 1, \ + "Expected triton_gpu.num-warp-groups-per-cta attribute to appear 0 or 1 times" + if num_warp_groups_matches: + num_warps *= int(num_warp_groups_matches[0]) + + return num_warps + + class SourceDescriptor: def __init__(self, src): @@ -94,11 +114,6 @@ def __init__(self, src): signature = match.group(2) types = re.findall(arg_type_pattern[self.ext], signature) self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} - # TODO: number of warps - # if ir_name == 'ttgir': - # num_warps_from_ir = _get_num_warps_from_ir_str(src) - # assert "num_warps" not in kwargs or num_warps_from_ir == num_warps, "num_warps in ttgir does not match num_warps in compile" - # num_warps = num_warps_from_ir self.is_ast = False else: @@ -106,6 +121,10 @@ def __init__(self, src): self.name = src.__name__ self.is_ast = True + def update_options(self, options): + if not self.is_ast and self.ext == "ttgir": + options.num_warps = _get_num_warps_from_ir_str(self.src) + def hash(self): if self.is_ast: return self.fn.cache_key @@ -124,6 +143,7 @@ def compile(src, device_type=("cuda", 80), signature=None, config=InstanceDescri backend = CUDABackend(device_type) options = backend.parse_options(**kwargs) specialization = SpecializationDescriptor(config, signature, constants) + src.update_options(options) # create cache manager key = f"{src.hash()}-{backend.hash()}-{options.hash()}-{frozenset(sorted(get_env_vars().items()))}" @@ -145,7 +165,6 @@ def compile(src, device_type=("cuda", 80), signature=None, config=InstanceDescri # initialize metadata metadata = { - "constants": constants, "device_type": device_type, **options.__dict__, **get_env_vars(), @@ -170,8 +189,7 @@ def compile(src, device_type=("cuda", 80), signature=None, config=InstanceDescri metadata_group[f"{src.name}.{ext}"] = fn_cache_manager.put(next_module, f"{src.name}.{ext}") module = next_module # write-back metadata - metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, - binary=False) + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata), metadata_filename, binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) so_path = backend.make_launcher_stub(src, metadata, specialization) # return handle to compiled kernel @@ -226,3 +244,16 @@ def __getattribute__(self, name): if name == 'c_wrapper': self._init_handles() return super().__getattribute__(name) + + def __getitem__(self, grid): + self._init_handles() + + def runner(*args, stream=None): + args_expand = driver.assemble_tensormap_to_arg(args) + if stream is None: + stream = get_cuda_stream() + self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.cluster_dims[0], + self.cluster_dims[1], self.cluster_dims[2], self.shared, stream, self.function, + CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand) + + return runner diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 09b91c2b2d81..9206f6d43c97 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1263,7 +1263,7 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): assert acc.type == ret_ty # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 - if lhs.dype.is_fp8() and rhs.dtype.is_fp8(): + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default if max_num_imprecise_acc is None: max_num_imprecise_acc = 2**30 diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index 59a809fb861a..61ce29521a02 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -110,14 +110,6 @@ def assemble_tensormap_to_arg(self, args): args_with_tma.append(CudaDriver.tensormap_manager[(e, args_ptr)]) return args_with_tma - def launch_kernel(self, kernel, stream, grid, launch_enter_hook, launch_exit_hook, *args): - args_expand = self.assemble_tensormap_to_arg(args) - if stream is None: - stream = self.get_stream(None) - kernel.c_wrapper(grid[0], grid[1], grid[2], kernel.num_warps, kernel.num_ctas, kernel.cluster_dims[0], - kernel.cluster_dims[1], kernel.cluster_dims[2], kernel.shared, stream, kernel.function, - launch_enter_hook, launch_exit_hook, kernel, *args_expand) - # ----------------------------- # HIP From c19a8f38d9a0733ee9e1009f89f80e752bab18ee Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 25 Nov 2023 14:49:52 -0800 Subject: [PATCH 27/64] . --- test/TritonGPU/rewrite-tensor-pointer-tma.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/TritonGPU/rewrite-tensor-pointer-tma.mlir b/test/TritonGPU/rewrite-tensor-pointer-tma.mlir index f5bf34177d85..b15b9a77b4f2 100644 --- a/test/TritonGPU/rewrite-tensor-pointer-tma.mlir +++ b/test/TritonGPU/rewrite-tensor-pointer-tma.mlir @@ -1,4 +1,4 @@ -// RUN: ENABLE_TMA=1 triton-opt %s -split-input-file -tritongpu-rewrite-tensor-pointer=compute-capability=90 | FileCheck %s +// RUN: ENABLE_TMA=1 triton-opt %s -split-input-file -tritongpu-rewrite-tensor-pointer | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { From c87a324cb274c61001f0de9003b400e4385e0cb7 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 25 Nov 2023 17:59:31 -0800 Subject: [PATCH 28/64] some cleaning --- python/triton/compiler/backends/cuda.py | 10 +-- python/triton/compiler/compiler.py | 96 ++++++++++++++----------- 2 files changed, 58 insertions(+), 48 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 4d6e2213c991..ecd9bbf3b401 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -251,17 +251,17 @@ def create_cubin(src, metadata): def hash(self): return f'{get_cuda_version_key()}-{self.capability}' - def make_launcher_stub(self, src, metadata, specialization): - ids_of_const_exprs = tuple(src.fn.constexprs) if hasattr(src, "fn") else () + def make_launcher_stub(self, src, metadata): ids = { "ids_of_tensormaps": metadata.get("ids_of_tensormaps", tuple()), "ids_of_folded_args": - metadata.get("ids_of_folded_args", tuple()), "ids_of_const_exprs": ids_of_const_exprs + metadata.get("ids_of_folded_args", + tuple()), "ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple() } + constants = src.constants if hasattr(src, "constants") else dict() enable_warp_specialization = False # set constant - return make_stub(src.name, specialization.signature, specialization.constants, ids, - enable_warp_specialization=enable_warp_specialization) + return make_stub(src.name, src.signature, constants, ids, enable_warp_specialization=enable_warp_specialization) @classmethod def create_backend(cls, device_type: str): diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index cc73444bce44..e9f60288919f 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -100,36 +100,59 @@ def _get_num_warps_from_ir_str(src: str): return num_warps -class SourceDescriptor: - - def __init__(self, src): - if isinstance(src, str): - src_path = Path(src) - self.path = src - self.ext = src_path.suffix[1:] - self.src = src_path.read_text() - match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) - self.name = match.group(1) - # TODO: signature shouldn't be here - signature = match.group(2) - types = re.findall(arg_type_pattern[self.ext], signature) - self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} - self.is_ast = False - - else: - self.fn = src - self.name = src.__name__ - self.is_ast = True +class ASTSource: + + def __init__(self, fn, signature, constants, config) -> None: + self.fn = fn + self.ext = "ttir" + self.name = fn.__name__ + self.signature = signature + self.constants = constants + self.config = config + if isinstance(self.signature, str): + self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} + if self.constants is None: + self.constants = dict() + + def hash(self): + key = f"{self.fn.cache_key}-{self.config.hash()}-{self.signature.values()}-{self.constants}" + return hashlib.md5(key.encode("utf-8")).hexdigest() + + def make_ir(self, options): + specialization = SpecializationDescriptor(self.config, self.signature, self.constants) + return ast_to_ttir(self.fn, specialization, options=options) def update_options(self, options): - if not self.is_ast and self.ext == "ttgir": - options.num_warps = _get_num_warps_from_ir_str(self.src) + pass + + +class IRSource: + + def __init__(self, path): + self.path = path + path = Path(path) + self.ext = path.suffix[1:] + self.src = path.read_text() + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + # TODO: signature shouldn't be here + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} def hash(self): - if self.is_ast: - return self.fn.cache_key return hashlib.md5(self.src.encode("utf-8")).hexdigest() + def make_ir(self, options): + context = ir.context() + module = ir.parse_mlir_module(self.path, context) + module.context = context + return module + + def update_options(self, options): + if self.ext == "ttgir": + options.num_warps = _get_num_warps_from_ir_str(self.src) + def compile(src, device_type=("cuda", 80), signature=None, config=InstanceDescriptor(), constants=None, extern_libs=None, **kwargs): @@ -139,19 +162,13 @@ def compile(src, device_type=("cuda", 80), signature=None, config=InstanceDescri # - **kwargs -> compiler_flags: dict # create backend - src = SourceDescriptor(src) + src = IRSource(src) if isinstance(src, str) else ASTSource(src, signature, constants, config) backend = CUDABackend(device_type) options = backend.parse_options(**kwargs) - specialization = SpecializationDescriptor(config, signature, constants) src.update_options(options) # create cache manager key = f"{src.hash()}-{backend.hash()}-{options.hash()}-{frozenset(sorted(get_env_vars().items()))}" - if src.is_ast: - key = f"{key}-{specialization.hash()}" - else: - # TODO: clean up - specialization.signature = src.signature hash = hashlib.md5(key.encode("utf-8")).hexdigest() fn_cache_manager = get_cache_manager(hash) metadata_filename = f"{src.name}.json" @@ -160,7 +177,7 @@ def compile(src, device_type=("cuda", 80), signature=None, config=InstanceDescri if metadata_path is not None: # cache hit! metadata = json.loads(Path(metadata_path).read_text()) - so_path = backend.make_launcher_stub(src, metadata, specialization) + so_path = backend.make_launcher_stub(src, metadata) return CompiledKernel(so_path, metadata_path) # initialize metadata @@ -174,16 +191,9 @@ def compile(src, device_type=("cuda", 80), signature=None, config=InstanceDescri # run compilation pipeline and populate metadata stages = dict() backend.add_stages(extern_libs, stages, options) - # TODO: clean up - if src.is_ast: - first_stage = list(stages.keys()).index("ttir") - module = ast_to_ttir(src.fn, specialization, options=options) - else: - context = ir.context() - first_stage = list(stages.keys()).index(src.ext) - module = ir.parse_mlir_module(src.path, context) - module.context = context - + # + first_stage = list(stages.keys()).index(src.ext) + module = src.make_ir(options) for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) metadata_group[f"{src.name}.{ext}"] = fn_cache_manager.put(next_module, f"{src.name}.{ext}") @@ -191,7 +201,7 @@ def compile(src, device_type=("cuda", 80), signature=None, config=InstanceDescri # write-back metadata metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata), metadata_filename, binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) - so_path = backend.make_launcher_stub(src, metadata, specialization) + so_path = backend.make_launcher_stub(src, metadata) # return handle to compiled kernel return CompiledKernel(so_path, metadata_group.get(metadata_filename)) From d9f8e0cdcad145ad247c4b70125fae099fb9b00c Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 25 Nov 2023 21:20:32 -0800 Subject: [PATCH 29/64] more cleaning --- python/triton/compiler/backends/cuda.py | 10 +++++----- python/triton/compiler/compiler.py | 14 +++++++------- python/triton/runtime/jit.py | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index ecd9bbf3b401..b0840408c75a 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -202,16 +202,16 @@ def parse_options(self, **opts) -> Any: return options def add_stages(self, extern_libs, stages, opt): - cluster_info = ClusterInfo() - if opt.cluster_dims is not None: - cluster_info.clusterDimX = opt.cluster_dims[0] - cluster_info.clusterDimY = opt.cluster_dims[1] - cluster_info.clusterDimZ = opt.cluster_dims[2] stages["ttir"] = lambda src, metadata: optimize_ttir(src, opt) # TTIR -> TTGIR stage def create_ttgir(src, metadata): + cluster_info = ClusterInfo() + if opt.cluster_dims is not None: + cluster_info.clusterDimX = opt.cluster_dims[0] + cluster_info.clusterDimY = opt.cluster_dims[1] + cluster_info.clusterDimZ = opt.cluster_dims[2] ttgir = ttir_to_ttgir(src, opt.num_warps, opt.num_ctas, self.capability) return optimize_ttgir(ttgir, opt.num_stages, opt.num_warps, opt.num_ctas, self.capability, cluster_info, opt.enable_warp_specialization, opt.enable_persistent, opt.optimize_epilogue) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index e9f60288919f..70544ee898f1 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -135,7 +135,6 @@ def __init__(self, path): self.src = path.read_text() match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) self.name = match.group(1) - # TODO: signature shouldn't be here signature = match.group(2) types = re.findall(arg_type_pattern[self.ext], signature) self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} @@ -172,8 +171,8 @@ def compile(src, device_type=("cuda", 80), signature=None, config=InstanceDescri hash = hashlib.md5(key.encode("utf-8")).hexdigest() fn_cache_manager = get_cache_manager(hash) metadata_filename = f"{src.name}.json" - metadata_group = fn_cache_manager.get_group(metadata_filename) or {} - metadata_path = metadata_group.get(metadata_filename) + cache_group = fn_cache_manager.get_group(metadata_filename) or {} + metadata_path = cache_group.get(metadata_filename) if metadata_path is not None: # cache hit! metadata = json.loads(Path(metadata_path).read_text()) @@ -186,6 +185,7 @@ def compile(src, device_type=("cuda", 80), signature=None, config=InstanceDescri **options.__dict__, **get_env_vars(), } + # TODO: remove once TMA support is cleaned up if signature is not None: metadata["ids_of_folded_args"] = tuple([int(k) for k in config.ids_of_folded_args]) # run compilation pipeline and populate metadata @@ -196,14 +196,14 @@ def compile(src, device_type=("cuda", 80), signature=None, config=InstanceDescri module = src.make_ir(options) for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) - metadata_group[f"{src.name}.{ext}"] = fn_cache_manager.put(next_module, f"{src.name}.{ext}") + cache_group[f"{src.name}.{ext}"] = fn_cache_manager.put(next_module, f"{src.name}.{ext}") module = next_module # write-back metadata - metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata), metadata_filename, binary=False) - fn_cache_manager.put_group(metadata_filename, metadata_group) + cache_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata), metadata_filename, binary=False) + fn_cache_manager.put_group(metadata_filename, cache_group) so_path = backend.make_launcher_stub(src, metadata) # return handle to compiled kernel - return CompiledKernel(so_path, metadata_group.get(metadata_filename)) + return CompiledKernel(so_path, cache_group.get(metadata_filename)) class CompiledKernel: diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index bb185da30ca6..bb723a53057a 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -400,7 +400,7 @@ def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: Li return device_types[0] if len(device_types) > 0 else "cuda" def run(self, *args, **kwargs): - from ..compiler import CompiledKernel, compile, InstanceDescriptor + from ..compiler import CompiledKernel, compile # Get a compiler-flags arg like `num_warps` and remove it from kwargs. def get_special_arg(name: str, default=None): From 2776a832f74ff76fbb58f017fc0829663d9e1b50 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 25 Nov 2023 21:43:05 -0800 Subject: [PATCH 30/64] more cleaning --- python/triton/compiler/backends/cuda.py | 271 ++++++++++-------------- 1 file changed, 115 insertions(+), 156 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index b0840408c75a..15f61d60b7ae 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -10,28 +10,6 @@ import hashlib -def optimize_ttir(mod, options): - pm = ir.pass_manager(mod.context) - pm.enable_debug() - pm.add_inliner_pass() - pm.add_triton_combine_pass() - pm.add_canonicalizer_pass() - pm.add_reorder_broadcast_pass() - pm.add_cse_pass() - pm.add_licm_pass() - pm.add_symbol_dce_pass() - pm.run(mod) - return mod - - -def ttir_to_ttgir(mod, num_warps, num_ctas, capability): - pm = ir.pass_manager(mod.context) - pm.enable_debug() - pm.add_convert_triton_to_tritongpu_pass(num_warps, 32, num_ctas, capability) - pm.run(mod) - return mod - - def parse_mlir_module(path, context): module = ir.parse_mlir_module(path, context) # module takes ownership of the context @@ -52,63 +30,6 @@ def get_kernel_name(src: str, pattern: str) -> str: return line.split()[-1] -def optimize_ttgir(mod, num_stages, num_warps, num_ctas, capability, cluster_info, enable_warp_specialization, - enable_persistent, optimize_epilogue): - pm = ir.pass_manager(mod.context) - pm.enable_debug() - pm.add_tritongpu_coalesce_pass() - # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass - pm.add_plan_cta_pass(cluster_info) - if capability // 10 < 9: - pm.add_tritongpu_rewrite_tensor_pointer_pass() - pm.add_plan_cta_pass(cluster_info) - pm.add_tritongpu_remove_layout_conversions_pass() - pm.add_tritongpu_accelerate_matmul_pass(capability) - pm.add_tritongpu_remove_layout_conversions_pass() - if optimize_epilogue: - pm.add_tritongpu_optimize_epilogue_pass() - pm.add_tritongpu_optimize_dot_operands_pass() - pm.add_cse_pass() - ws_enabled = False - # `num_warps` does not mean the total number of warps of a CTA when - # warp specialization is enabled. - # it's the responsibility of the compiler to figure out the exact - # `num_warps` to use. - # TODO: support the case where `num_warps` from user is not 4. - if capability // 10 >= 9 and enable_warp_specialization and num_warps == 4: - pm.add_tritongpu_ws_feasibility_checking_pass(capability) - pm.run(mod) - ws_enabled = ir.is_ws_supported(mod) - pm = ir.pass_manager(mod.context) - pm.enable_debug() - if ws_enabled: - pm.add_tritongpu_wsdecomposing_pass(capability) - pm.add_tritongpu_wspipeline_pass(num_stages, num_warps, capability) - pm.add_tritongpu_wsmutex_pass(capability) - pm.add_tritongpu_wsmaterialization_pass(capability) - pm.add_licm_pass() - pm.add_cse_pass() - else: - pm.add_tritongpu_pipeline_pass(num_stages, num_warps, num_ctas, capability) - pm.add_tritongpu_materialize_load_store_pass(num_warps, capability) - if capability // 10 <= 8: - pm.add_tritongpu_prefetch_pass() - pm.add_tritongpu_optimize_dot_operands_pass() - pm.add_tritongpu_remove_layout_conversions_pass() - pm.add_tritongpu_decompose_conversions_pass() - pm.add_tritongpu_ws_fixup_missing_attrs_pass() - pm.add_tritongpu_reorder_instructions_pass() - pm.add_cse_pass() - pm.add_symbol_dce_pass() - if capability // 10 >= 9: - pm.add_tritongpu_fence_insertion_pass() - pm.add_tritongpu_ws_fixup_missing_attrs_pass() - pm.add_tritongpu_optimize_thread_locality_pass() - pm.add_canonicalizer_pass() - pm.run(mod) - return mod - - def _add_external_libs(mod, libs): for name, path in libs.items(): if len(name) == 0 or len(path) == 0: @@ -116,15 +37,6 @@ def _add_external_libs(mod, libs): add_external_libs(mod, list(libs.keys()), list(libs.values())) -def ttgir_to_llir(mod, extern_libs, capability, tma_infos): - if extern_libs: - _add_external_libs(mod, extern_libs) - return translate_triton_gpu_to_llvmir(mod, capability, tma_infos, runtime.TARGET.NVVM) - - -# PTX translation - - @functools.lru_cache() def ptx_get_version(cuda_version) -> int: ''' @@ -141,35 +53,13 @@ def ptx_get_version(cuda_version) -> int: raise RuntimeError("Triton only support CUDA 10.0 or higher") -def llir_to_ptx(mod: Any, enable_fp_fusion: bool, capability: int, ptx_version: int = None) -> str: - ''' - Translate TritonGPU module to PTX code. - :param mod: a TritonGPU dialect module - :return: PTX code - ''' - if ptx_version is None: - _, cuda_version = path_to_ptxas() - ptx_version = ptx_get_version(cuda_version) - return translate_llvmir_to_ptx(mod, capability, ptx_version, enable_fp_fusion) - - -def ptx_to_cubin(ptx: str, capability: int, enable_fp_fusion: bool): - ''' - Compile TritonGPU module to cubin. - :param ptx: ptx code - :param compute_capability: compute capability - :return: str - ''' - ptxas, _ = path_to_ptxas() - return compile_ptx_to_cubin(ptx, ptxas, capability, enable_fp_fusion) - - @dataclass class CUDAOptions: num_warps: int = 4 num_ctas: int = 1 num_stages: int = 3 cluster_dims: tuple = (1, 1, 1) + ptx_version: int = None enable_warp_specialization: bool = False enable_persistent: bool = False optimize_epilogue: bool = False @@ -201,52 +91,121 @@ def parse_options(self, **opts) -> Any: options.max_num_imprecise_acc = 0 if self.capability >= 89 else None return options - def add_stages(self, extern_libs, stages, opt): - - stages["ttir"] = lambda src, metadata: optimize_ttir(src, opt) - - # TTIR -> TTGIR stage - def create_ttgir(src, metadata): - cluster_info = ClusterInfo() - if opt.cluster_dims is not None: - cluster_info.clusterDimX = opt.cluster_dims[0] - cluster_info.clusterDimY = opt.cluster_dims[1] - cluster_info.clusterDimZ = opt.cluster_dims[2] - ttgir = ttir_to_ttgir(src, opt.num_warps, opt.num_ctas, self.capability) - return optimize_ttgir(ttgir, opt.num_stages, opt.num_warps, opt.num_ctas, self.capability, cluster_info, - opt.enable_warp_specialization, opt.enable_persistent, opt.optimize_epilogue) - - stages["ttgir"] = create_ttgir - - # TTGIR -> LLIR stage - - def create_llir(src, metadata): - metadata["enable_warp_specialization"] = ir.is_ws_supported(src) - metadata["num_warps"] = get_num_warps(src) - tma_infos = TMAInfos() - ret = ttgir_to_llir(src, opt.extern_libs, self.capability, tma_infos) - if len(tma_infos) > 0: - metadata["tensormaps_info"] = parse_tma_info(tma_infos, metadata["ids_of_folded_args"]) - for i, _ in enumerate(metadata["tensormaps_info"]): - metadata["tensormaps_info"][i].ids_of_folded_args = metadata["ids_of_folded_args"] - metadata["ids_of_tensormaps"] = get_ids_of_tensormaps(metadata.get("tensormaps_info", None)) - metadata["shared"] = get_shared_memory_size(src) - return ret - - stages["llir"] = create_llir - - # LLIR -> PTX stage - def create_ptx(src, metadata): - return llir_to_ptx(src, opt.enable_fp_fusion, self.capability) - - stages["ptx"] = create_ptx - - # PTX -> CUBIN stage - def create_cubin(src, metadata): - metadata["name"] = get_kernel_name(src, pattern='// .globl') - return ptx_to_cubin(src, self.capability, opt.enable_fp_fusion) + @staticmethod + def make_ttir(mod, metadata, opt): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + pm.add_inliner_pass() + pm.add_triton_combine_pass() + pm.add_canonicalizer_pass() + pm.add_reorder_broadcast_pass() + pm.add_cse_pass() + pm.add_licm_pass() + pm.add_symbol_dce_pass() + pm.run(mod) + return mod + + @staticmethod + def make_ttgir(mod, metadata, opt, capability): + cluster_info = ClusterInfo() + if opt.cluster_dims is not None: + cluster_info.clusterDimX = opt.cluster_dims[0] + cluster_info.clusterDimY = opt.cluster_dims[1] + cluster_info.clusterDimZ = opt.cluster_dims[2] + # TTIR -> TTGIR + pm = ir.pass_manager(mod.context) + pm.enable_debug() + pm.add_convert_triton_to_tritongpu_pass(opt.num_warps, 32, opt.num_ctas, capability) + # optimize TTGIR + pm.add_tritongpu_coalesce_pass() + # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass + pm.add_plan_cta_pass(cluster_info) + if capability // 10 < 9: + pm.add_tritongpu_rewrite_tensor_pointer_pass() + pm.add_plan_cta_pass(cluster_info) + pm.add_tritongpu_remove_layout_conversions_pass() + pm.add_tritongpu_accelerate_matmul_pass(capability) + pm.add_tritongpu_remove_layout_conversions_pass() + if opt.optimize_epilogue: + pm.add_tritongpu_optimize_epilogue_pass() + pm.add_tritongpu_optimize_dot_operands_pass() + pm.add_cse_pass() + ws_enabled = False + # `num_warps` does not mean the total number of warps of a CTA when + # warp specialization is enabled. + # it's the responsibility of the compiler to figure out the exact + # `num_warps` to use. + # TODO: support the case where `num_warps` from user is not 4. + if capability // 10 >= 9 and opt.enable_warp_specialization and opt.num_warps == 4: + pm.add_tritongpu_ws_feasibility_checking_pass(capability) + pm.run(mod) + ws_enabled = ir.is_ws_supported(mod) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + if ws_enabled: + pm.add_tritongpu_wsdecomposing_pass(capability) + pm.add_tritongpu_wspipeline_pass(opt.num_stages, opt.num_warps, capability) + pm.add_tritongpu_wsmutex_pass(capability) + pm.add_tritongpu_wsmaterialization_pass(capability) + pm.add_licm_pass() + pm.add_cse_pass() + else: + pm.add_tritongpu_pipeline_pass(opt.num_stages, opt.num_warps, opt.num_ctas, capability) + pm.add_tritongpu_materialize_load_store_pass(opt.num_warps, capability) + if capability // 10 <= 8: + pm.add_tritongpu_prefetch_pass() + pm.add_tritongpu_optimize_dot_operands_pass() + pm.add_tritongpu_remove_layout_conversions_pass() + pm.add_tritongpu_decompose_conversions_pass() + pm.add_tritongpu_ws_fixup_missing_attrs_pass() + pm.add_tritongpu_reorder_instructions_pass() + pm.add_cse_pass() + pm.add_symbol_dce_pass() + if capability // 10 >= 9: + pm.add_tritongpu_fence_insertion_pass() + pm.add_tritongpu_ws_fixup_missing_attrs_pass() + pm.add_tritongpu_optimize_thread_locality_pass() + pm.add_canonicalizer_pass() + pm.run(mod) + return mod + + @staticmethod + def make_llir(src, metadata, opt, capability): + metadata["enable_warp_specialization"] = ir.is_ws_supported(src) + metadata["num_warps"] = get_num_warps(src) + tma_infos = TMAInfos() + + if opt.extern_libs: + _add_external_libs(src, opt.extern_libs) + ret = translate_triton_gpu_to_llvmir(src, capability, tma_infos, runtime.TARGET.NVVM) + + if len(tma_infos) > 0: + metadata["tensormaps_info"] = parse_tma_info(tma_infos, metadata["ids_of_folded_args"]) + for i, _ in enumerate(metadata["tensormaps_info"]): + metadata["tensormaps_info"][i].ids_of_folded_args = metadata["ids_of_folded_args"] + metadata["ids_of_tensormaps"] = get_ids_of_tensormaps(metadata.get("tensormaps_info", None)) + metadata["shared"] = get_shared_memory_size(src) + return ret + + @staticmethod + def make_ptx(src, metadata, opt, capability): + ptx_version = opt.ptx_version + if ptx_version is None: + _, cuda_version = path_to_ptxas() + ptx_version = ptx_get_version(cuda_version) + return translate_llvmir_to_ptx(src, capability, ptx_version, opt.enable_fp_fusion) + + @staticmethod + def make_cubin(src, metadata, opt, capability): + ptxas, _ = path_to_ptxas() + return compile_ptx_to_cubin(src, ptxas, capability, opt.enable_fp_fusion) - stages["cubin"] = create_cubin + def add_stages(self, extern_libs, stages, opt): + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, opt) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, opt, self.capability) + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, opt, self.capability) + stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, opt, self.capability) + stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, opt, self.capability) def hash(self): return f'{get_cuda_version_key()}-{self.capability}' From 33f0480660c4d5fe3b65b78a9f2ab60c31be193e Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 25 Nov 2023 21:47:40 -0800 Subject: [PATCH 31/64] cleaning --- python/triton/compiler/backends/cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 15f61d60b7ae..97c32b9c0035 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -197,6 +197,7 @@ def make_ptx(src, metadata, opt, capability): @staticmethod def make_cubin(src, metadata, opt, capability): + metadata["name"] = get_kernel_name(src, pattern='// .globl') ptxas, _ = path_to_ptxas() return compile_ptx_to_cubin(src, ptxas, capability, opt.enable_fp_fusion) From aeeae566183988b239355d964eab1a50dcc8457d Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 25 Nov 2023 21:48:33 -0800 Subject: [PATCH 32/64] more cleaning --- python/triton/compiler/backends/cuda.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 97c32b9c0035..cedbc34753b0 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -10,13 +10,6 @@ import hashlib -def parse_mlir_module(path, context): - module = ir.parse_mlir_module(path, context) - # module takes ownership of the context - module.context = context - return module - - def get_kernel_name(src: str, pattern: str) -> str: ''' Get kernel name from PTX code. From 9eda8acc523037e46ef9487c593bf259f8e3c1e4 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 26 Nov 2023 14:50:07 -0800 Subject: [PATCH 33/64] temporary disable TMA tests to see if the rest works --- .github/workflows/integration-tests.yml | 15 +++++++-------- python/triton/compiler/backends/cuda.py | 9 +-------- python/triton/compiler/compiler.py | 12 ++++++------ 3 files changed, 14 insertions(+), 22 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 216609bb783a..04fab929fe26 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -34,7 +34,6 @@ jobs: echo '::set-output name=matrix-optional::["ubuntu-latest"]' fi - Integration-Tests: needs: Runner-Preparation @@ -49,7 +48,7 @@ jobs: - name: Checkout uses: actions/checkout@v3 with: - submodules: 'true' + submodules: "true" - name: Set CUDA ENV if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}} run: | @@ -99,15 +98,15 @@ jobs: - name: Run python tests on CUDA with ENABLE_TMA=1 if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1'}} run: | - cd python/test/unit - python3 -m pytest -n 8 --ignore=runtime --ignore=operators --ignore=language/test_line_info.py --ignore=language/test_subprocess.py - python3 -m pytest -n 8 language/test_subprocess.py + #cd python/test/unit + #python3 -m pytest -n 8 --ignore=runtime --ignore=operators --ignore=language/test_line_info.py --ignore=language/test_subprocess.py + #python3 -m pytest -n 8 language/test_subprocess.py # run runtime tests serially to avoid race condition with cache handling. - python3 -m pytest runtime/ + #python3 -m pytest runtime/ # run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 - TRITON_DISABLE_LINE_INFO=0 python3 -m pytest language/test_line_info.py + #TRITON_DISABLE_LINE_INFO=0 python3 -m pytest language/test_line_info.py #run hopper/test_flashattention.py to avoid out of gpu memory - python3 -m pytest hopper/test_flashattention.py + #python3 -m pytest hopper/test_flashattention.py - name: Run python tests on CUDA with ENABLE_TMA=0 if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0'}} diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index cedbc34753b0..94057cc4f51b 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -23,13 +23,6 @@ def get_kernel_name(src: str, pattern: str) -> str: return line.split()[-1] -def _add_external_libs(mod, libs): - for name, path in libs.items(): - if len(name) == 0 or len(path) == 0: - return - add_external_libs(mod, list(libs.keys()), list(libs.values())) - - @functools.lru_cache() def ptx_get_version(cuda_version) -> int: ''' @@ -169,7 +162,7 @@ def make_llir(src, metadata, opt, capability): tma_infos = TMAInfos() if opt.extern_libs: - _add_external_libs(src, opt.extern_libs) + add_external_libs(src, list(opt.extern_libs.keys()), list(opt.extern_libs.values())) ret = translate_triton_gpu_to_llvmir(src, capability, tma_infos, runtime.TARGET.NVVM) if len(tma_infos) > 0: diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 70544ee898f1..156634fef28c 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -171,8 +171,8 @@ def compile(src, device_type=("cuda", 80), signature=None, config=InstanceDescri hash = hashlib.md5(key.encode("utf-8")).hexdigest() fn_cache_manager = get_cache_manager(hash) metadata_filename = f"{src.name}.json" - cache_group = fn_cache_manager.get_group(metadata_filename) or {} - metadata_path = cache_group.get(metadata_filename) + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} + metadata_path = metadata_group.get(metadata_filename) if metadata_path is not None: # cache hit! metadata = json.loads(Path(metadata_path).read_text()) @@ -196,14 +196,14 @@ def compile(src, device_type=("cuda", 80), signature=None, config=InstanceDescri module = src.make_ir(options) for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) - cache_group[f"{src.name}.{ext}"] = fn_cache_manager.put(next_module, f"{src.name}.{ext}") + metadata_group[f"{src.name}.{ext}"] = fn_cache_manager.put(next_module, f"{src.name}.{ext}") module = next_module # write-back metadata - cache_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata), metadata_filename, binary=False) - fn_cache_manager.put_group(metadata_filename, cache_group) + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata), metadata_filename, binary=False) + fn_cache_manager.put_group(metadata_filename, metadata_group) so_path = backend.make_launcher_stub(src, metadata) # return handle to compiled kernel - return CompiledKernel(so_path, cache_group.get(metadata_filename)) + return CompiledKernel(so_path, metadata_group.get(metadata_filename)) class CompiledKernel: From bb849fba4608bafc27be7141e1e4adce66afdc2c Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 26 Nov 2023 15:07:29 -0800 Subject: [PATCH 34/64] . --- python/triton/compiler/backends/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 python/triton/compiler/backends/__init__.py diff --git a/python/triton/compiler/backends/__init__.py b/python/triton/compiler/backends/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From 93796785a71ea6c3a2a20d423e4c14dd53a15578 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 26 Nov 2023 15:08:00 -0800 Subject: [PATCH 35/64] . --- python/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/setup.py b/python/setup.py index 7b3b2521282c..f5344e6c53c8 100644 --- a/python/setup.py +++ b/python/setup.py @@ -353,6 +353,7 @@ def build_extension(self, ext): "triton/_C", "triton/common", "triton/compiler", + "triton/compiler/backends", "triton/language", "triton/language/extra", "triton/ops", From 3812aed008741818e3196a04ea14a849b9ca8a4a Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 26 Nov 2023 15:12:05 -0800 Subject: [PATCH 36/64] comment out interpreter tests --- .github/workflows/integration-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 04fab929fe26..ef92dcdb9874 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -128,7 +128,7 @@ jobs: CUA_VISIBLE_DEVICES: "" run: | cd python/test/unit - python3 -m pytest -vs operators/test_flash_attention.py + #python3 -m pytest -vs operators/test_flash_attention.py - name: Run partial tests on CUDA with ENABLE_TMA=1 if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1'}} From 4686fcc3fd05f8194cedeee2d2e694049d80c44b Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 26 Nov 2023 15:16:58 -0800 Subject: [PATCH 37/64] . --- .github/workflows/integration-tests.yml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index ef92dcdb9874..b19110ba8547 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -27,7 +27,7 @@ jobs: id: set-matrix run: | if [ x"${{ github.repository }}" == x"openai/triton" ]; then - echo '::set-output name=matrix-required::[["self-hosted", "A100"], ["self-hosted", "H100"]]' + echo '::set-output name=matrix-required::[["self-hosted", "A100"]]' echo '::set-output name=matrix-optional::[["self-hosted", "gfx908"], ["self-hosted", "arc770"]]' else echo '::set-output name=matrix-required::["ubuntu-latest"]' @@ -98,15 +98,15 @@ jobs: - name: Run python tests on CUDA with ENABLE_TMA=1 if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1'}} run: | - #cd python/test/unit - #python3 -m pytest -n 8 --ignore=runtime --ignore=operators --ignore=language/test_line_info.py --ignore=language/test_subprocess.py - #python3 -m pytest -n 8 language/test_subprocess.py + cd python/test/unit + python3 -m pytest -n 8 --ignore=runtime --ignore=operators --ignore=language/test_line_info.py --ignore=language/test_subprocess.py + python3 -m pytest -n 8 language/test_subprocess.py # run runtime tests serially to avoid race condition with cache handling. - #python3 -m pytest runtime/ + python3 -m pytest runtime/ # run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 - #TRITON_DISABLE_LINE_INFO=0 python3 -m pytest language/test_line_info.py + TRITON_DISABLE_LINE_INFO=0 python3 -m pytest language/test_line_info.py #run hopper/test_flashattention.py to avoid out of gpu memory - #python3 -m pytest hopper/test_flashattention.py + python3 -m pytest hopper/test_flashattention.py - name: Run python tests on CUDA with ENABLE_TMA=0 if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0'}} @@ -128,7 +128,7 @@ jobs: CUA_VISIBLE_DEVICES: "" run: | cd python/test/unit - #python3 -m pytest -vs operators/test_flash_attention.py + python3 -m pytest -vs operators/test_flash_attention.py - name: Run partial tests on CUDA with ENABLE_TMA=1 if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1'}} From f1d2820a19ab7cd0ccfc3a13f653ae894b4ebaab Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 26 Nov 2023 15:50:02 -0800 Subject: [PATCH 38/64] . --- python/triton/__init__.py | 1 + python/triton/compiler/compiler.py | 18 +++++++++++++++--- python/triton/runtime/jit.py | 2 +- python/triton/tools/compile.py | 2 +- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 55484acd5bf2..3cda7668e50e 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -21,6 +21,7 @@ from . import language from . import testing +from . import tools __all__ = [ "autotune", diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 156634fef28c..0877f3b1b6fd 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -24,6 +24,16 @@ class InstanceDescriptor: ids_of_folded_args: set = None divisible_by_8: set = None + def __post_init__(self): + if self.divisible_by_16 is None: + self.divisible_by_16 = set() + if self.equal_to_1 is None: + self.equal_to_1 = set() + if self.ids_of_folded_args is None: + self.ids_of_folded_args = set() + if self.divisible_by_8 is None: + self.divisible_by_8 = set() + def hash(self): key = str([sorted(x) for x in self.__dict__.values()]) return hashlib.md5(key.encode("utf-8")).hexdigest() @@ -153,13 +163,15 @@ def update_options(self, options): options.num_warps = _get_num_warps_from_ir_str(self.src) -def compile(src, device_type=("cuda", 80), signature=None, config=InstanceDescriptor(), constants=None, - extern_libs=None, **kwargs): +def compile(src, device_type=("cuda", 80), signature=None, configs=None, constants=None, extern_libs=None, **kwargs): # TODO (backward-breaking): # - merge InstanceDescriptor and SpecializationDescriptor + # - no more configs # - extern_libs => linker_flags: dict # - **kwargs -> compiler_flags: dict - + configs = [InstanceDescriptor()] if configs is None else configs + assert len(configs) == 1 + config = configs[0] # create backend src = IRSource(src) if isinstance(src, str) else ASTSource(src, signature, constants, config) backend = CUDABackend(device_type) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index bb723a53057a..f43b83a9fe69 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -542,7 +542,7 @@ def get_special_arg(name: str, default=None): enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, - config=configs[0], + configs=[configs[0]], debug=self.debug, ) diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index a69c7100ddd0..02ca129e46e8 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -104,7 +104,7 @@ def constexpr(s): assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" divisible_by_16 = [i for i, h in hints.items() if h == 16] equal_to_1 = [i for i, h in hints.items() if h == 1] - config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) + config = triton.compiler.InstanceDescriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) for i in equal_to_1: constexprs.update({i: 1}) ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], From 693cfe2aa69e965fcf6ff4b1bd76e7aac0cbdec6 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 26 Nov 2023 16:11:05 -0800 Subject: [PATCH 39/64] . --- python/triton/compiler/code_generator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index a6f498111c51..b5153a5b8542 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -981,6 +981,8 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): gscope = sys.modules[fn.fn.__module__].__dict__ # If the callee is not set, we use the same debug setting as the caller file_name, begin_line = _get_fn_file_line(fn) + options = self.builder.options + options.debug = self.debug if fn.debug is None else fn.debug generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, noinline=fn.noinline, file_name=file_name, begin_line=begin_line, From 6a7dc51e07b413d28f6f0f6c84105411a0d4b8b6 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 26 Nov 2023 17:07:24 -0800 Subject: [PATCH 40/64] . --- python/test/unit/tools/test_aot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/test/unit/tools/test_aot.py b/python/test/unit/tools/test_aot.py index 92b5562e9527..b76ba71997e2 100644 --- a/python/test/unit/tools/test_aot.py +++ b/python/test/unit/tools/test_aot.py @@ -446,7 +446,7 @@ def test_ttgir_to_ptx(): kernel_path = os.path.join(tmp_dir, "empty_kernel.ttgir") with open(kernel_path, "w") as fp: fp.write(src) - k = triton.compile(kernel_path, cc=80) + k = triton.compile(kernel_path, device_type=("cuda", 80)) ptx = k.asm["ptx"] assert ".target sm_80" in ptx assert ".address_size 64" in ptx From 1e00e4a5aa4bd65f4f165a23a1ad0b8a475e9eb1 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 26 Nov 2023 21:50:54 -0800 Subject: [PATCH 41/64] . --- python/test/unit/runtime/test_subproc.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index f1039d011e2c..c19fd88e1e36 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -1,7 +1,6 @@ import multiprocessing import os import shutil -from collections import namedtuple import torch @@ -17,10 +16,6 @@ def reset_tmp_dir(): shutil.rmtree(tmpdir, ignore_errors=True) -instance_descriptor = namedtuple("instance_descriptor", - ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"]) - - def compile_fn(config, cc): @triton.jit @@ -29,20 +24,18 @@ def kernel_sub(a, b, o, N: tl.constexpr): tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) triton.compile( - fn=kernel_sub, + src=kernel_sub, signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, - device=0, + device_type=("cuda", cc), constants={3: 32}, configs=[config], - warm_cache_only=True, - cc=cc, ) def test_compile_in_subproc() -> None: major, minor = torch.cuda.get_device_capability(0) cc = major * 10 + minor - config = instance_descriptor(tuple(range(4)), (), (), ()) + config = triton.compiler.InstanceDescriptor(tuple(range(4)), (), (), ()) multiprocessing.set_start_method('fork') proc = multiprocessing.Process(target=compile_fn, args=(config, cc)) @@ -61,12 +54,10 @@ def kernel_dot(Z): tl.store(Z + offs, z) triton.compile( - fn=kernel_dot, + src=kernel_dot, signature={0: "*fp32"}, - device=0, + device_type=("cuda", cc), configs=[config], - warm_cache_only=True, - cc=cc, ) @@ -74,7 +65,7 @@ def test_compile_in_forked_subproc() -> None: reset_tmp_dir() major, minor = torch.cuda.get_device_capability(0) cc = major * 10 + minor - config = instance_descriptor(tuple(range(1)), (), (), ()) + config = triton.compiler.InstanceDescriptor(tuple(range(1)), (), (), ()) assert multiprocessing.get_start_method() == 'fork' proc = multiprocessing.Process(target=compile_fn_dot, args=(config, cc)) From 0c69111ad1ec25fe2c17ce7c5155341d99675807 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 27 Nov 2023 18:22:33 -0800 Subject: [PATCH 42/64] . --- python/triton/compiler/backends/cuda.py | 2 +- python/triton/language/semantic.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 94057cc4f51b..65ee23054147 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -74,7 +74,7 @@ def __init__(self, device_type: tuple) -> None: def parse_options(self, **opts) -> Any: options = CUDAOptions(**opts) options.allow_fp8e4nv = self.capability >= 89 - options.max_num_imprecise_acc = 0 if self.capability >= 89 else None + options.max_num_imprecise_acc_default = 0 if self.capability >= 89 else None return options @staticmethod diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 9206f6d43c97..066d9d9a0be1 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1263,10 +1263,11 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): assert acc.type == ret_ty # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 + max_num_imprecise_acc = 0 if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default - if max_num_imprecise_acc is None: - max_num_imprecise_acc = 2**30 + if max_num_imprecise_acc is None: + max_num_imprecise_acc = 2**30 return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty) From e5a74705b64e04ce53e1c6c738565b816af7305b Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 27 Nov 2023 20:47:41 -0800 Subject: [PATCH 43/64] . --- .github/workflows/integration-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index b19110ba8547..b13c74801ee3 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -27,7 +27,7 @@ jobs: id: set-matrix run: | if [ x"${{ github.repository }}" == x"openai/triton" ]; then - echo '::set-output name=matrix-required::[["self-hosted", "A100"]]' + echo '::set-output name=matrix-required::[["self-hosted", "A100"], ["self-hosted", "H100"]]' echo '::set-output name=matrix-optional::[["self-hosted", "gfx908"], ["self-hosted", "arc770"]]' else echo '::set-output name=matrix-required::["ubuntu-latest"]' From 2331078e249c286ae6a298f2b2929b3f851bd8a7 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 27 Nov 2023 21:29:52 -0800 Subject: [PATCH 44/64] fixup --- python/triton/language/semantic.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 066d9d9a0be1..b521c8483923 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -641,8 +641,7 @@ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar - if builder.options.allow_fp8e4nv: - assert False, "fp8e4nv data type is not supported on CUDA arch < 89" + assert builder.options.allow_fp8e4nv, "fp8e4nv data type is not supported on CUDA arch < 89" # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ From d00a7b31e40c14f03731aa6f52ac6a332d167f2f Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 27 Nov 2023 22:12:56 -0800 Subject: [PATCH 45/64] . --- python/triton/compiler/compiler.py | 8 +++++++- python/triton/compiler/utils.py | 1 + python/triton/runtime/driver.py | 6 +++--- python/triton/runtime/jit.py | 2 +- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 0877f3b1b6fd..75dd69f7d74f 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -10,6 +10,7 @@ from ..runtime.cache import get_cache_manager from ..runtime.jit import get_current_device, get_cuda_stream from ..runtime.driver import driver +from .utils import InfoFromBackendForTensorMap from .backends.cuda import CUDABackend from dataclasses import dataclass from .code_generator import ast_to_ttir @@ -211,7 +212,8 @@ def compile(src, device_type=("cuda", 80), signature=None, configs=None, constan metadata_group[f"{src.name}.{ext}"] = fn_cache_manager.put(next_module, f"{src.name}.{ext}") module = next_module # write-back metadata - metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata), metadata_filename, binary=False) + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, + binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) so_path = backend.make_launcher_stub(src, metadata) # return handle to compiled kernel @@ -235,6 +237,10 @@ def __init__(self, so_path, metadata_path): self.c_wrapper = getattr(mod, "launch") # initialize metadata self.metadata = json.loads(metadata_path.read_text()) + self.metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in self.metadata['tensormaps_info'] + ] if 'tensormaps_info' in self.metadata else [] + for i, _ in enumerate(self.metadata["tensormaps_info"]): + self.metadata["tensormaps_info"][i].ids_of_folded_args = tuple(self.metadata["ids_of_folded_args"]) for key, val in self.metadata.items(): setattr(self, key, val) # stores the text of each level of IR that was generated during compilation diff --git a/python/triton/compiler/utils.py b/python/triton/compiler/utils.py index 48233afedd4c..65789129fb25 100644 --- a/python/triton/compiler/utils.py +++ b/python/triton/compiler/utils.py @@ -243,6 +243,7 @@ def getGlobalStrides(self, args): return strides_in_bytes def getOriginArgIdx(self, idx, args): + print(self.ids_of_folded_args) if self.ids_of_folded_args: ids_before_folding_arg = [i for i in range(len(args)) if i not in self.ids_of_folded_args] return ids_before_folding_arg[idx] diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index 61ce29521a02..36840eee9f8f 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -101,12 +101,12 @@ def __init__(self): self.backend = self.CUDA self.binary_ext = "cubin" - def assemble_tensormap_to_arg(self, args): + def assemble_tensormap_to_arg(self, tensormaps_info, args): args_with_tma = list(args) - if hasattr(self, 'tensormaps_info'): + if tensormaps_info is not None: # tuple for hashable args_ptr = tuple([arg.data_ptr() if hasattr(arg, 'data_ptr') else arg for arg in args]) - for i, e in enumerate(self.tensormaps_info): + for i, e in enumerate(tensormaps_info): args_with_tma.append(CudaDriver.tensormap_manager[(e, args_ptr)]) return args_with_tma diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index f43b83a9fe69..60f1610993a7 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -563,7 +563,7 @@ def get_special_arg(name: str, default=None): CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, - *driver.assemble_tensormap_to_arg(non_constexpr_arg_values), + *driver.assemble_tensormap_to_arg(bin.metadata["tensormaps_info"], non_constexpr_arg_values), ) return bin From 0b983bdd6f7ce7e25905091869a15c716afabc83 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 27 Nov 2023 22:38:23 -0800 Subject: [PATCH 46/64] . --- python/triton/compiler/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/triton/compiler/utils.py b/python/triton/compiler/utils.py index 65789129fb25..48233afedd4c 100644 --- a/python/triton/compiler/utils.py +++ b/python/triton/compiler/utils.py @@ -243,7 +243,6 @@ def getGlobalStrides(self, args): return strides_in_bytes def getOriginArgIdx(self, idx, args): - print(self.ids_of_folded_args) if self.ids_of_folded_args: ids_before_folding_arg = [i for i in range(len(args)) if i not in self.ids_of_folded_args] return ids_before_folding_arg[idx] From 2e9e902fd33ce67aae6a104add3949eaeec4a0c9 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 27 Nov 2023 22:40:45 -0800 Subject: [PATCH 47/64] . --- python/triton/compiler/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 75dd69f7d74f..4092d4541879 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -255,6 +255,7 @@ def __init__(self, so_path, metadata_path): # (e.g., checking amount of shared memory on current device) self.module = None self.function = None + print(self.num_ctas, self.cluster_dims) def _init_handles(self): if self.module is not None: From 2563c0d3672cb4de9c6deb9519146b9a99402b16 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 27 Nov 2023 22:48:15 -0800 Subject: [PATCH 48/64] . --- python/triton/compiler/backends/cuda.py | 1 + python/triton/compiler/compiler.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 65ee23054147..aae2dc93805e 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -153,6 +153,7 @@ def make_ttgir(mod, metadata, opt, capability): pm.add_tritongpu_optimize_thread_locality_pass() pm.add_canonicalizer_pass() pm.run(mod) + metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) return mod @staticmethod diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 4092d4541879..75dd69f7d74f 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -255,7 +255,6 @@ def __init__(self, so_path, metadata_path): # (e.g., checking amount of shared memory on current device) self.module = None self.function = None - print(self.num_ctas, self.cluster_dims) def _init_handles(self): if self.module is not None: From 4d53e08ad19a93ec22b8b8719e9320e0d3d4b8ec Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Tue, 28 Nov 2023 17:49:07 -0800 Subject: [PATCH 49/64] . --- .../triton/Dialect/Triton/Transforms/Passes.h | 3 ++- .../Dialect/Triton/Transforms/Passes.td | 6 ++++- .../TritonNvidiaGPU/Transforms/Passes.h | 3 ++- .../TritonNvidiaGPU/Transforms/Passes.td | 6 ++++- .../Transforms/RewriteTensorPointer.cpp | 13 +++++++--- .../Transforms/RewriteTensorPointer.cpp | 25 ++++++++++++------- python/src/triton.cc | 10 +++++--- python/triton/compiler/backends/cuda.py | 3 +-- 8 files changed, 47 insertions(+), 22 deletions(-) diff --git a/include/triton/Dialect/Triton/Transforms/Passes.h b/include/triton/Dialect/Triton/Transforms/Passes.h index fde54fe17125..1d1ef2615d83 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.h +++ b/include/triton/Dialect/Triton/Transforms/Passes.h @@ -9,7 +9,8 @@ namespace triton { std::unique_ptr createCombineOpsPass(); std::unique_ptr createReorderBroadcastPass(); -std::unique_ptr createRewriteTensorPointerPass(); +std::unique_ptr +createRewriteTensorPointerPass(int computeCapability = 80); } // namespace triton diff --git a/include/triton/Dialect/Triton/Transforms/Passes.td b/include/triton/Dialect/Triton/Transforms/Passes.td index 404e8896c062..219e72b0950b 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.td +++ b/include/triton/Dialect/Triton/Transforms/Passes.td @@ -40,7 +40,11 @@ def TritonRewriteTensorPointer : Pass + ]; } #endif diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h index a9ac3ffeab89..9d3fd70890c7 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h @@ -70,7 +70,8 @@ createTritonNvidiaGPUWSMaterializationPass(int computeCapability = 90); std::unique_ptr createTritonNvidiaGPUFenceInsertionPass(int computeCapability = 90); -std::unique_ptr createTritonGPURewriteTensorPointerPass(); +std::unique_ptr +createTritonGPURewriteTensorPointerPass(int computeCapability = 80); std::unique_ptr createTritonNvidiaGPUWSFixupMissingAttrs(); diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td index b94b8bcb64ca..d038c610f999 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td @@ -218,7 +218,11 @@ def TritonGPURewriteTensorPointer : Pass + ]; } def TritonGPUWSFixupMissingAttrs : Pass<"triton-nvidia-gpu-ws-fixup-missing-attrs", "mlir::ModuleOp"> { diff --git a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp index 58b6e73ce986..31a53af78b69 100644 --- a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -196,7 +196,9 @@ class RewriteTensorPointerPass DenseMap rewritedInfo; public: - explicit RewriteTensorPointerPass() {} + explicit RewriteTensorPointerPass(int computeCapability) { + this->computeCapability = computeCapability; + } static bool needRewrite(Operation *op) { return std::any_of(op->getOperands().begin(), op->getOperands().end(), @@ -471,6 +473,10 @@ class RewriteTensorPointerPass } void runOnOperation() override { + // Only rewrite if the hardware does not support + if (computeCapability >= 90) + return; + // NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because // MLIR does not support one-multiple value mapping. For example, if we use // `ConversionPatternRewriter`, we can not make a type converter, which @@ -496,6 +502,7 @@ class RewriteTensorPointerPass } }; -std::unique_ptr triton::createRewriteTensorPointerPass() { - return std::make_unique(); +std::unique_ptr +triton::createRewriteTensorPointerPass(int computeCapability) { + return std::make_unique(computeCapability); } diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp index 257ca151cd02..3a10fee57c51 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp @@ -71,8 +71,8 @@ bool isDivisible(Value v, unsigned divisor) { } } -bool shouldRemove(tt::MakeTensorPtrOp &op) { - if (!::triton::tools::getBoolEnv("ENABLE_TMA")) +bool shouldRemove(tt::MakeTensorPtrOp &op, int computeCapability) { + if (computeCapability < 90 || !::triton::tools::getBoolEnv("ENABLE_TMA")) return true; auto resType = op.getResult() .getType() @@ -395,7 +395,13 @@ class TritonGPURewriteTensorPointerPass DenseMap rewritedInfo; public: + // explicit TritonGPURewriteTensorPointerPass(int computeCapability) + // : computeCapability(computeCapability) {} + TritonGPURewriteTensorPointerPass() = default; + TritonGPURewriteTensorPointerPass(int computeCapability) { + this->computeCapability = computeCapability; + } static bool needRewrite(Operation *op, const DenseSet &valueToRemove) { if (auto ifOp = dyn_cast(op)) { @@ -797,14 +803,14 @@ class TritonGPURewriteTensorPointerPass DenseSet valueToRemove; mod.walk([&valueToRemove, this](Operation *op) { if (auto makeTensorPtrOp = dyn_cast(op)) { - if (shouldRemove(makeTensorPtrOp)) + if (shouldRemove(makeTensorPtrOp, this->computeCapability)) valueToRemove.insert(op->getResult(0)); } if (llvm::isa(op)) { auto src = op->getOperand(0); if (tt::isTensorPointerType(src.getType())) { auto makeTensorPtrOp = getMakeTensorPtrOp(src); - if (shouldRemove(makeTensorPtrOp)) { + if (shouldRemove(makeTensorPtrOp, this->computeCapability)) { valueToRemove.insert(op->getResult(0)); } } @@ -813,7 +819,7 @@ class TritonGPURewriteTensorPointerPass auto src = op->getOperand(0); if (tt::isTensorPointerType(src.getType())) { auto makeTensorPtrOp = getMakeTensorPtrOp(src); - if (shouldRemove(makeTensorPtrOp)) + if (shouldRemove(makeTensorPtrOp, this->computeCapability)) valueToRemove.insert(src); } } @@ -822,7 +828,7 @@ class TritonGPURewriteTensorPointerPass for (unsigned i = 0, size = forOp.getInitArgs().size(); i < size; ++i) { if (tt::isTensorPointerType(iterOperands[i].getType())) { auto makeTensorPtrOp = getMakeTensorPtrOp(iterOperands[i]); - if (shouldRemove(makeTensorPtrOp)) + if (shouldRemove(makeTensorPtrOp, this->computeCapability)) valueToRemove.insert(iterOperands[i]); } } @@ -831,7 +837,7 @@ class TritonGPURewriteTensorPointerPass for (unsigned i = 0, size = yieldOp.getNumOperands(); i < size; ++i) { if (tt::isTensorPointerType(operands[i].getType())) { auto makeTensorPtrOp = getMakeTensorPtrOp(operands[i]); - if (shouldRemove(makeTensorPtrOp)) + if (shouldRemove(makeTensorPtrOp, this->computeCapability)) valueToRemove.insert(operands[i]); } } @@ -864,6 +870,7 @@ class TritonGPURewriteTensorPointerPass } }; -std::unique_ptr mlir::createTritonGPURewriteTensorPointerPass() { - return std::make_unique(); +std::unique_ptr +mlir::createTritonGPURewriteTensorPointerPass(int computeCapability) { + return std::make_unique(computeCapability); } diff --git a/python/src/triton.cc b/python/src/triton.cc index 2f0bfd875552..e1f9f482e1a3 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1685,8 +1685,9 @@ void init_triton_ir(py::module &&m) { self.addPass(mlir::triton::createReorderBroadcastPass()); }) .def("add_rewrite_tensor_pointer_pass", - [](mlir::PassManager &self) { - self.addPass(mlir::triton::createRewriteTensorPointerPass()); + [](mlir::PassManager &self, int capability) { + self.addPass( + mlir::triton::createRewriteTensorPointerPass(capability)); }) .def("add_tritongpu_ws_feasibility_checking_pass", [](mlir::PassManager &self, int computeCapability) { @@ -1760,8 +1761,9 @@ void init_triton_ir(py::module &&m) { self.addPass(mlir::createTritonGPUReorderInstructionsPass()); }) .def("add_tritongpu_rewrite_tensor_pointer_pass", - [](mlir::PassManager &self) { - self.addPass(mlir::createTritonGPURewriteTensorPointerPass()); + [](mlir::PassManager &self, int capability) { + self.addPass( + mlir::createTritonGPURewriteTensorPointerPass(capability)); }) .def("add_tritongpu_decompose_conversions_pass", [](mlir::PassManager &self) { diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index aae2dc93805e..bd73e9280238 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -106,8 +106,7 @@ def make_ttgir(mod, metadata, opt, capability): pm.add_tritongpu_coalesce_pass() # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass pm.add_plan_cta_pass(cluster_info) - if capability // 10 < 9: - pm.add_tritongpu_rewrite_tensor_pointer_pass() + pm.add_tritongpu_rewrite_tensor_pointer_pass(capability) pm.add_plan_cta_pass(cluster_info) pm.add_tritongpu_remove_layout_conversions_pass() pm.add_tritongpu_accelerate_matmul_pass(capability) From ca5781ca0accbbe4f991422be7a7669dca0c2020 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Tue, 28 Nov 2023 17:53:03 -0800 Subject: [PATCH 50/64] . --- test/TritonGPU/rewrite-tensor-pointer-tma.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/TritonGPU/rewrite-tensor-pointer-tma.mlir b/test/TritonGPU/rewrite-tensor-pointer-tma.mlir index bb14ca6ad830..6f9d5f58763a 100644 --- a/test/TritonGPU/rewrite-tensor-pointer-tma.mlir +++ b/test/TritonGPU/rewrite-tensor-pointer-tma.mlir @@ -1,4 +1,4 @@ -// RUN: ENABLE_TMA=1 triton-opt %s -split-input-file -tritongpu-rewrite-tensor-pointer | FileCheck %s +// RUN: ENABLE_TMA=1 triton-opt %s -split-input-file -tritongpu-rewrite-tensor-pointer=compute-capability=90 | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { From 84c55980f8a763e3b50816453f15c8525fe3e2ca Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Tue, 28 Nov 2023 18:18:23 -0800 Subject: [PATCH 51/64] device_type -> target --- python/test/unit/runtime/test_subproc.py | 2 +- python/triton/compiler/compiler.py | 10 ++++++---- python/triton/runtime/jit.py | 10 +++++++++- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index c19fd88e1e36..a552342c0d35 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -26,7 +26,7 @@ def kernel_sub(a, b, o, N: tl.constexpr): triton.compile( src=kernel_sub, signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, - device_type=("cuda", cc), + target=("cuda", cc), constants={3: 32}, configs=[config], ) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 75dd69f7d74f..4feb5d41b944 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -8,7 +8,7 @@ # TODO: runtime.errors from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager -from ..runtime.jit import get_current_device, get_cuda_stream +from ..runtime.jit import get_current_device, get_cuda_stream, get_current_target from ..runtime.driver import driver from .utils import InfoFromBackendForTensorMap from .backends.cuda import CUDABackend @@ -164,18 +164,20 @@ def update_options(self, options): options.num_warps = _get_num_warps_from_ir_str(self.src) -def compile(src, device_type=("cuda", 80), signature=None, configs=None, constants=None, extern_libs=None, **kwargs): +def compile(src, target=None, signature=None, configs=None, constants=None, extern_libs=None, **kwargs): # TODO (backward-breaking): # - merge InstanceDescriptor and SpecializationDescriptor # - no more configs # - extern_libs => linker_flags: dict # - **kwargs -> compiler_flags: dict + if target is None: + target = get_current_target() + backend = CUDABackend(target) configs = [InstanceDescriptor()] if configs is None else configs assert len(configs) == 1 config = configs[0] # create backend src = IRSource(src) if isinstance(src, str) else ASTSource(src, signature, constants, config) - backend = CUDABackend(device_type) options = backend.parse_options(**kwargs) src.update_options(options) @@ -194,7 +196,7 @@ def compile(src, device_type=("cuda", 80), signature=None, configs=None, constan # initialize metadata metadata = { - "device_type": device_type, + "target": target, **options.__dict__, **get_env_vars(), } diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 60f1610993a7..908890562eb7 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -47,6 +47,14 @@ def get_device_capability(idx): return torch.cuda.get_device_capability(idx) +def get_current_target(): + import torch + device = get_current_device() + capability = get_device_capability(device) + capability = capability[0] * 10 + capability[1] + return ("cuda", capability) + + T = TypeVar("T") # ----------------------------------------------------------------------------- @@ -533,8 +541,8 @@ def get_special_arg(name: str, default=None): capability = capability[0] * 10 + capability[1] self.cache[device][key] = compile( self, + target=(device_type, capability), signature=signature, - device_type=(device_type, capability), constants=constants, num_warps=num_warps, num_ctas=num_ctas, From 2651cf8e18ef640c9975abc5fe4c70af1b044965 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Tue, 28 Nov 2023 18:28:34 -0800 Subject: [PATCH 52/64] . --- python/triton/compiler/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 4feb5d41b944..453f61794a5a 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -279,7 +279,7 @@ def __getitem__(self, grid): self._init_handles() def runner(*args, stream=None): - args_expand = driver.assemble_tensormap_to_arg(args) + args_expand = driver.assemble_tensormap_to_arg(self.tensormaps_info, args) if stream is None: stream = get_cuda_stream() self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.cluster_dims[0], From f10d336bfa9de1f7b1ecb89787842b4c887f66c3 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Tue, 28 Nov 2023 18:48:59 -0800 Subject: [PATCH 53/64] . --- python/test/unit/tools/test_aot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/test/unit/tools/test_aot.py b/python/test/unit/tools/test_aot.py index b76ba71997e2..613836d24e36 100644 --- a/python/test/unit/tools/test_aot.py +++ b/python/test/unit/tools/test_aot.py @@ -446,7 +446,7 @@ def test_ttgir_to_ptx(): kernel_path = os.path.join(tmp_dir, "empty_kernel.ttgir") with open(kernel_path, "w") as fp: fp.write(src) - k = triton.compile(kernel_path, device_type=("cuda", 80)) + k = triton.compile(kernel_path, target=("cuda", 80)) ptx = k.asm["ptx"] assert ".target sm_80" in ptx assert ".address_size 64" in ptx From 730cecdc7ca11d0a911b31f1eab2fe4449832bd9 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Tue, 28 Nov 2023 20:11:19 -0800 Subject: [PATCH 54/64] fixup --- python/test/unit/runtime/test_subproc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index a552342c0d35..1974328711ce 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -56,7 +56,7 @@ def kernel_dot(Z): triton.compile( src=kernel_dot, signature={0: "*fp32"}, - device_type=("cuda", cc), + target=("cuda", cc), configs=[config], ) From 5f84af476f7fd8351abddbc2e992b8e438df4c2c Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Tue, 28 Nov 2023 21:53:42 -0800 Subject: [PATCH 55/64] . --- python/triton/language/semantic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index b521c8483923..c1ee1036ba6f 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -641,7 +641,8 @@ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar - assert builder.options.allow_fp8e4nv, "fp8e4nv data type is not supported on CUDA arch < 89" + if (src_sca_ty.is_fp8e4nv() or dst_sca_ty.is_fp8e4nv()): + assert builder.options.allow_fp8e4nv, "fp8e4nv data type is not supported on CUDA arch < 89" # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ From 7f4ff234cf9d1570ee24c422ae2dc7280c75691b Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Wed, 29 Nov 2023 14:38:05 -0800 Subject: [PATCH 56/64] cleaning --- python/triton/compiler/__init__.py | 4 +- python/triton/compiler/backends/cuda.py | 33 ++++++---- python/triton/compiler/code_generator.py | 12 ++-- python/triton/compiler/compiler.py | 76 +++++++++--------------- python/triton/runtime/jit.py | 30 +++++----- 5 files changed, 72 insertions(+), 83 deletions(-) diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index e3bf82207b14..0ced3fed6af1 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,7 +1,7 @@ -from .compiler import (CompiledKernel, compile, InstanceDescriptor) +from .compiler import (CompiledKernel, ASTSource, compile, AttrsDescriptor) from .errors import CompilationError __all__ = [ - "compile", "InstanceDescriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", + "compile", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages" ] diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index bd73e9280238..68e75bd379d4 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -50,7 +50,6 @@ class CUDAOptions: enable_persistent: bool = False optimize_epilogue: bool = False enable_fp_fusion: bool = True - extern_libs = None allow_fp8e4nv: bool = False max_num_imprecise_acc_default: bool = None debug: bool = False @@ -64,6 +63,11 @@ def hash(self): return hashlib.md5(key.encode("utf-8")).hexdigest() +@dataclass +class CUDALinkerOptions: + libs: dict = None + + class CUDABackend(BaseBackend): def __init__(self, device_type: tuple) -> None: @@ -71,12 +75,15 @@ def __init__(self, device_type: tuple) -> None: self.capability = device_type[1] assert isinstance(self.capability, int) - def parse_options(self, **opts) -> Any: + def parse_compiler_options(self, opts) -> Any: options = CUDAOptions(**opts) options.allow_fp8e4nv = self.capability >= 89 options.max_num_imprecise_acc_default = 0 if self.capability >= 89 else None return options + def parse_linker_options(self, opts) -> Any: + return CUDALinkerOptions(**opts) + @staticmethod def make_ttir(mod, metadata, opt): pm = ir.pass_manager(mod.context) @@ -156,15 +163,15 @@ def make_ttgir(mod, metadata, opt, capability): return mod @staticmethod - def make_llir(src, metadata, opt, capability): + def make_llir(src, metadata, linker_options, capability): metadata["enable_warp_specialization"] = ir.is_ws_supported(src) metadata["num_warps"] = get_num_warps(src) tma_infos = TMAInfos() - - if opt.extern_libs: - add_external_libs(src, list(opt.extern_libs.keys()), list(opt.extern_libs.values())) + # link libraries + if linker_options.libs: + add_external_libs(src, list(linker_options.libs.keys()), list(linker_options.libs.values())) + # TritonGPU -> LLVM-IR ret = translate_triton_gpu_to_llvmir(src, capability, tma_infos, runtime.TARGET.NVVM) - if len(tma_infos) > 0: metadata["tensormaps_info"] = parse_tma_info(tma_infos, metadata["ids_of_folded_args"]) for i, _ in enumerate(metadata["tensormaps_info"]): @@ -187,12 +194,12 @@ def make_cubin(src, metadata, opt, capability): ptxas, _ = path_to_ptxas() return compile_ptx_to_cubin(src, ptxas, capability, opt.enable_fp_fusion) - def add_stages(self, extern_libs, stages, opt): - stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, opt) - stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, opt, self.capability) - stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, opt, self.capability) - stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, opt, self.capability) - stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, opt, self.capability) + def add_stages(self, stages, compiler_options, linker_options): + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, compiler_options) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, compiler_options, self.capability) + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, linker_options, self.capability) + stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, compiler_options, self.capability) + stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, compiler_options, self.capability) def hash(self): return f'{get_cuda_version_key()}-{self.capability}' diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 86f361a48976..53812ca32ba7 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -1190,7 +1190,7 @@ def kernel_suffix(signature, specialization): def ast_to_ttir(fn, specialization, options): - config = specialization.config + attrs = specialization.attrs context = ir.context() context.load_triton() # create kernel prototype @@ -1198,13 +1198,13 @@ def ast_to_ttir(fn, specialization, options): constants = {cst_key(key): value for key, value in specialization.constants.items()} # visit kernel AST gscope = fn.__globals__.copy() - function_name = '_'.join([fn.__name__, kernel_suffix(specialization.signature.values(), config)]) + function_name = '_'.join([fn.__name__, kernel_suffix(specialization.signature.values(), attrs)]) tys = list(specialization.signature.values()) - new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in config.equal_to_1} - new_attrs = {k: [("tt.divisibility", 16)] for k in config.divisible_by_16} - for k in config.divisible_by_8: + new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs.equal_to_1} + new_attrs = {k: [("tt.divisibility", 16)] for k in attrs.divisible_by_16} + for k in attrs.divisible_by_8: attr = new_attrs[k] if k in new_attrs else [] - if k in config.divisible_by_16: + if k in attrs.divisible_by_16: attr.append(("tt.max_divisibility", 16)) else: attr.append(("tt.max_divisibility", 8)) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 453f61794a5a..a927b74aea10 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -19,7 +19,7 @@ @dataclass -class InstanceDescriptor: +class AttrsDescriptor: divisible_by_16: set = None equal_to_1: set = None ids_of_folded_args: set = None @@ -40,23 +40,6 @@ def hash(self): return hashlib.md5(key.encode("utf-8")).hexdigest() -@dataclass -class SpecializationDescriptor: - config: InstanceDescriptor - signature: dict - constants: dict - - def __post_init__(self): - if isinstance(self.signature, str): - self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} - if self.constants is None: - self.constants = dict() - - def hash(self): - key = f"{self.config.hash()}-{self.signature.values()}-{self.constants}" - return hashlib.md5(key.encode("utf-8")).hexdigest() - - # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, # and any following whitespace # - (public\s+)? : optionally match the keyword public and any following whitespace @@ -113,25 +96,28 @@ def _get_num_warps_from_ir_str(src: str): class ASTSource: - def __init__(self, fn, signature, constants, config) -> None: + def __init__(self, fn, signature, constants, attrs) -> None: self.fn = fn self.ext = "ttir" self.name = fn.__name__ self.signature = signature self.constants = constants - self.config = config + self.attrs = attrs if isinstance(self.signature, str): self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} if self.constants is None: self.constants = dict() def hash(self): - key = f"{self.fn.cache_key}-{self.config.hash()}-{self.signature.values()}-{self.constants}" + key = f"{self.fn.cache_key}-{self.attrs.hash()}-{self.signature.values()}-{self.constants}" return hashlib.md5(key.encode("utf-8")).hexdigest() def make_ir(self, options): - specialization = SpecializationDescriptor(self.config, self.signature, self.constants) - return ast_to_ttir(self.fn, specialization, options=options) + return ast_to_ttir(self.fn, self, options=options) + + def metadata(self): + # TODO: remove once TMA support is cleaned up + return {"ids_of_folded_args": tuple([int(k) for k in self.attrs.ids_of_folded_args])} def update_options(self, options): pass @@ -159,30 +145,27 @@ def make_ir(self, options): module.context = context return module + def metadata(self): + return dict() + def update_options(self, options): if self.ext == "ttgir": options.num_warps = _get_num_warps_from_ir_str(self.src) -def compile(src, target=None, signature=None, configs=None, constants=None, extern_libs=None, **kwargs): - # TODO (backward-breaking): - # - merge InstanceDescriptor and SpecializationDescriptor - # - no more configs - # - extern_libs => linker_flags: dict - # - **kwargs -> compiler_flags: dict +def compile(src, target=None, compiler_options=None, linker_options=None): if target is None: target = get_current_target() backend = CUDABackend(target) - configs = [InstanceDescriptor()] if configs is None else configs - assert len(configs) == 1 - config = configs[0] # create backend - src = IRSource(src) if isinstance(src, str) else ASTSource(src, signature, constants, config) - options = backend.parse_options(**kwargs) - src.update_options(options) - + compiler_options = backend.parse_compiler_options(compiler_options) + linker_options = backend.parse_linker_options(linker_options) + if not isinstance(src, ASTSource): + assert isinstance(src, str), "source must be either AST or a filepath" + src = IRSource(src) + src.update_options(compiler_options) # create cache manager - key = f"{src.hash()}-{backend.hash()}-{options.hash()}-{frozenset(sorted(get_env_vars().items()))}" + key = f"{src.hash()}-{backend.hash()}-{compiler_options.hash()}-{frozenset(sorted(get_env_vars().items()))}" hash = hashlib.md5(key.encode("utf-8")).hexdigest() fn_cache_manager = get_cache_manager(hash) metadata_filename = f"{src.name}.json" @@ -193,22 +176,19 @@ def compile(src, target=None, signature=None, configs=None, constants=None, exte metadata = json.loads(Path(metadata_path).read_text()) so_path = backend.make_launcher_stub(src, metadata) return CompiledKernel(so_path, metadata_path) - # initialize metadata metadata = { "target": target, - **options.__dict__, - **get_env_vars(), + "compiler_options": compiler_options.__dict__, + "linker_options": linker_options.__dict__, + "environment": get_env_vars(), + **src.metadata(), } - # TODO: remove once TMA support is cleaned up - if signature is not None: - metadata["ids_of_folded_args"] = tuple([int(k) for k in config.ids_of_folded_args]) # run compilation pipeline and populate metadata stages = dict() - backend.add_stages(extern_libs, stages, options) - # + backend.add_stages(stages, compiler_options, linker_options) first_stage = list(stages.keys()).index(src.ext) - module = src.make_ir(options) + module = src.make_ir(compiler_options) for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) metadata_group[f"{src.name}.{ext}"] = fn_cache_manager.put(next_module, f"{src.name}.{ext}") @@ -243,8 +223,10 @@ def __init__(self, so_path, metadata_path): ] if 'tensormaps_info' in self.metadata else [] for i, _ in enumerate(self.metadata["tensormaps_info"]): self.metadata["tensormaps_info"][i].ids_of_folded_args = tuple(self.metadata["ids_of_folded_args"]) - for key, val in self.metadata.items(): + for key, val in self.metadata["compiler_options"].items(): setattr(self, key, val) + self.shared = self.metadata["shared"] + self.name = self.metadata["name"] # stores the text of each level of IR that was generated during compilation asm_files = [file for file in metadata_path.parent.glob(f'{metadata_path.stem}.*') if file.suffix != '.json'] self.asm = { diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 908890562eb7..42b16e3dba74 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -261,7 +261,7 @@ def _spec_of(arg): # TODO(jlebar): Fold this into the KernelArg class. def _get_config(self, *args): - from ..compiler import InstanceDescriptor + from ..compiler import AttrsDescriptor def is_divisible_by_16(x): if hasattr(x, "data_ptr"): @@ -298,8 +298,8 @@ def is_divisible_by_8(x): # TODO: method to collect all folded args none_args = {param.num for param, arg in zip(self.params, args) if arg is None and not param.do_not_specialize} ids_of_folded_args = equal_to_1 | none_args - return InstanceDescriptor(tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), - tuple(divisible_by_8)) + return AttrsDescriptor(tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), + tuple(divisible_by_8)) # return _triton.code_gen.instance_descriptor(divisible_by_16, # equal_to_1) @@ -408,7 +408,7 @@ def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: Li return device_types[0] if len(device_types) > 0 else "cuda" def run(self, *args, **kwargs): - from ..compiler import CompiledKernel, compile + from ..compiler import CompiledKernel, compile, ASTSource # Get a compiler-flags arg like `num_warps` and remove it from kwargs. def get_special_arg(name: str, default=None): @@ -539,19 +539,19 @@ def get_special_arg(name: str, default=None): capability = get_device_capability(device) capability = capability[0] * 10 + capability[1] + src = ASTSource(self, signature, constants, configs[0]) self.cache[device][key] = compile( - self, + src, target=(device_type, capability), - signature=signature, - constants=constants, - num_warps=num_warps, - num_ctas=num_ctas, - num_stages=num_stages, - enable_warp_specialization=enable_warp_specialization, - enable_fp_fusion=enable_fp_fusion, - extern_libs=extern_libs, - configs=[configs[0]], - debug=self.debug, + compiler_options={ + "num_warps": num_warps, + "num_ctas": num_ctas, + "num_stages": num_stages, + "enable_warp_specialization": enable_warp_specialization, + "enable_fp_fusion": enable_fp_fusion, + "debug": self.debug, + }, + linker_options={"libs": extern_libs}, ) bin = self.cache[device][key] From 15e8e4290485759226565b53acfac54d3d20d895 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Wed, 29 Nov 2023 14:53:34 -0800 Subject: [PATCH 57/64] more cleaning --- python/test/unit/runtime/test_subproc.py | 29 ++++++++++++------------ python/triton/tools/compile.py | 21 +++++++++-------- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index 1974328711ce..e405bb8bff77 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -6,6 +6,7 @@ import triton import triton.language as tl +from triton.compiler import ASTSource tmpdir = ".tmp" @@ -16,20 +17,20 @@ def reset_tmp_dir(): shutil.rmtree(tmpdir, ignore_errors=True) -def compile_fn(config, cc): +def compile_fn(attrs, capability): @triton.jit def kernel_sub(a, b, o, N: tl.constexpr): idx = tl.arange(0, N) tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) - triton.compile( - src=kernel_sub, - signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, - target=("cuda", cc), + src = ASTSource( + fn=kernel_sub, constants={3: 32}, - configs=[config], + signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, + attrs=attrs, ) + triton.compile(src=src, target=("cuda", capability)) def test_compile_in_subproc() -> None: @@ -44,7 +45,7 @@ def test_compile_in_subproc() -> None: assert proc.exitcode == 0 -def compile_fn_dot(config, cc): +def compile_fn_dot(attrs, capability): @triton.jit def kernel_dot(Z): @@ -53,22 +54,22 @@ def kernel_dot(Z): z = tl.dot(z, z) tl.store(Z + offs, z) - triton.compile( - src=kernel_dot, + src = ASTSource( + fn=kernel_dot, signature={0: "*fp32"}, - target=("cuda", cc), - configs=[config], + attrs=attrs, ) + triton.compile(src=src, target=("cuda", capability)) def test_compile_in_forked_subproc() -> None: reset_tmp_dir() major, minor = torch.cuda.get_device_capability(0) - cc = major * 10 + minor - config = triton.compiler.InstanceDescriptor(tuple(range(1)), (), (), ()) + capability = major * 10 + minor + config = triton.compiler.AttrsDescriptor(tuple(range(1)), (), (), ()) assert multiprocessing.get_start_method() == 'fork' - proc = multiprocessing.Process(target=compile_fn_dot, args=(config, cc)) + proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability)) proc.start() proc.join() assert proc.exitcode == 0 diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index 02ca129e46e8..e28a7a404c02 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -92,11 +92,11 @@ def constexpr(s): hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} hints = {k: v for k, v in hints.items() if v is not None} - constexprs = {i: constexpr(s) for i, s in enumerate(signature)} - constexprs = {k: v for k, v in constexprs.items() if v is not None} - signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constexprs} - const_sig = 'x'.join([str(v) for v in constexprs.values()]) - doc_string = [f"{kernel.arg_names[i]}={constexprs[i]}" for i in constexprs.keys()] + constants = {i: constexpr(s) for i, s in enumerate(signature)} + constants = {k: v for k, v in constants.items() if v is not None} + signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constants} + const_sig = 'x'.join([str(v) for v in constants.values()]) + doc_string = [f"{kernel.arg_names[i]}={constants[i]}" for i in constants.keys()] doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] # compile ast into cubin @@ -104,11 +104,12 @@ def constexpr(s): assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" divisible_by_16 = [i for i, h in hints.items() if h == 16] equal_to_1 = [i for i, h in hints.items() if h == 1] - config = triton.compiler.InstanceDescriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) + attrs = triton.compiler.AttrsDescriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) for i in equal_to_1: - constexprs.update({i: 1}) - ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], - num_warps=args.num_warps, num_stages=args.num_stages) + constants.update({i: 1}) + src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) + opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} + ccinfo = triton.compile(src, compiler_options=opts) arg_names = [] arg_types = [] for i in signature.keys(): @@ -117,7 +118,7 @@ def constexpr(s): arg_types += [signature[i]] # dump C stub code - suffix = kernel_suffix(signature.values(), config) + suffix = kernel_suffix(signature.values(), attrs) func_name = '_'.join([out_name, sig_hash, suffix]) triton_kernel_name = '_'.join([args.kernel_name, suffix]) hex_ = str(binascii.hexlify(ccinfo.asm["cubin"]))[2:-1] From f0dfa236ffbd35fef3f58ae1803dbaf887aa5541 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Wed, 29 Nov 2023 15:08:31 -0800 Subject: [PATCH 58/64] cleaning --- python/triton/compiler/compiler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index a927b74aea10..25c5528c0c0f 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -158,8 +158,8 @@ def compile(src, target=None, compiler_options=None, linker_options=None): target = get_current_target() backend = CUDABackend(target) # create backend - compiler_options = backend.parse_compiler_options(compiler_options) - linker_options = backend.parse_linker_options(linker_options) + compiler_options = backend.parse_compiler_options(compiler_options or dict()) + linker_options = backend.parse_linker_options(linker_options or dict()) if not isinstance(src, ASTSource): assert isinstance(src, str), "source must be either AST or a filepath" src = IRSource(src) @@ -225,6 +225,7 @@ def __init__(self, so_path, metadata_path): self.metadata["tensormaps_info"][i].ids_of_folded_args = tuple(self.metadata["ids_of_folded_args"]) for key, val in self.metadata["compiler_options"].items(): setattr(self, key, val) + self.tensormaps_info = self.metadata["tensormaps_info"] self.shared = self.metadata["shared"] self.name = self.metadata["name"] # stores the text of each level of IR that was generated during compilation From 2fd04d087548178e3cdc0d72874f0682a0ff3bb6 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Wed, 29 Nov 2023 15:55:03 -0800 Subject: [PATCH 59/64] fix linker option bug --- python/triton/compiler/backends/cuda.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 68e75bd379d4..516c1c5921f4 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -67,6 +67,9 @@ def hash(self): class CUDALinkerOptions: libs: dict = None + def __post_init__(self): + self.libs = {k: v for k, v in self.libs.items() if v} + class CUDABackend(BaseBackend): From 6c8b3f0078c0709280a27796d63903ab4c35d3f7 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Wed, 29 Nov 2023 16:11:37 -0800 Subject: [PATCH 60/64] . --- python/triton/compiler/backends/cuda.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 516c1c5921f4..a019539ad812 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -68,7 +68,8 @@ class CUDALinkerOptions: libs: dict = None def __post_init__(self): - self.libs = {k: v for k, v in self.libs.items() if v} + if self.libs is not None: + self.libs = {k: v for k, v in self.libs.items() if v} class CUDABackend(BaseBackend): From 0a4f518106945772f077bd999ea8ae710704a557 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Wed, 29 Nov 2023 16:31:33 -0800 Subject: [PATCH 61/64] fixup --- python/test/unit/runtime/test_subproc.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index e405bb8bff77..63401f28e42b 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -36,7 +36,7 @@ def kernel_sub(a, b, o, N: tl.constexpr): def test_compile_in_subproc() -> None: major, minor = torch.cuda.get_device_capability(0) cc = major * 10 + minor - config = triton.compiler.InstanceDescriptor(tuple(range(4)), (), (), ()) + config = triton.compiler.AttrsDescriptor(tuple(range(4)), (), (), ()) multiprocessing.set_start_method('fork') proc = multiprocessing.Process(target=compile_fn, args=(config, cc)) @@ -54,11 +54,7 @@ def kernel_dot(Z): z = tl.dot(z, z) tl.store(Z + offs, z) - src = ASTSource( - fn=kernel_dot, - signature={0: "*fp32"}, - attrs=attrs, - ) + src = ASTSource(fn=kernel_dot, signature={0: "*fp32"}, attrs=attrs, constants=dict()) triton.compile(src=src, target=("cuda", capability)) From 10f3c62bd6a34603c6938823b40a42d0a9f958cb Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Wed, 29 Nov 2023 18:17:09 -0800 Subject: [PATCH 62/64] fixup --- python/triton/compiler/compiler.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 25c5528c0c0f..cd096f97d9eb 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -179,9 +179,9 @@ def compile(src, target=None, compiler_options=None, linker_options=None): # initialize metadata metadata = { "target": target, - "compiler_options": compiler_options.__dict__, - "linker_options": linker_options.__dict__, - "environment": get_env_vars(), + **compiler_options.__dict__, + **linker_options.__dict__, + **get_env_vars(), **src.metadata(), } # run compilation pipeline and populate metadata @@ -223,11 +223,8 @@ def __init__(self, so_path, metadata_path): ] if 'tensormaps_info' in self.metadata else [] for i, _ in enumerate(self.metadata["tensormaps_info"]): self.metadata["tensormaps_info"][i].ids_of_folded_args = tuple(self.metadata["ids_of_folded_args"]) - for key, val in self.metadata["compiler_options"].items(): + for key, val in self.metadata.items(): setattr(self, key, val) - self.tensormaps_info = self.metadata["tensormaps_info"] - self.shared = self.metadata["shared"] - self.name = self.metadata["name"] # stores the text of each level of IR that was generated during compilation asm_files = [file for file in metadata_path.parent.glob(f'{metadata_path.stem}.*') if file.suffix != '.json'] self.asm = { From bac12843131c0a187167f0245d74a2e114bffd89 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Wed, 29 Nov 2023 18:32:07 -0800 Subject: [PATCH 63/64] . From c525439d2ee4dbf9ede1b68149c03e0ab0732310 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Wed, 29 Nov 2023 18:49:55 -0800 Subject: [PATCH 64/64] fixup --- .../unit/hopper/test_persistent_warp_specialized_gemm.py | 7 ++++--- python/triton/compiler/compiler.py | 4 +++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py index abd5c5edcbc4..1101b8906688 100644 --- a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py +++ b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py @@ -900,11 +900,12 @@ def process_epilogue(d, bias, w, epilogue): num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count if NUM_CTAS > 1: device = get_current_device() - null_kernel = triton.compile(empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}) + src = triton.compiler.ASTSource(fn=empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}) + null_kernel = triton.compile(src) null_kernel._init_handles() max_shared_mem = driver.utils.get_device_properties(device)["max_shared_mem"] - num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.cu_function, max_shared_mem, NUM_CTAS, - 1, 1) + num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.function, max_shared_mem, NUM_CTAS, 1, + 1) num_SMs = num_clusters def grid(META): diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index cd096f97d9eb..528b824a61d5 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -96,7 +96,7 @@ def _get_num_warps_from_ir_str(src: str): class ASTSource: - def __init__(self, fn, signature, constants, attrs) -> None: + def __init__(self, fn, signature, constants=None, attrs=None) -> None: self.fn = fn self.ext = "ttir" self.name = fn.__name__ @@ -107,6 +107,8 @@ def __init__(self, fn, signature, constants, attrs) -> None: self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} if self.constants is None: self.constants = dict() + if self.attrs is None: + self.attrs = AttrsDescriptor() def hash(self): key = f"{self.fn.cache_key}-{self.attrs.hash()}-{self.signature.values()}-{self.constants}"