From 81b8b4b7b4aaa80e0a68d1747055306f948182aa Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 20 Sep 2024 08:41:30 -0700 Subject: [PATCH] [Mosaic GPU] Clean up the module structure Previously the code was awkwardly split between the `jax.experimental.mosaic.gpu` and `jax.experimental.mosaic.gpu.dsl` namespaces. I've now merged both so that all user-visible APIs are accessible from `jax.experimental.mosaic.gpu`. PiperOrigin-RevId: 676857257 --- jax/_src/pallas/mosaic_gpu/core.py | 8 +- jax/_src/pallas/mosaic_gpu/lowering.py | 13 +- .../mosaic_gpu/pallas_call_registration.py | 4 +- jax/experimental/mosaic/gpu/__init__.py | 1013 +---------------- jax/experimental/mosaic/gpu/core.py | 979 ++++++++++++++++ jax/experimental/mosaic/gpu/dsl.py | 58 - .../mosaic/gpu/examples/flash_attention.py | 28 +- .../mosaic/gpu/examples/matmul.py | 17 +- .../mosaic/gpu/fragmented_array.py | 2 +- jax/experimental/mosaic/gpu/wgmma.py | 4 +- tests/mosaic/gpu_test.py | 99 +- 11 files changed, 1111 insertions(+), 1114 deletions(-) create mode 100644 jax/experimental/mosaic/gpu/core.py delete mode 100644 jax/experimental/mosaic/gpu/dsl.py diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index dc698b8747d9..6ef4cd1621f4 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -23,7 +23,7 @@ from jax._src import dtypes from jax._src import tree_util from jax._src.pallas import core as pallas_core -from jax.experimental.mosaic import gpu as mosaic_gpu +import jax.experimental.mosaic.gpu as mgpu import jax.numpy as jnp @@ -64,7 +64,7 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol): - def to_gpu_transform(self) -> mosaic_gpu.MemRefTransform: + def to_gpu_transform(self) -> mgpu.MemRefTransform: ... @@ -101,8 +101,8 @@ def __call__( inner_aval=block_aval.inner_aval.update(shape=new_block_shape) ) - def to_gpu_transform(self) -> mosaic_gpu.MemRefTransform: - return mosaic_gpu.TileTransform(self.tiling) + def to_gpu_transform(self) -> mgpu.MemRefTransform: + return mgpu.TileTransform(self.tiling) @dataclasses.dataclass(frozen=True) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 4b8199b36105..6eae64b7affa 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -39,8 +39,7 @@ from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.state import primitives as sp -from jax.experimental.mosaic import gpu as mosaic_gpu -from jax.experimental.mosaic.gpu import dsl as mgpu +import jax.experimental.mosaic.gpu as mgpu import jax.numpy as jnp import numpy as np @@ -160,7 +159,7 @@ def stack_free_smem(self, bytes: int): @dataclasses.dataclass(frozen=True) class LoweringRuleContext: module_ctx: ModuleContext - launch_ctx: mosaic_gpu.LaunchContext + launch_ctx: mgpu.LaunchContext avals_in: Sequence[jax_core.ShapedArray] avals_out: Sequence[jax_core.ShapedArray] @@ -180,7 +179,7 @@ class LoweringError(Exception): # pylint: disable=g-bad-exception-name def _eval_index_map( module_ctx: ModuleContext, - launch_ctx: mosaic_gpu.LaunchContext, + launch_ctx: mgpu.LaunchContext, idx: ir.Value, block_mapping: pallas_core.BlockMapping, ) -> Sequence[ir.Value]: @@ -300,7 +299,7 @@ def lower_jaxpr_to_module( ) ] - def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers: ir.Value): + def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): *buffers_gmem, ( buffers_smem, *scratch_buffers_smem, @@ -494,7 +493,7 @@ def _(step, _): jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8) ) - module, out_structs_smem, _ = mosaic_gpu._lower_as_gpu_kernel( + module, out_structs_smem, _ = mgpu._lower_as_gpu_kernel( body, grid=grid, cluster=(), @@ -528,7 +527,7 @@ def deco(fn): def lower_jaxpr_to_mosaic_gpu( module_ctx: ModuleContext, - launch_ctx: mosaic_gpu.LaunchContext, + launch_ctx: mgpu.LaunchContext, jaxpr: jax_core.Jaxpr, args: Sequence[ir.Value], consts=(), diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 5b09cad176a6..510d4032f3dd 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -23,7 +23,7 @@ from jax._src.interpreters import mlir from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import lowering -from jax.experimental.mosaic import gpu as mosaic_gpu +import jax.experimental.mosaic.gpu.core as mosaic_core def pallas_call_lowering( @@ -67,7 +67,7 @@ def pallas_call_lowering( print(lowering_result.module.operation) module = lowering_result.module - return mosaic_gpu._mosaic_gpu_lowering_rule( + return mosaic_core._mosaic_gpu_lowering_rule( ctx, *args, module=module.operation.get_asm(binary=True, enable_debug_info=True), diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 0e263844b18e..21c7f666b233 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -13,967 +13,52 @@ # limitations under the License. # ============================================================================== -from collections.abc import Callable, Sequence -import contextlib -import ctypes -import dataclasses -import functools -import hashlib -import itertools -import math -import os -import pathlib -import subprocess -import tempfile -import time -from typing import Any, Generic, TypeVar -import weakref - -import jax -from jax._src import config -from jax._src import core as jax_core -from jax._src.interpreters import mlir -from jax._src.lib import xla_client -from jaxlib.mlir import ir -from jaxlib.mlir.dialects import arith -from jaxlib.mlir.dialects import builtin -from jaxlib.mlir.dialects import func -from jaxlib.mlir.dialects import gpu -from jaxlib.mlir.dialects import llvm -from jaxlib.mlir.dialects import memref -from jaxlib.mlir.dialects import nvvm -from jaxlib.mlir.passmanager import PassManager -import numpy as np - -from . import profiler -from . import utils - -# mypy: ignore-errors - -# MLIR can't find libdevice unless we point it to the CUDA path -# TODO(apaszke): Unify with jax._src.lib.cuda_path -CUDA_ROOT = "/usr/local/cuda" -if os.environ.get("CUDA_ROOT") is None: - os.environ["CUDA_ROOT"] = CUDA_ROOT -else: - CUDA_ROOT = os.environ["CUDA_ROOT"] - -PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas") -NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm") - -TMA_DESCRIPTOR_BYTES = 128 -TMA_DESCRIPTOR_ALIGNMENT = 64 - - -c = utils.c # This is too common to fully qualify. - - -RUNTIME_PATH = None -try: - from jax._src.lib import mosaic_gpu as mosaic_gpu_lib - - RUNTIME_PATH = ( - pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent - / "libmosaic_gpu_runtime.so" - ) -except ImportError: - pass - -if RUNTIME_PATH and RUNTIME_PATH.exists(): - # Set this so that the custom call can find it - os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) - - -mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p") -mosaic_gpu_p.multiple_results = True - - -@mosaic_gpu_p.def_abstract_eval -def _mosaic_gpu_abstract_eval(*_, module, out_types): - del module # Unused. - return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] - -# TODO(apaszke): Implement a proper system for managing kernel lifetimes -KNOWN_KERNELS = {} - -def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types): - del out_types # Unused. - kernel_id = hashlib.sha256(module).digest() - # Note that this is technically only a half measure. Someone might load a - # compiled module with a hash collision from disk. But that's so unlikely with - # SHA256 that it shouldn't be a problem. - if (kernel_text := KNOWN_KERNELS.get(kernel_id, None)) is not None: - if kernel_text != module: - raise RuntimeError("Hash collision!") - else: - KNOWN_KERNELS[kernel_id] = module - op = mlir.custom_call( - "mosaic_gpu", - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands=args, - operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], - result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], - backend_config=kernel_id + module, - ) - return op.results - -mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda") - - -@dataclasses.dataclass(frozen=True) -class MemRefTransform: - def apply(self, ref: ir.Value) -> ir.Value: - raise NotImplementedError("Subclasses should override this method") - - def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: - raise NotImplementedError("Subclasses should override this method") - - def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: - raise NotImplementedError("Subclasses should override this method") - - -@dataclasses.dataclass(frozen=True) -class TileTransform(MemRefTransform): - """Tiles a suffix of memref dimensions. - - For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32), - the shape of the result will be (5, 2, 4, 64, 32). The shape always ends with - the tile shape, and the size of tiled dimensions is divided by the tile size. - This is especially useful for swizzled WGMMA, which expect tiled layouts in - shared memory. - """ - tiling: tuple[int, ...] - - def apply(self, ref: ir.Value) -> ir.Value: - untiled_rank = ir.MemRefType(ref.type).rank - tiling_rank = len(self.tiling) - tiled_rank = untiled_rank + tiling_rank - for t, d in zip(self.tiling[::-1], range(untiled_rank)[::-1]): - s = ir.MemRefType(ref.type).shape[d] - if s % t and s > t: - raise ValueError( - f"Dimension {d} must have size smaller or a multiple of its tiling" - f" {t}, but got {s}" - ) - ref = utils.memref_unfold(ref, d, (None, min(t, s))) - permutation = ( - *range(untiled_rank - tiling_rank), - *range(untiled_rank - tiling_rank, tiled_rank, 2), - *range(untiled_rank - tiling_rank + 1, tiled_rank, 2), - ) - return utils.memref_transpose(ref, permutation) - - def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: - index = ir.IndexType.get() - tiling_rank = len(self.tiling) - return ( - *idx[:-tiling_rank], - *( - arith.divui(i, c(t, index)) - for i, t in zip(idx[-tiling_rank:], self.tiling) - ), - *( - arith.remui(i, c(t, index)) - for i, t in zip(idx[-tiling_rank:], self.tiling) - ), - ) - - def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: - # Note that this also checks that tiled dims are not squeezed. Their slice - # size would be 1 if so. - tiling_rank = len(self.tiling) - for size, tile_size in zip(shape[-tiling_rank:], self.tiling): - if size % tile_size: - raise ValueError( - f"Expected GMEM slice shape {shape} suffix to be a multiple of" - f" tiling {self.tiling}.\nIf you're using padded async copies, your" - " slice might need to extend out of bounds of the GMEM buffer (OOB" - " accesses will be skipped)." - ) - return ( - *shape[:-tiling_rank], - *(s // t for s, t in zip(shape[-tiling_rank:], self.tiling)), - *self.tiling, - ) - - -@dataclasses.dataclass(frozen=True) -class TransposeTransform(MemRefTransform): - """Transposes memref dimensions.""" - permutation: tuple[int, ...] - - def __post_init__(self): - if len(self.permutation) != len(set(self.permutation)): - raise ValueError("Permutation must be a permutation") - - def apply(self, ref: ir.Value) -> ir.Value: - return utils.memref_transpose(ref, self.permutation) - - def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: - return tuple(idx[p] for p in self.permutation) - - def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: - return tuple(shape[p] for p in self.permutation) - - -OnDeviceProfiler = profiler.OnDeviceProfiler - - -@dataclasses.dataclass() -class LaunchContext: - launch_op: gpu.LaunchOp - gmem_scratch_ptr: ir.Value - cluster_size: tuple[int, int, int] - profiler: OnDeviceProfiler | None = None - next_scratch_offset: int = 0 - host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field( - default_factory=list, init=False - ) - tma_descriptors: dict[ - tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]], - ir.Value, - ] = dataclasses.field(default_factory=dict, init=False) - - @contextlib.contextmanager - def named_region(self, *args, **kwargs): - if self.profiler is not None: - with self.profiler.record(*args, **kwargs): - yield - else: - yield - - def _alloc_scratch( - self, - size: int, - alignment: int | None = None, - host_init: Callable[[ir.Value], None] = lambda _: None, - device_init: Callable[[ir.Value], Any] = lambda x: x, - ) -> ir.Value: - """Allocates a GMEM scratch buffer. - - The buffer is initialized on the host and then copied to GMEM before the - kernel launch. - """ - i8 = ir.IntegerType.get_signless(8) - ptr_ty = ir.Type.parse("!llvm.ptr") - if alignment is None: - alignment = size - if self.next_scratch_offset % alignment: - raise NotImplementedError # TODO(apaszke): Pad to match alignment - alloc_base = self.next_scratch_offset - self.next_scratch_offset += size - def host_init_wrapped(host_ptr): - host_init( - llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8) - ) - self.host_scratch_init.append(host_init_wrapped) - # with ir.InsertionPoint(self.gmem_scratch_ptr.owner): - # There is no way to create an insertion point after an operation... - gep = llvm.GEPOp( - ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8 - ) - gep.move_after(self.gmem_scratch_ptr.owner) - return device_init(gep.result) - - def _get_tma_desc( - self, - gmem_ref, - gmem_transform: tuple[MemRefTransform, ...], - transformed_slice_shape: tuple[int, ...], - swizzle: int | None, - ): - tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform) - if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: - i64 = ir.IntegerType.get_signless(64) - ptr_ty = ir.Type.parse("!llvm.ptr") - def init_tma_desc(host_ptr): - ref = gmem_ref - for t in gmem_transform: - ref = t.apply(ref) - ref_ty = ir.MemRefType(ref.type) - # TODO(apaszke): Use utils.memref_ptr to compute base_ptr - _, offset, *sizes_and_strides = memref.extract_strided_metadata(ref) - aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref) - as_i64 = lambda i: arith.index_cast(i64, i) - alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx)) - llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings... - base_ptr = llvm.getelementptr( - ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, - ) - rank = ref_ty.rank - assert rank * 2 == len(sizes_and_strides) - args = [ - host_ptr, - base_ptr, - c(utils.bytewidth(ref_ty.element_type), i64), - c(rank, i64), - utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]), - utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), - c(0 if swizzle is None else swizzle, i64), - utils.pack_array([c(v, i64) for v in transformed_slice_shape]), - ] - func.call([], "mosaic_gpu_init_tma_desc", args) - def cast_tma_desc(device_ptr): - # TODO(apaszke): Investigate why prefetching can cause launch failures - # nvvm.prefetch_tensormap(device_ptr) - return device_ptr - tma_desc = self._alloc_scratch( - TMA_DESCRIPTOR_BYTES, - alignment=TMA_DESCRIPTOR_ALIGNMENT, - host_init=init_tma_desc, - device_init=cast_tma_desc, - ) - self.tma_descriptors[tma_desc_key] = tma_desc - return tma_desc - - def async_copy( - self, - *, - src_ref, - dst_ref, - gmem_slice: Any = (), - gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (), - barrier: utils.BarrierRef | None = None, - swizzle: int | None = None, - arrive: bool | None = None, - uniform: bool = True, - collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, - ): - index = ir.IndexType.get() - i16 = ir.IntegerType.get_signless(16) - i32 = ir.IntegerType.get_signless(32) - smem = ir.Attribute.parse("#gpu.address_space") - src_ref_ty = ir.MemRefType(src_ref.type) - dst_ref_ty = ir.MemRefType(dst_ref.type) - element_type = src_ref_ty.element_type - element_bytewidth = utils.bytewidth(element_type) - if element_type != dst_ref_ty.element_type: - raise ValueError( - f"Expected same element type, got {element_type} and" - f" {dst_ref_ty.element_type}" - ) - if not isinstance(gmem_transform, tuple): - gmem_transform = (gmem_transform,) - - if src_ref_ty.memory_space is None and dst_ref_ty.memory_space == smem: - gmem_ref, smem_ref = src_ref, dst_ref - if barrier is None: - raise ValueError("Barriers are required for GMEM -> SMEM copies") - if arrive is None: - arrive = True # Arrive by default - elif src_ref_ty.memory_space == smem and dst_ref_ty.memory_space is None: - gmem_ref, smem_ref = dst_ref, src_ref - if barrier is not None: - raise ValueError("Barriers are unsupported for SMEM -> GMEM copies") - if arrive is not None: - raise ValueError("arrive is unsupported for SMEM -> GMEM copies") - else: - raise ValueError("Only SMEM <-> GMEM copies supported") - # TODO(apaszke): This is a very approximate check. Improve it! - expected_name = "builtin.unrealized_conversion_cast" - if ( - gmem_ref.owner is None - or gmem_ref.owner.opview.OPERATION_NAME != expected_name - ): - raise ValueError("GMEM reference in async_copy must be a kernel argument") - - base_indices, slice_shape, is_squeezed = utils.parse_indices( - gmem_slice, ir.MemRefType(gmem_ref.type).shape - ) - dyn_base_indices = tuple( - c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices - ) - slice_shape = tuple(slice_shape) - for t in gmem_transform: - dyn_base_indices = t.transform_index(dyn_base_indices) - slice_shape = t.transform_shape(slice_shape) - for dim, squeezed in enumerate(is_squeezed): - if squeezed: - smem_ref = utils.memref_unsqueeze(smem_ref, dim) - smem_ref_ty = ir.MemRefType(smem_ref.type) - - if slice_shape != tuple(smem_ref_ty.shape): - raise ValueError( - "Expected the SMEM reference to have the same shape as the" - f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}" - ) - - dyn_base_indices = list(dyn_base_indices) - slice_shape = list(slice_shape) - collective_size = 1 - if collective is not None: - if isinstance(collective, gpu.Dimension): - collective = (collective,) - collective_size = math.prod(self.cluster_size[d] for d in collective) - if collective_size > 1: - def partition_dim(dim: int, idx: ir.Value, num_chunks: int): - nonlocal smem_ref - slice_shape[dim] //= num_chunks - block_offset = arith.muli(idx, c(slice_shape[dim], index)) - dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset) - smem_ref = utils.memref_slice( - smem_ref, - (slice(None),) * dim + (utils.ds(block_offset, slice_shape[dim]),) - ) - stride = 1 - idx = c(0, index) - for d in sorted(collective): - if self.cluster_size[d] == 1: # Optimize a multiply by 0. - continue - idx = arith.addi(idx, arith.muli(gpu.cluster_block_id(d), c(stride, index))) - stride *= self.cluster_size[d] - rem_collective_size = collective_size - for dim, slice_size in enumerate(slice_shape[:-1]): - if slice_size % rem_collective_size == 0: - partition_dim(dim, idx, rem_collective_size) - rem_collective_size = 1 - break - elif rem_collective_size % slice_size == 0: - dim_idx = arith.remui(idx, c(slice_size, index)) - partition_dim(dim, dim_idx, slice_size) - idx = arith.divui(idx, c(slice_size, index)) - rem_collective_size //= slice_size - else: - break # We failed to partition the leading dimensions. - del idx # We overwrote the block index in the loop. - if rem_collective_size > 1: - raise ValueError( - "None of the leading dimensions in the transformed slice shape" - f" {slice_shape} is divisible by the collective size" - f" {collective_size}" - ) - # Make each block load a smaller slice, adjust the GMEM indices and slice - # the SMEM reference accordingly. - multicast_mask = arith.trunci( - i16, utils.cluster_collective_mask(self.cluster_size, collective) - ) - else: - multicast_mask = None - - tma_desc = self._get_tma_desc( - gmem_ref, gmem_transform, tuple(slice_shape), swizzle, - ) - - # We constuct TMA descriptors in column-major order. - rev_dyn_base_indices = [ - arith.index_cast(i32, idx) for idx in reversed(dyn_base_indices) - ] - - uniform_ctx = ( - functools.partial(utils.single_thread, per_block=False) - if uniform - else contextlib.nullcontext - ) - - rank = len(slice_shape) - if rank > 5: # TODO: apaszke - Implement stride compression - raise ValueError("Async copies only support striding up to 5 dimensions") - if max(slice_shape) > 256: - raise ValueError( - "Async copies only support copying <=256 elements along each" - " dimension" - ) - if (zeroth_bw := slice_shape[-1] * element_bytewidth) % 16 != 0: - raise ValueError( - "Async copies require the number of bytes copied along the last" - f" dimension to be divisible by 16, but got {zeroth_bw}" - ) - if swizzle is not None and slice_shape[-1] != swizzle // element_bytewidth: - raise ValueError( - f"Async copies with {swizzle=} require last dimension of the slice to" - f" be exactly {swizzle} bytes" - f" ({swizzle // element_bytewidth} elements), but got" - f" {slice_shape[-1]}" - ) - smem_ptr = utils.memref_ptr(smem_ref, memory_space=3) - if gmem_ref is src_ref: - assert barrier is not None # for pytype - transfer_bytes = c( - np.prod(slice_shape) * element_bytewidth * collective_size, i32 - ) - barrier_ptr = barrier.get_ptr() - with uniform_ctx(): - if arrive: - nvvm.mbarrier_arrive_expect_tx_shared(barrier_ptr, transfer_bytes) - nvvm.cp_async_bulk_tensor_shared_cluster_global( - smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], multicast_mask=multicast_mask, - ) - else: - with uniform_ctx(): - nvvm.cp_async_bulk_tensor_global_shared_cta( - tma_desc, smem_ptr, rev_dyn_base_indices - ) - nvvm.cp_async_bulk_commit_group() - - def await_async_copy( - self, allow_groups: int, await_read_only: bool = False - ): - nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only) - utils.warpgroup_barrier() - - -# ShapeTrees currently can not contain unions. -ShapeTree = Any -RefTree = Any -T = TypeVar('T') - - -@dataclasses.dataclass(frozen=True) -class Union(Generic[T]): - members: Sequence[T] - - def __iter__(self): - return iter(self.members) - -@dataclasses.dataclass(frozen=True) -class TMABarrier: - num_barriers: int = 1 - -@dataclasses.dataclass(frozen=True) -class Barrier: - arrival_count: int - num_barriers: int = 1 - -@dataclasses.dataclass(frozen=True) -class ClusterBarrier: - collective_dims: Sequence[gpu.Dimension] - num_barriers: int = 1 - - -def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int: - return np.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize - - -def _construct_smem_reftree( - cluster_shape: tuple[int, int, int], - dynamic_smem: ir.Value, - smem_buffers: ShapeTree, - dynamic_smem_offset: int = 0, -) -> RefTree: - index = ir.IndexType.get() - i8 = ir.IntegerType.get_signless(8) - ptr = ir.Type.parse("!llvm.ptr") - smem = ir.Attribute.parse("#gpu.address_space") - flat_ref_tys, smem_buffer_tree = jax.tree.flatten( - smem_buffers, is_leaf=lambda x: isinstance(x, Union) - ) - smem_refs = [] - for ref_ty in flat_ref_tys: - def get_barrier_ptr(num_barriers: int) -> ir.Value: - nonlocal dynamic_smem_offset - smem_base_ptr = utils.memref_ptr(dynamic_smem, memory_space=3) - barrier_base_ptr = llvm.getelementptr( - ptr, smem_base_ptr, [], [dynamic_smem_offset], i8 - ) - dynamic_smem_offset += num_barriers * MBARRIER_BYTES - return barrier_base_ptr - match ref_ty: - case Union(members): - member_trees = [ - _construct_smem_reftree(cluster_shape, dynamic_smem, m, dynamic_smem_offset) - for m in members - ] - # TODO(apaszke): This is quadratic, but it shouldn't matter for now... - dynamic_smem_offset += _smem_tree_size(ref_ty) - ref = Union(member_trees) - case TMABarrier(num_barriers): - ref = utils.BarrierRef.initialize( - get_barrier_ptr(num_barriers), num_barriers, arrival_count=1 - ) - case Barrier(arrival_count, num_barriers): - ref = utils.BarrierRef.initialize( - get_barrier_ptr(num_barriers), - num_barriers, - arrival_count=arrival_count, - ) - case ClusterBarrier(collective_dims, num_barriers): - ref = utils.CollectiveBarrierRef.initialize( - get_barrier_ptr(num_barriers), - num_barriers, - collective_dims, - cluster_shape, - ) - case _: - mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype) - tile_smem = memref.view( - ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem), - dynamic_smem, c(dynamic_smem_offset, index), [], - ) - dynamic_smem_offset += _count_buffer_bytes(ref_ty) - ref = tile_smem - smem_refs.append(ref) - return jax.tree.unflatten(smem_buffer_tree, smem_refs) - - -MBARRIER_BYTES = 8 - - -def _smem_tree_size(smem_buffers: ShapeTree) -> int: - leaves = jax.tree.leaves( - smem_buffers, is_leaf=lambda x: isinstance(x, Union) - ) - size = 0 - for l in leaves: - match l: - case Union(members): - size += max(_smem_tree_size(s) for s in members) - case ( - TMABarrier(num_barriers) - | ClusterBarrier(_, num_barriers=num_barriers) - | Barrier(_, num_barriers=num_barriers) - ): - if size % MBARRIER_BYTES: - raise NotImplementedError("Misaligned barrier allocation") - size += num_barriers * MBARRIER_BYTES - case _: - size += _count_buffer_bytes(l) - return size - - -# TODO(apaszke): Inline this -@contextlib.contextmanager -def _launch( - token, - grid: tuple[int, int, int], - cluster: tuple[int, int, int], - block: tuple[int, int, int], - scratch_arr, - smem_buffers: ShapeTree | Union[ShapeTree], - profiler_spec: profiler.ProfilerSpec | None = None, - maybe_prof_buffer: ir.Value | None = None, -): - if (profiler_spec is None) != (maybe_prof_buffer is None): - raise ValueError - index = ir.IndexType.get() - i32 = ir.IntegerType.get_signless(32) - i8 = ir.IntegerType.get_signless(8) - grid_vals = [c(i, index) for i in grid] - block_vals = [c(i, index) for i in block] - - user_smem_bytes = _smem_tree_size(smem_buffers) - - smem_bytes = user_smem_bytes - if profiler_spec is not None: - smem_bytes += profiler_spec.smem_bytes(block=block) - - # TODO(cperivol): Query the shared memory size programmatically. - if smem_bytes > 228 * 1024: - raise ValueError(f"Mosaic GPU kernel exceeds available shared memory {smem_bytes=} > 228000") - if math.prod(cluster) != 1: - if len(cluster) != 3: - raise ValueError("Clusters must be 3D") - cluster_kwargs = { - "clusterSize" + d: c(s, index) for s, d in zip(cluster, "XYZ") - } - for d, grid_size, cluster_size in zip("xyz", grid, cluster): - if grid_size % cluster_size != 0: - raise ValueError( - f"Grid dimension {d} must be divisible by cluster dimension:" - f" {grid_size} % {cluster_size} != 0" - ) - else: - cluster_kwargs = {} - launch_op = gpu.LaunchOp( - token.type, [token], *grid_vals, *block_vals, - dynamicSharedMemorySize=c(smem_bytes, i32), **cluster_kwargs) - launch_op.body.blocks.append(*([index] * (12 + 2 * len(cluster_kwargs)))) # Append an empty block - smem = ir.Attribute.parse("#gpu.address_space") - with ir.InsertionPoint(launch_op.body.blocks[0]): - dynamic_smem = gpu.dynamic_shared_memory( - ir.MemRefType.get( - (ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem - ) - ) - - smem_ref_tree = _construct_smem_reftree( - cluster, dynamic_smem, smem_buffers - ) - # TODO(apaszke): Skip the following if no barriers were initialized. - nvvm.fence_mbarrier_init() - if math.prod(cluster) != 1: - nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) - nvvm.cluster_wait(aligned=ir.UnitAttr.get()) - gpu.barrier() - - if profiler_spec: - prof_smem = memref.view( - ir.MemRefType.get( - (profiler_spec.smem_i32_elements(block=block),), - i32, memory_space=smem, - ), - dynamic_smem, c(user_smem_bytes, index), [], - ) - prof = profiler.OnDeviceProfiler( - profiler_spec, prof_smem, maybe_prof_buffer - ) - else: - prof = None - - ptr_ty = ir.Type.parse("!llvm.ptr") - scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr]) - yield LaunchContext(launch_op, scratch_ptr, cluster, prof), smem_ref_tree - if prof is not None: - prof.finalize(grid=grid, block=block) - gpu.terminator() - - -def _lower_as_gpu_kernel( - body, - grid: tuple[int, int, int], - cluster: tuple[int, int, int], - block: tuple[int, int, int], - in_shapes: tuple[Any, ...], - out_shape, - smem_scratch_shape: ShapeTree | Union[ShapeTree], - module_name: str, - prof_spec: profiler.ProfilerSpec | None = None, -): - ptr_ty = ir.Type.parse("!llvm.ptr") - token_ty = ir.Type.parse("!gpu.async.token") - i32 = ir.IntegerType.get_signless(32) - i64 = ir.IntegerType.get_signless(64) - - def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: - return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype)) - - in_ref_tys = [_shape_to_ref_ty(t) for t in in_shapes] - - unwrap_output_tuple = False - if isinstance(out_shape, list): - out_shape = tuple(out_shape) - elif not isinstance(out_shape, tuple): - out_shape = (out_shape,) - unwrap_output_tuple = True - out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape] - if prof_spec is not None: - out_shape = (*out_shape, prof_spec.jax_buffer_type(grid, block)) - out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block)) - - module = ir.Module.create() - attrs = module.operation.attributes - attrs["sym_name"] = ir.StringAttr.get(module_name) - with ir.InsertionPoint(module.body): - _declare_runtime_functions() - gmem_scratch_bytes = 0 - global_scratch = llvm.GlobalOp( - ir.Type.parse("!llvm.array<0 x i8>"), # We don't know the shape yet. - "global_scratch", - ir.Attribute.parse("#llvm.linkage"), - addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory. - ) - @func.FuncOp.from_py_func(ptr_ty, ptr_ty) - def main(token_ptr, buffers): - nonlocal gmem_scratch_bytes - token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) - arg_refs = [] - for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]): - ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty)) - arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty))) - in_refs = arg_refs[:len(in_ref_tys)] - out_refs = arg_refs[len(in_ref_tys):] - prof_buffer = out_refs.pop() if prof_spec is not None else None - empty_arr_ty = ir.Type.parse("!llvm.array<0 x i8>") - scratch_alloc = llvm.AllocaOp( - ptr_ty, c(1, i64), empty_arr_ty, alignment=TMA_DESCRIPTOR_ALIGNMENT - ) - scratch_arr = llvm.load(empty_arr_ty, scratch_alloc.result) - with _launch( - token, grid, cluster, block, scratch_arr, smem_scratch_shape, - prof_spec, prof_buffer - ) as (launch_ctx, smem_refs): - body(launch_ctx, *in_refs, *out_refs, smem_refs) - gmem_scratch_bytes = launch_ctx.next_scratch_offset - # Allocate and initialize the host buffer right before the launch. - # Note that we couldn't do that before, because we had to run the body - # to learn what the scratch contains. - with ir.InsertionPoint(scratch_arr.owner): - scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>") - scratch_alloc.elem_type = ir.TypeAttr.get(scratch_arr_ty) - scratch_arr.set_type(scratch_arr_ty) - for init_callback in launch_ctx.host_scratch_init: - init_callback(scratch_alloc.result) - main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - sym_tab = ir.SymbolTable(module.operation) - sym_tab.insert(main.func_op) - sym_tab.insert(global_scratch) - module.operation.verify() - - return module, out_shape, unwrap_output_tuple - - -def _declare_runtime_functions(): - """Declares the runtime functions that can be used by the generated code.""" - ptr_ty = ir.Type.parse("!llvm.ptr") - i64 = ir.IntegerType.get_signless(64) - arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty] - init_tma_desc_type = ir.FunctionType.get(arg_tys, []) - func.FuncOp( - "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private" - ) - memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], []) - func.FuncOp( - "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private" - ) - - -def as_gpu_kernel( - body, - grid: tuple[int, int, int], - block: tuple[int, int, int], - in_shape, - out_shape, - smem_scratch_shape: ShapeTree | Union[ShapeTree], - prof_spec: profiler.ProfilerSpec | None = None, - cluster: tuple[int, int, int] = (1, 1, 1), - module_name: str = "unknown", -): - if isinstance(in_shape, list): - in_shape = tuple(in_shape) - elif not isinstance(in_shape, tuple): - in_shape = (in_shape,) - - module, out_shape, unwrap_output_tuple = ( - _lower_as_gpu_kernel( - body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, - module_name, prof_spec - ) - ) - - expected_arg_treedef = jax.tree.structure(in_shape) - def _check_args(*args): - arg_treedef = jax.tree.structure(args) - if arg_treedef != expected_arg_treedef: - raise ValueError( - f"Invalid argument structure: expected {expected_arg_treedef}, got" - f" {arg_treedef}, ({args=})" - ) - - module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) - def bind(*args): - return mosaic_gpu_p.bind( - *args, - out_types=out_shape, - module=module_asm, - ) - - if prof_spec is not None: - @jax.jit - def prof_kernel(*args): - _check_args(*args) - *results, prof_buffer = bind(*args) - def dump_profile(prof_buffer): - out_file = os.path.join( - os.getenv("TEST_UNDECLARED_OUTPUTS_DIR"), - f"{time.time_ns()}-trace.json", - ) - try: - with open(out_file, "x") as f: - prof_spec.dump(prof_buffer, f, grid=grid, block=block) - except FileExistsError: - pass # TODO: Retry - jax.debug.callback(dump_profile, prof_buffer) - return results[0] if unwrap_output_tuple else results - return prof_kernel - else: - @jax.jit - def kernel(*args): - _check_args(*args) - results = bind(*args) - return results[0] if unwrap_output_tuple else results - return kernel - - -def as_torch_gpu_kernel( - body, - grid: tuple[int, int, int], - block: tuple[int, int, int], - in_shape, - out_shape, - smem_scratch_shape: ShapeTree | Union[ShapeTree], - prof_spec: profiler.ProfilerSpec | None = None, - cluster: tuple[int, int, int] = (1, 1, 1), - module_name: str = "unknown", -): - try: - import torch - except ImportError: - raise RuntimeError("as_torch_gpu_kernel requires PyTorch") - torch.cuda.init() # Make sure CUDA context is set up. - - if isinstance(in_shape, list): - in_shape = tuple(in_shape) - elif not isinstance(in_shape, tuple): - in_shape = (in_shape,) - - flat_out_types, out_treedef = jax.tree.flatten(out_shape) - expected_arg_treedef = jax.tree.structure(in_shape) - - module, out_shape, unwrap_output_tuple = ( - _lower_as_gpu_kernel( - body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, - module_name, prof_spec - ) - ) - - # Get our hands on the compilation and unload functions - try: - import jax_plugins.xla_cuda12 as cuda_plugin - except ImportError: - raise RuntimeError("as_torch_gpu_kernel only works with recent jaxlib builds " - "that use backend plugins") - dll = ctypes.CDLL(cuda_plugin._get_library_path()) - compile_func = dll.MosaicGpuCompile - compile_func.argtypes = [ctypes.c_void_p] - compile_func.restype = ctypes.POINTER(ctypes.c_void_p) - unload_func = dll.MosaicGpuUnload - unload_func.argtypes = [compile_func.restype] - unload_func.restype = None - - module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) - compiled = compile_func(ctypes.c_char_p(module_asm)) - if compiled is None: - raise RuntimeError("Failed to compile the module") - ctx, launch_ptr = compiled[0], compiled[1] - ctx_ptr_ptr = ctypes.pointer(ctypes.c_void_p(ctx)) - launch = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(launch_ptr) - - def as_torch_dtype(dtype): - # torch contains NumPy-compatible dtypes in its top namespace - return getattr(torch, np.dtype(dtype).name) - - def apply(*args): - flat_args, arg_treedef = jax.tree.flatten(args) - if arg_treedef != expected_arg_treedef: - raise ValueError( - f"Invalid argument structure: expected {expected_arg_treedef}, got" - f" {arg_treedef}, ({args=})" - ) - - # Construct a device pointer list like in the XLA calling convention - buffers = (ctypes.c_void_p * (arg_treedef.num_leaves + out_treedef.num_leaves))() - i = -1 # Define i in case there are no args - device = 'cuda' - for i, arg in enumerate(flat_args): - buffers[i] = arg.data_ptr() - device = arg.device - flat_outs = [] - for i, t in enumerate(flat_out_types, i + 1): - out = torch.empty(t.shape, dtype=as_torch_dtype(t.dtype), device=device) - flat_outs.append(out) - buffers[i] = out.data_ptr() - # Allocate another buffer for args of the host-side program. This is sadly - # the default MLIR calling convention. - args_ptr = (ctypes.POINTER(ctypes.c_void_p) * 3)() - args_ptr[0] = ctx_ptr_ptr - args_ptr[1] = ctypes.pointer(torch.cuda.default_stream(device)._as_parameter_) - args_ptr[2] = ctypes.cast(ctypes.pointer(ctypes.pointer(buffers)), - ctypes.POINTER(ctypes.c_void_p)) - launch(args_ptr) - return jax.tree.unflatten(out_treedef, flat_outs) - - # Unload the compiled code when the Python function is destroyed. - def unload(_): - unload_func(compiled) - apply.destructor = weakref.ref(apply, unload) - - return apply +from jax import ShapeDtypeStruct +from .core import ( + Barrier, + ClusterBarrier, + LaunchContext, + MemRefTransform, + TMABarrier, + TileTransform, + TransposeTransform, + Union, + as_gpu_kernel, +) +from .fragmented_array import ( + FragmentedArray, + FragmentedLayout, + WGMMA_LAYOUT, + WGMMA_ROW_LAYOUT, + WGStridedFragLayout, +) +from .utils import ( + BarrierRef, + CollectiveBarrierRef, + DynamicSlice, + Partition, + Partition1D, + bytewidth, + c, + commit_shared, + debug_print, + ds, + fori, + memref_fold, + memref_slice, + memref_transpose, + memref_unfold, + memref_unsqueeze, + single_thread, + thread_idx, + tile_shape, + warp_idx, + warpgroup_barrier, + warpgroup_idx, + when, +) +from .wgmma import ( + WGMMAAccumulator, + WGMMALayout, + wgmma, +) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py new file mode 100644 index 000000000000..0e263844b18e --- /dev/null +++ b/jax/experimental/mosaic/gpu/core.py @@ -0,0 +1,979 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from collections.abc import Callable, Sequence +import contextlib +import ctypes +import dataclasses +import functools +import hashlib +import itertools +import math +import os +import pathlib +import subprocess +import tempfile +import time +from typing import Any, Generic, TypeVar +import weakref + +import jax +from jax._src import config +from jax._src import core as jax_core +from jax._src.interpreters import mlir +from jax._src.lib import xla_client +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import builtin +from jaxlib.mlir.dialects import func +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import llvm +from jaxlib.mlir.dialects import memref +from jaxlib.mlir.dialects import nvvm +from jaxlib.mlir.passmanager import PassManager +import numpy as np + +from . import profiler +from . import utils + +# mypy: ignore-errors + +# MLIR can't find libdevice unless we point it to the CUDA path +# TODO(apaszke): Unify with jax._src.lib.cuda_path +CUDA_ROOT = "/usr/local/cuda" +if os.environ.get("CUDA_ROOT") is None: + os.environ["CUDA_ROOT"] = CUDA_ROOT +else: + CUDA_ROOT = os.environ["CUDA_ROOT"] + +PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas") +NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm") + +TMA_DESCRIPTOR_BYTES = 128 +TMA_DESCRIPTOR_ALIGNMENT = 64 + + +c = utils.c # This is too common to fully qualify. + + +RUNTIME_PATH = None +try: + from jax._src.lib import mosaic_gpu as mosaic_gpu_lib + + RUNTIME_PATH = ( + pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent + / "libmosaic_gpu_runtime.so" + ) +except ImportError: + pass + +if RUNTIME_PATH and RUNTIME_PATH.exists(): + # Set this so that the custom call can find it + os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) + + +mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p") +mosaic_gpu_p.multiple_results = True + + +@mosaic_gpu_p.def_abstract_eval +def _mosaic_gpu_abstract_eval(*_, module, out_types): + del module # Unused. + return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] + +# TODO(apaszke): Implement a proper system for managing kernel lifetimes +KNOWN_KERNELS = {} + +def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types): + del out_types # Unused. + kernel_id = hashlib.sha256(module).digest() + # Note that this is technically only a half measure. Someone might load a + # compiled module with a hash collision from disk. But that's so unlikely with + # SHA256 that it shouldn't be a problem. + if (kernel_text := KNOWN_KERNELS.get(kernel_id, None)) is not None: + if kernel_text != module: + raise RuntimeError("Hash collision!") + else: + KNOWN_KERNELS[kernel_id] = module + op = mlir.custom_call( + "mosaic_gpu", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=args, + operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], + result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], + backend_config=kernel_id + module, + ) + return op.results + +mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda") + + +@dataclasses.dataclass(frozen=True) +class MemRefTransform: + def apply(self, ref: ir.Value) -> ir.Value: + raise NotImplementedError("Subclasses should override this method") + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + raise NotImplementedError("Subclasses should override this method") + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + raise NotImplementedError("Subclasses should override this method") + + +@dataclasses.dataclass(frozen=True) +class TileTransform(MemRefTransform): + """Tiles a suffix of memref dimensions. + + For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32), + the shape of the result will be (5, 2, 4, 64, 32). The shape always ends with + the tile shape, and the size of tiled dimensions is divided by the tile size. + This is especially useful for swizzled WGMMA, which expect tiled layouts in + shared memory. + """ + tiling: tuple[int, ...] + + def apply(self, ref: ir.Value) -> ir.Value: + untiled_rank = ir.MemRefType(ref.type).rank + tiling_rank = len(self.tiling) + tiled_rank = untiled_rank + tiling_rank + for t, d in zip(self.tiling[::-1], range(untiled_rank)[::-1]): + s = ir.MemRefType(ref.type).shape[d] + if s % t and s > t: + raise ValueError( + f"Dimension {d} must have size smaller or a multiple of its tiling" + f" {t}, but got {s}" + ) + ref = utils.memref_unfold(ref, d, (None, min(t, s))) + permutation = ( + *range(untiled_rank - tiling_rank), + *range(untiled_rank - tiling_rank, tiled_rank, 2), + *range(untiled_rank - tiling_rank + 1, tiled_rank, 2), + ) + return utils.memref_transpose(ref, permutation) + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + index = ir.IndexType.get() + tiling_rank = len(self.tiling) + return ( + *idx[:-tiling_rank], + *( + arith.divui(i, c(t, index)) + for i, t in zip(idx[-tiling_rank:], self.tiling) + ), + *( + arith.remui(i, c(t, index)) + for i, t in zip(idx[-tiling_rank:], self.tiling) + ), + ) + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + # Note that this also checks that tiled dims are not squeezed. Their slice + # size would be 1 if so. + tiling_rank = len(self.tiling) + for size, tile_size in zip(shape[-tiling_rank:], self.tiling): + if size % tile_size: + raise ValueError( + f"Expected GMEM slice shape {shape} suffix to be a multiple of" + f" tiling {self.tiling}.\nIf you're using padded async copies, your" + " slice might need to extend out of bounds of the GMEM buffer (OOB" + " accesses will be skipped)." + ) + return ( + *shape[:-tiling_rank], + *(s // t for s, t in zip(shape[-tiling_rank:], self.tiling)), + *self.tiling, + ) + + +@dataclasses.dataclass(frozen=True) +class TransposeTransform(MemRefTransform): + """Transposes memref dimensions.""" + permutation: tuple[int, ...] + + def __post_init__(self): + if len(self.permutation) != len(set(self.permutation)): + raise ValueError("Permutation must be a permutation") + + def apply(self, ref: ir.Value) -> ir.Value: + return utils.memref_transpose(ref, self.permutation) + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + return tuple(idx[p] for p in self.permutation) + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + return tuple(shape[p] for p in self.permutation) + + +OnDeviceProfiler = profiler.OnDeviceProfiler + + +@dataclasses.dataclass() +class LaunchContext: + launch_op: gpu.LaunchOp + gmem_scratch_ptr: ir.Value + cluster_size: tuple[int, int, int] + profiler: OnDeviceProfiler | None = None + next_scratch_offset: int = 0 + host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field( + default_factory=list, init=False + ) + tma_descriptors: dict[ + tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]], + ir.Value, + ] = dataclasses.field(default_factory=dict, init=False) + + @contextlib.contextmanager + def named_region(self, *args, **kwargs): + if self.profiler is not None: + with self.profiler.record(*args, **kwargs): + yield + else: + yield + + def _alloc_scratch( + self, + size: int, + alignment: int | None = None, + host_init: Callable[[ir.Value], None] = lambda _: None, + device_init: Callable[[ir.Value], Any] = lambda x: x, + ) -> ir.Value: + """Allocates a GMEM scratch buffer. + + The buffer is initialized on the host and then copied to GMEM before the + kernel launch. + """ + i8 = ir.IntegerType.get_signless(8) + ptr_ty = ir.Type.parse("!llvm.ptr") + if alignment is None: + alignment = size + if self.next_scratch_offset % alignment: + raise NotImplementedError # TODO(apaszke): Pad to match alignment + alloc_base = self.next_scratch_offset + self.next_scratch_offset += size + def host_init_wrapped(host_ptr): + host_init( + llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8) + ) + self.host_scratch_init.append(host_init_wrapped) + # with ir.InsertionPoint(self.gmem_scratch_ptr.owner): + # There is no way to create an insertion point after an operation... + gep = llvm.GEPOp( + ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8 + ) + gep.move_after(self.gmem_scratch_ptr.owner) + return device_init(gep.result) + + def _get_tma_desc( + self, + gmem_ref, + gmem_transform: tuple[MemRefTransform, ...], + transformed_slice_shape: tuple[int, ...], + swizzle: int | None, + ): + tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform) + if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: + i64 = ir.IntegerType.get_signless(64) + ptr_ty = ir.Type.parse("!llvm.ptr") + def init_tma_desc(host_ptr): + ref = gmem_ref + for t in gmem_transform: + ref = t.apply(ref) + ref_ty = ir.MemRefType(ref.type) + # TODO(apaszke): Use utils.memref_ptr to compute base_ptr + _, offset, *sizes_and_strides = memref.extract_strided_metadata(ref) + aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref) + as_i64 = lambda i: arith.index_cast(i64, i) + alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx)) + llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings... + base_ptr = llvm.getelementptr( + ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, + ) + rank = ref_ty.rank + assert rank * 2 == len(sizes_and_strides) + args = [ + host_ptr, + base_ptr, + c(utils.bytewidth(ref_ty.element_type), i64), + c(rank, i64), + utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]), + utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), + c(0 if swizzle is None else swizzle, i64), + utils.pack_array([c(v, i64) for v in transformed_slice_shape]), + ] + func.call([], "mosaic_gpu_init_tma_desc", args) + def cast_tma_desc(device_ptr): + # TODO(apaszke): Investigate why prefetching can cause launch failures + # nvvm.prefetch_tensormap(device_ptr) + return device_ptr + tma_desc = self._alloc_scratch( + TMA_DESCRIPTOR_BYTES, + alignment=TMA_DESCRIPTOR_ALIGNMENT, + host_init=init_tma_desc, + device_init=cast_tma_desc, + ) + self.tma_descriptors[tma_desc_key] = tma_desc + return tma_desc + + def async_copy( + self, + *, + src_ref, + dst_ref, + gmem_slice: Any = (), + gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (), + barrier: utils.BarrierRef | None = None, + swizzle: int | None = None, + arrive: bool | None = None, + uniform: bool = True, + collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, + ): + index = ir.IndexType.get() + i16 = ir.IntegerType.get_signless(16) + i32 = ir.IntegerType.get_signless(32) + smem = ir.Attribute.parse("#gpu.address_space") + src_ref_ty = ir.MemRefType(src_ref.type) + dst_ref_ty = ir.MemRefType(dst_ref.type) + element_type = src_ref_ty.element_type + element_bytewidth = utils.bytewidth(element_type) + if element_type != dst_ref_ty.element_type: + raise ValueError( + f"Expected same element type, got {element_type} and" + f" {dst_ref_ty.element_type}" + ) + if not isinstance(gmem_transform, tuple): + gmem_transform = (gmem_transform,) + + if src_ref_ty.memory_space is None and dst_ref_ty.memory_space == smem: + gmem_ref, smem_ref = src_ref, dst_ref + if barrier is None: + raise ValueError("Barriers are required for GMEM -> SMEM copies") + if arrive is None: + arrive = True # Arrive by default + elif src_ref_ty.memory_space == smem and dst_ref_ty.memory_space is None: + gmem_ref, smem_ref = dst_ref, src_ref + if barrier is not None: + raise ValueError("Barriers are unsupported for SMEM -> GMEM copies") + if arrive is not None: + raise ValueError("arrive is unsupported for SMEM -> GMEM copies") + else: + raise ValueError("Only SMEM <-> GMEM copies supported") + # TODO(apaszke): This is a very approximate check. Improve it! + expected_name = "builtin.unrealized_conversion_cast" + if ( + gmem_ref.owner is None + or gmem_ref.owner.opview.OPERATION_NAME != expected_name + ): + raise ValueError("GMEM reference in async_copy must be a kernel argument") + + base_indices, slice_shape, is_squeezed = utils.parse_indices( + gmem_slice, ir.MemRefType(gmem_ref.type).shape + ) + dyn_base_indices = tuple( + c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices + ) + slice_shape = tuple(slice_shape) + for t in gmem_transform: + dyn_base_indices = t.transform_index(dyn_base_indices) + slice_shape = t.transform_shape(slice_shape) + for dim, squeezed in enumerate(is_squeezed): + if squeezed: + smem_ref = utils.memref_unsqueeze(smem_ref, dim) + smem_ref_ty = ir.MemRefType(smem_ref.type) + + if slice_shape != tuple(smem_ref_ty.shape): + raise ValueError( + "Expected the SMEM reference to have the same shape as the" + f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}" + ) + + dyn_base_indices = list(dyn_base_indices) + slice_shape = list(slice_shape) + collective_size = 1 + if collective is not None: + if isinstance(collective, gpu.Dimension): + collective = (collective,) + collective_size = math.prod(self.cluster_size[d] for d in collective) + if collective_size > 1: + def partition_dim(dim: int, idx: ir.Value, num_chunks: int): + nonlocal smem_ref + slice_shape[dim] //= num_chunks + block_offset = arith.muli(idx, c(slice_shape[dim], index)) + dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset) + smem_ref = utils.memref_slice( + smem_ref, + (slice(None),) * dim + (utils.ds(block_offset, slice_shape[dim]),) + ) + stride = 1 + idx = c(0, index) + for d in sorted(collective): + if self.cluster_size[d] == 1: # Optimize a multiply by 0. + continue + idx = arith.addi(idx, arith.muli(gpu.cluster_block_id(d), c(stride, index))) + stride *= self.cluster_size[d] + rem_collective_size = collective_size + for dim, slice_size in enumerate(slice_shape[:-1]): + if slice_size % rem_collective_size == 0: + partition_dim(dim, idx, rem_collective_size) + rem_collective_size = 1 + break + elif rem_collective_size % slice_size == 0: + dim_idx = arith.remui(idx, c(slice_size, index)) + partition_dim(dim, dim_idx, slice_size) + idx = arith.divui(idx, c(slice_size, index)) + rem_collective_size //= slice_size + else: + break # We failed to partition the leading dimensions. + del idx # We overwrote the block index in the loop. + if rem_collective_size > 1: + raise ValueError( + "None of the leading dimensions in the transformed slice shape" + f" {slice_shape} is divisible by the collective size" + f" {collective_size}" + ) + # Make each block load a smaller slice, adjust the GMEM indices and slice + # the SMEM reference accordingly. + multicast_mask = arith.trunci( + i16, utils.cluster_collective_mask(self.cluster_size, collective) + ) + else: + multicast_mask = None + + tma_desc = self._get_tma_desc( + gmem_ref, gmem_transform, tuple(slice_shape), swizzle, + ) + + # We constuct TMA descriptors in column-major order. + rev_dyn_base_indices = [ + arith.index_cast(i32, idx) for idx in reversed(dyn_base_indices) + ] + + uniform_ctx = ( + functools.partial(utils.single_thread, per_block=False) + if uniform + else contextlib.nullcontext + ) + + rank = len(slice_shape) + if rank > 5: # TODO: apaszke - Implement stride compression + raise ValueError("Async copies only support striding up to 5 dimensions") + if max(slice_shape) > 256: + raise ValueError( + "Async copies only support copying <=256 elements along each" + " dimension" + ) + if (zeroth_bw := slice_shape[-1] * element_bytewidth) % 16 != 0: + raise ValueError( + "Async copies require the number of bytes copied along the last" + f" dimension to be divisible by 16, but got {zeroth_bw}" + ) + if swizzle is not None and slice_shape[-1] != swizzle // element_bytewidth: + raise ValueError( + f"Async copies with {swizzle=} require last dimension of the slice to" + f" be exactly {swizzle} bytes" + f" ({swizzle // element_bytewidth} elements), but got" + f" {slice_shape[-1]}" + ) + smem_ptr = utils.memref_ptr(smem_ref, memory_space=3) + if gmem_ref is src_ref: + assert barrier is not None # for pytype + transfer_bytes = c( + np.prod(slice_shape) * element_bytewidth * collective_size, i32 + ) + barrier_ptr = barrier.get_ptr() + with uniform_ctx(): + if arrive: + nvvm.mbarrier_arrive_expect_tx_shared(barrier_ptr, transfer_bytes) + nvvm.cp_async_bulk_tensor_shared_cluster_global( + smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], multicast_mask=multicast_mask, + ) + else: + with uniform_ctx(): + nvvm.cp_async_bulk_tensor_global_shared_cta( + tma_desc, smem_ptr, rev_dyn_base_indices + ) + nvvm.cp_async_bulk_commit_group() + + def await_async_copy( + self, allow_groups: int, await_read_only: bool = False + ): + nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only) + utils.warpgroup_barrier() + + +# ShapeTrees currently can not contain unions. +ShapeTree = Any +RefTree = Any +T = TypeVar('T') + + +@dataclasses.dataclass(frozen=True) +class Union(Generic[T]): + members: Sequence[T] + + def __iter__(self): + return iter(self.members) + +@dataclasses.dataclass(frozen=True) +class TMABarrier: + num_barriers: int = 1 + +@dataclasses.dataclass(frozen=True) +class Barrier: + arrival_count: int + num_barriers: int = 1 + +@dataclasses.dataclass(frozen=True) +class ClusterBarrier: + collective_dims: Sequence[gpu.Dimension] + num_barriers: int = 1 + + +def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int: + return np.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize + + +def _construct_smem_reftree( + cluster_shape: tuple[int, int, int], + dynamic_smem: ir.Value, + smem_buffers: ShapeTree, + dynamic_smem_offset: int = 0, +) -> RefTree: + index = ir.IndexType.get() + i8 = ir.IntegerType.get_signless(8) + ptr = ir.Type.parse("!llvm.ptr") + smem = ir.Attribute.parse("#gpu.address_space") + flat_ref_tys, smem_buffer_tree = jax.tree.flatten( + smem_buffers, is_leaf=lambda x: isinstance(x, Union) + ) + smem_refs = [] + for ref_ty in flat_ref_tys: + def get_barrier_ptr(num_barriers: int) -> ir.Value: + nonlocal dynamic_smem_offset + smem_base_ptr = utils.memref_ptr(dynamic_smem, memory_space=3) + barrier_base_ptr = llvm.getelementptr( + ptr, smem_base_ptr, [], [dynamic_smem_offset], i8 + ) + dynamic_smem_offset += num_barriers * MBARRIER_BYTES + return barrier_base_ptr + match ref_ty: + case Union(members): + member_trees = [ + _construct_smem_reftree(cluster_shape, dynamic_smem, m, dynamic_smem_offset) + for m in members + ] + # TODO(apaszke): This is quadratic, but it shouldn't matter for now... + dynamic_smem_offset += _smem_tree_size(ref_ty) + ref = Union(member_trees) + case TMABarrier(num_barriers): + ref = utils.BarrierRef.initialize( + get_barrier_ptr(num_barriers), num_barriers, arrival_count=1 + ) + case Barrier(arrival_count, num_barriers): + ref = utils.BarrierRef.initialize( + get_barrier_ptr(num_barriers), + num_barriers, + arrival_count=arrival_count, + ) + case ClusterBarrier(collective_dims, num_barriers): + ref = utils.CollectiveBarrierRef.initialize( + get_barrier_ptr(num_barriers), + num_barriers, + collective_dims, + cluster_shape, + ) + case _: + mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype) + tile_smem = memref.view( + ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem), + dynamic_smem, c(dynamic_smem_offset, index), [], + ) + dynamic_smem_offset += _count_buffer_bytes(ref_ty) + ref = tile_smem + smem_refs.append(ref) + return jax.tree.unflatten(smem_buffer_tree, smem_refs) + + +MBARRIER_BYTES = 8 + + +def _smem_tree_size(smem_buffers: ShapeTree) -> int: + leaves = jax.tree.leaves( + smem_buffers, is_leaf=lambda x: isinstance(x, Union) + ) + size = 0 + for l in leaves: + match l: + case Union(members): + size += max(_smem_tree_size(s) for s in members) + case ( + TMABarrier(num_barriers) + | ClusterBarrier(_, num_barriers=num_barriers) + | Barrier(_, num_barriers=num_barriers) + ): + if size % MBARRIER_BYTES: + raise NotImplementedError("Misaligned barrier allocation") + size += num_barriers * MBARRIER_BYTES + case _: + size += _count_buffer_bytes(l) + return size + + +# TODO(apaszke): Inline this +@contextlib.contextmanager +def _launch( + token, + grid: tuple[int, int, int], + cluster: tuple[int, int, int], + block: tuple[int, int, int], + scratch_arr, + smem_buffers: ShapeTree | Union[ShapeTree], + profiler_spec: profiler.ProfilerSpec | None = None, + maybe_prof_buffer: ir.Value | None = None, +): + if (profiler_spec is None) != (maybe_prof_buffer is None): + raise ValueError + index = ir.IndexType.get() + i32 = ir.IntegerType.get_signless(32) + i8 = ir.IntegerType.get_signless(8) + grid_vals = [c(i, index) for i in grid] + block_vals = [c(i, index) for i in block] + + user_smem_bytes = _smem_tree_size(smem_buffers) + + smem_bytes = user_smem_bytes + if profiler_spec is not None: + smem_bytes += profiler_spec.smem_bytes(block=block) + + # TODO(cperivol): Query the shared memory size programmatically. + if smem_bytes > 228 * 1024: + raise ValueError(f"Mosaic GPU kernel exceeds available shared memory {smem_bytes=} > 228000") + if math.prod(cluster) != 1: + if len(cluster) != 3: + raise ValueError("Clusters must be 3D") + cluster_kwargs = { + "clusterSize" + d: c(s, index) for s, d in zip(cluster, "XYZ") + } + for d, grid_size, cluster_size in zip("xyz", grid, cluster): + if grid_size % cluster_size != 0: + raise ValueError( + f"Grid dimension {d} must be divisible by cluster dimension:" + f" {grid_size} % {cluster_size} != 0" + ) + else: + cluster_kwargs = {} + launch_op = gpu.LaunchOp( + token.type, [token], *grid_vals, *block_vals, + dynamicSharedMemorySize=c(smem_bytes, i32), **cluster_kwargs) + launch_op.body.blocks.append(*([index] * (12 + 2 * len(cluster_kwargs)))) # Append an empty block + smem = ir.Attribute.parse("#gpu.address_space") + with ir.InsertionPoint(launch_op.body.blocks[0]): + dynamic_smem = gpu.dynamic_shared_memory( + ir.MemRefType.get( + (ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem + ) + ) + + smem_ref_tree = _construct_smem_reftree( + cluster, dynamic_smem, smem_buffers + ) + # TODO(apaszke): Skip the following if no barriers were initialized. + nvvm.fence_mbarrier_init() + if math.prod(cluster) != 1: + nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) + nvvm.cluster_wait(aligned=ir.UnitAttr.get()) + gpu.barrier() + + if profiler_spec: + prof_smem = memref.view( + ir.MemRefType.get( + (profiler_spec.smem_i32_elements(block=block),), + i32, memory_space=smem, + ), + dynamic_smem, c(user_smem_bytes, index), [], + ) + prof = profiler.OnDeviceProfiler( + profiler_spec, prof_smem, maybe_prof_buffer + ) + else: + prof = None + + ptr_ty = ir.Type.parse("!llvm.ptr") + scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr]) + yield LaunchContext(launch_op, scratch_ptr, cluster, prof), smem_ref_tree + if prof is not None: + prof.finalize(grid=grid, block=block) + gpu.terminator() + + +def _lower_as_gpu_kernel( + body, + grid: tuple[int, int, int], + cluster: tuple[int, int, int], + block: tuple[int, int, int], + in_shapes: tuple[Any, ...], + out_shape, + smem_scratch_shape: ShapeTree | Union[ShapeTree], + module_name: str, + prof_spec: profiler.ProfilerSpec | None = None, +): + ptr_ty = ir.Type.parse("!llvm.ptr") + token_ty = ir.Type.parse("!gpu.async.token") + i32 = ir.IntegerType.get_signless(32) + i64 = ir.IntegerType.get_signless(64) + + def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: + return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype)) + + in_ref_tys = [_shape_to_ref_ty(t) for t in in_shapes] + + unwrap_output_tuple = False + if isinstance(out_shape, list): + out_shape = tuple(out_shape) + elif not isinstance(out_shape, tuple): + out_shape = (out_shape,) + unwrap_output_tuple = True + out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape] + if prof_spec is not None: + out_shape = (*out_shape, prof_spec.jax_buffer_type(grid, block)) + out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block)) + + module = ir.Module.create() + attrs = module.operation.attributes + attrs["sym_name"] = ir.StringAttr.get(module_name) + with ir.InsertionPoint(module.body): + _declare_runtime_functions() + gmem_scratch_bytes = 0 + global_scratch = llvm.GlobalOp( + ir.Type.parse("!llvm.array<0 x i8>"), # We don't know the shape yet. + "global_scratch", + ir.Attribute.parse("#llvm.linkage"), + addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory. + ) + @func.FuncOp.from_py_func(ptr_ty, ptr_ty) + def main(token_ptr, buffers): + nonlocal gmem_scratch_bytes + token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) + arg_refs = [] + for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]): + ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty)) + arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty))) + in_refs = arg_refs[:len(in_ref_tys)] + out_refs = arg_refs[len(in_ref_tys):] + prof_buffer = out_refs.pop() if prof_spec is not None else None + empty_arr_ty = ir.Type.parse("!llvm.array<0 x i8>") + scratch_alloc = llvm.AllocaOp( + ptr_ty, c(1, i64), empty_arr_ty, alignment=TMA_DESCRIPTOR_ALIGNMENT + ) + scratch_arr = llvm.load(empty_arr_ty, scratch_alloc.result) + with _launch( + token, grid, cluster, block, scratch_arr, smem_scratch_shape, + prof_spec, prof_buffer + ) as (launch_ctx, smem_refs): + body(launch_ctx, *in_refs, *out_refs, smem_refs) + gmem_scratch_bytes = launch_ctx.next_scratch_offset + # Allocate and initialize the host buffer right before the launch. + # Note that we couldn't do that before, because we had to run the body + # to learn what the scratch contains. + with ir.InsertionPoint(scratch_arr.owner): + scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>") + scratch_alloc.elem_type = ir.TypeAttr.get(scratch_arr_ty) + scratch_arr.set_type(scratch_arr_ty) + for init_callback in launch_ctx.host_scratch_init: + init_callback(scratch_alloc.result) + main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + sym_tab = ir.SymbolTable(module.operation) + sym_tab.insert(main.func_op) + sym_tab.insert(global_scratch) + module.operation.verify() + + return module, out_shape, unwrap_output_tuple + + +def _declare_runtime_functions(): + """Declares the runtime functions that can be used by the generated code.""" + ptr_ty = ir.Type.parse("!llvm.ptr") + i64 = ir.IntegerType.get_signless(64) + arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty] + init_tma_desc_type = ir.FunctionType.get(arg_tys, []) + func.FuncOp( + "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private" + ) + memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], []) + func.FuncOp( + "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private" + ) + + +def as_gpu_kernel( + body, + grid: tuple[int, int, int], + block: tuple[int, int, int], + in_shape, + out_shape, + smem_scratch_shape: ShapeTree | Union[ShapeTree], + prof_spec: profiler.ProfilerSpec | None = None, + cluster: tuple[int, int, int] = (1, 1, 1), + module_name: str = "unknown", +): + if isinstance(in_shape, list): + in_shape = tuple(in_shape) + elif not isinstance(in_shape, tuple): + in_shape = (in_shape,) + + module, out_shape, unwrap_output_tuple = ( + _lower_as_gpu_kernel( + body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, + module_name, prof_spec + ) + ) + + expected_arg_treedef = jax.tree.structure(in_shape) + def _check_args(*args): + arg_treedef = jax.tree.structure(args) + if arg_treedef != expected_arg_treedef: + raise ValueError( + f"Invalid argument structure: expected {expected_arg_treedef}, got" + f" {arg_treedef}, ({args=})" + ) + + module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) + def bind(*args): + return mosaic_gpu_p.bind( + *args, + out_types=out_shape, + module=module_asm, + ) + + if prof_spec is not None: + @jax.jit + def prof_kernel(*args): + _check_args(*args) + *results, prof_buffer = bind(*args) + def dump_profile(prof_buffer): + out_file = os.path.join( + os.getenv("TEST_UNDECLARED_OUTPUTS_DIR"), + f"{time.time_ns()}-trace.json", + ) + try: + with open(out_file, "x") as f: + prof_spec.dump(prof_buffer, f, grid=grid, block=block) + except FileExistsError: + pass # TODO: Retry + jax.debug.callback(dump_profile, prof_buffer) + return results[0] if unwrap_output_tuple else results + return prof_kernel + else: + @jax.jit + def kernel(*args): + _check_args(*args) + results = bind(*args) + return results[0] if unwrap_output_tuple else results + return kernel + + +def as_torch_gpu_kernel( + body, + grid: tuple[int, int, int], + block: tuple[int, int, int], + in_shape, + out_shape, + smem_scratch_shape: ShapeTree | Union[ShapeTree], + prof_spec: profiler.ProfilerSpec | None = None, + cluster: tuple[int, int, int] = (1, 1, 1), + module_name: str = "unknown", +): + try: + import torch + except ImportError: + raise RuntimeError("as_torch_gpu_kernel requires PyTorch") + torch.cuda.init() # Make sure CUDA context is set up. + + if isinstance(in_shape, list): + in_shape = tuple(in_shape) + elif not isinstance(in_shape, tuple): + in_shape = (in_shape,) + + flat_out_types, out_treedef = jax.tree.flatten(out_shape) + expected_arg_treedef = jax.tree.structure(in_shape) + + module, out_shape, unwrap_output_tuple = ( + _lower_as_gpu_kernel( + body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, + module_name, prof_spec + ) + ) + + # Get our hands on the compilation and unload functions + try: + import jax_plugins.xla_cuda12 as cuda_plugin + except ImportError: + raise RuntimeError("as_torch_gpu_kernel only works with recent jaxlib builds " + "that use backend plugins") + dll = ctypes.CDLL(cuda_plugin._get_library_path()) + compile_func = dll.MosaicGpuCompile + compile_func.argtypes = [ctypes.c_void_p] + compile_func.restype = ctypes.POINTER(ctypes.c_void_p) + unload_func = dll.MosaicGpuUnload + unload_func.argtypes = [compile_func.restype] + unload_func.restype = None + + module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) + compiled = compile_func(ctypes.c_char_p(module_asm)) + if compiled is None: + raise RuntimeError("Failed to compile the module") + ctx, launch_ptr = compiled[0], compiled[1] + ctx_ptr_ptr = ctypes.pointer(ctypes.c_void_p(ctx)) + launch = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(launch_ptr) + + def as_torch_dtype(dtype): + # torch contains NumPy-compatible dtypes in its top namespace + return getattr(torch, np.dtype(dtype).name) + + def apply(*args): + flat_args, arg_treedef = jax.tree.flatten(args) + if arg_treedef != expected_arg_treedef: + raise ValueError( + f"Invalid argument structure: expected {expected_arg_treedef}, got" + f" {arg_treedef}, ({args=})" + ) + + # Construct a device pointer list like in the XLA calling convention + buffers = (ctypes.c_void_p * (arg_treedef.num_leaves + out_treedef.num_leaves))() + i = -1 # Define i in case there are no args + device = 'cuda' + for i, arg in enumerate(flat_args): + buffers[i] = arg.data_ptr() + device = arg.device + flat_outs = [] + for i, t in enumerate(flat_out_types, i + 1): + out = torch.empty(t.shape, dtype=as_torch_dtype(t.dtype), device=device) + flat_outs.append(out) + buffers[i] = out.data_ptr() + # Allocate another buffer for args of the host-side program. This is sadly + # the default MLIR calling convention. + args_ptr = (ctypes.POINTER(ctypes.c_void_p) * 3)() + args_ptr[0] = ctx_ptr_ptr + args_ptr[1] = ctypes.pointer(torch.cuda.default_stream(device)._as_parameter_) + args_ptr[2] = ctypes.cast(ctypes.pointer(ctypes.pointer(buffers)), + ctypes.POINTER(ctypes.c_void_p)) + launch(args_ptr) + return jax.tree.unflatten(out_treedef, flat_outs) + + # Unload the compiled code when the Python function is destroyed. + def unload(_): + unload_func(compiled) + apply.destructor = weakref.ref(apply, unload) + + return apply diff --git a/jax/experimental/mosaic/gpu/dsl.py b/jax/experimental/mosaic/gpu/dsl.py deleted file mode 100644 index a12e5bc18803..000000000000 --- a/jax/experimental/mosaic/gpu/dsl.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2024 The JAX Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from . import ( - Barrier, - ClusterBarrier, - TMABarrier, - Union, -) -from .fragmented_array import ( - FragmentedArray, - FragmentedLayout, - WGMMA_LAYOUT, - WGMMA_ROW_LAYOUT, - WGStridedFragLayout, -) -from .utils import ( - BarrierRef, - CollectiveBarrierRef, - DynamicSlice, - Partition, - Partition1D, - bytewidth, - c, - commit_shared, - debug_print, - ds, - fori, - memref_fold, - memref_slice, - memref_transpose, - memref_unfold, - memref_unsqueeze, - single_thread, - thread_idx, - tile_shape, - warp_idx, - warpgroup_barrier, - warpgroup_idx, - when, -) -from .wgmma import ( - WGMMAAccumulator, - WGMMALayout, - wgmma, -) diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index a9a533ca361c..99586875ae90 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -24,9 +24,8 @@ from jax import random from jax._src.interpreters import mlir from jax._src import test_util as jtu -from jax.experimental.mosaic import gpu as mosaic_gpu from jax.experimental.mosaic.gpu import profiler -from jax.experimental.mosaic.gpu.dsl import * # noqa: F403 +from jax.experimental.mosaic.gpu import * # noqa: F403 import jax.numpy as jnp from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith @@ -144,7 +143,7 @@ def c(value, ty=index): return _utils_c(value, ty) def tma_wg_kernel( - ctx: mosaic_gpu.LaunchContext, + ctx: LaunchContext, q_gmem, k_gmem, v_gmem, @@ -190,7 +189,7 @@ def only_wg(idx): ctx.async_copy( src_ref=q_gmem, gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=TileTransform(tiling), dst_ref=qo_smem, barrier=q_barriers[wg_idx], swizzle=128, @@ -294,7 +293,7 @@ def kv_loop(kv_step, carry): src_ref=qo_smem, dst_ref=out_gmem, gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=TileTransform(tiling), swizzle=128, ) ctx.await_async_copy(0) @@ -304,10 +303,9 @@ def kv_loop(kv_step, carry): nvvm.setmaxregister(40, nvvm.SetMaxRegisterAction.decrease) with single_thread(per_block=False): k_tr = ( - mosaic_gpu.TileTransform(tiling), - mosaic_gpu.TransposeTransform((0, 2, 1, 3, 4)), + TileTransform(tiling), TransposeTransform((0, 2, 1, 3, 4)), ) - v_tr = mosaic_gpu.TileTransform(tiling) + v_tr = TileTransform(tiling) kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head)) def start_kv_copy(slot, kv_seq_base, smem, gmem, barrier, transform): ctx.async_copy( @@ -350,7 +348,7 @@ def _kv_loop_memory(i, _): scf.yield_([]) def compute_only_kernel( - ctx: mosaic_gpu.LaunchContext, + ctx: LaunchContext, q_gmem, k_gmem, v_gmem, @@ -388,7 +386,7 @@ def only_wg(idx): ctx.async_copy( src_ref=q_gmem, gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=TileTransform(tiling), dst_ref=qo_smem, barrier=barriers[q_barrier], swizzle=128, @@ -401,10 +399,10 @@ def kv_copy_init(slot, kv_seq_base): txcount = 2 * blocks.kv * head_dim * bytewidth(f16) barriers[slot].arrive_expect_tx(txcount) k_tr = ( - mosaic_gpu.TileTransform(tiling), - mosaic_gpu.TransposeTransform((0, 2, 1, 3, 4)), + TileTransform(tiling), + TransposeTransform((0, 2, 1, 3, 4)), ) - v_tr = mosaic_gpu.TileTransform(tiling) + v_tr = TileTransform(tiling) for smem, gmem, t in ((k_smem, k_gmem, k_tr), (v_smem, v_gmem, v_tr)): ctx.async_copy( dst_ref=memref_slice(smem, slot), @@ -526,7 +524,7 @@ def kv_loop(kv_step, carry): src_ref=qo_smem, dst_ref=out_gmem, gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=TileTransform(tiling), swizzle=128, ) ctx.await_async_copy(0) @@ -551,7 +549,7 @@ def kv_loop(kv_step, carry): Barrier(arrival_count=256, num_barriers=2), Barrier(arrival_count=256, num_barriers=1), ) - return mosaic_gpu.as_gpu_kernel( + return as_gpu_kernel( kernel, grid, block, in_shape, out_shape, smem_scratch_shape, prof_spec ) diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index 775b7c2ea898..c56c5cd6b982 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -22,17 +22,14 @@ import jax from jax import random from jax._src.interpreters import mlir -from jax.experimental.mosaic import gpu as mosaic_gpu from jax.experimental.mosaic.gpu import profiler -from jax.experimental.mosaic.gpu.dsl import * # noqa: F403 +from jax.experimental.mosaic.gpu import * # noqa: F403 import jax.numpy as jnp from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import gpu -from jaxlib.mlir.dialects import memref from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import scf -from jaxlib.mlir.dialects import vector import numpy as np # mypy: ignore-errors @@ -190,7 +187,7 @@ def safe_div(x, y): wgmma_impl.smem_shape_extra(block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose), ) epilogue_scratch_shape = jax.ShapeDtypeStruct(out_tile.shape, out_tile.dtype) - smem_shape = mosaic_gpu.Union([compute_scratch_shape, epilogue_scratch_shape]) + smem_shape = Union([compute_scratch_shape, epilogue_scratch_shape]) def _main(ctx, a_device, b_device, c_device, smem): ((lhs_smem, rhs_smem, impl_smem), epilogue_smem), *barriers = smem @@ -218,15 +215,15 @@ def fetch(slot, ki): src_ref=a_device, dst_ref=memref_slice(lhs_smem, slot), gmem_slice=(ds(m_start, block_tiling.m), ds(k_start, block_tiling.k)), - gmem_transform=mosaic_gpu.TileTransform(tma_tiling.mk), + gmem_transform=TileTransform(tma_tiling.mk), collective=(gpu.Dimension.x, gpu.Dimension.z), **common_copy_args, ) rhs_slice = (ds(k_start, block_tiling.k), ds(n_start, block_tiling.n)) - rhs_transform = (mosaic_gpu.TileTransform(tma_tiling.kn),) + rhs_transform = (TileTransform(tma_tiling.kn),) if rhs_transpose: rhs_slice = rhs_slice[::-1] - rhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) + rhs_transform += (TransposeTransform((1, 0, 2, 3)),) assert tma_tiling.n == tma_tiling.k, block_tiling # No need to flip the tiling. ctx.async_copy( src_ref=b_device, @@ -292,7 +289,7 @@ def stage_loop_body(ki, accs): src_ref=epilogue_smem, dst_ref=c_device, gmem_slice=(ds(m_start, tile_m), ds(n_start, tile_n)), - gmem_transform=mosaic_gpu.TileTransform(out_tiling), + gmem_transform=TileTransform(out_tiling), swizzle=out_swizzle, ) ctx.await_async_copy(0) @@ -304,7 +301,7 @@ def stage_loop_body(ki, accs): f" {grid_tile_n=})" ) cluster = (cluster_tile_n, cluster_m, cluster_n // cluster_tile_n) - return mosaic_gpu.as_gpu_kernel( + return as_gpu_kernel( _main, grid, block, diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 502373bdc91e..d5a6e9eb69d1 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -29,7 +29,7 @@ from jaxlib.mlir.dialects import vector import numpy as np -from . import dsl as mgpu +import jax.experimental.mosaic.gpu as mgpu from . import utils # mypy: ignore-errors diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index b64418022d0e..5b0282080c55 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -21,13 +21,11 @@ import jax from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith -from jaxlib.mlir.dialects import builtin from jaxlib.mlir.dialects import llvm -from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import vector import numpy as np -from . import dsl as mgpu +import jax.experimental.mosaic.gpu as mgpu from . import utils # mypy: ignore-errors diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 1a29bbb5736d..9f2f1222bde4 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -43,8 +43,7 @@ class Dimension(enum.IntEnum): # Just to make parameterized tests expand ok y = 1 z = 2 else: - from jax.experimental.mosaic import gpu as mosaic_gpu - from jax.experimental.mosaic.gpu import dsl as mgpu + import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import utils as utils from jax.experimental.mosaic.gpu import profiler from jax.experimental.mosaic.gpu.utils import * # noqa: F403 @@ -171,14 +170,14 @@ def test_copy_basic(self): def kernel(ctx, src, dst, _): copy(src, dst) x = jnp.arange(2 * 3 * 5).reshape(2, 5, 3) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x) np.testing.assert_array_equal(y, x) def test_copy_swizzle(self): def kernel(ctx, src, dst, _): copy(src, dst, swizzle=128) x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x) expected = np.zeros_like(y) for i in range(8): for j in range(8): @@ -192,7 +191,7 @@ def kernel(ctx, src, dst, smem): copy(src, smem, swizzle=128) copy(smem, dst, swizzle=128) x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) np.testing.assert_array_equal(y, x) def test_iota_tensor(self): @@ -209,7 +208,7 @@ def kernel(ctx, dst, _): reg, dst, [gpu.thread_id(gpu.Dimension.x), c(2 * i + j, index)] ) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) - regs = mosaic_gpu.as_gpu_kernel( + regs = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() thread_ids = np.arange(128) @@ -248,7 +247,7 @@ def kernel(ctx, inp, out, _): out_shape = list(x.shape) out_shape.insert(dim, 1) out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () )(x) np.testing.assert_array_equal(y, x.reshape(out_shape)) @@ -276,7 +275,7 @@ def kernel(ctx, inp, out, _): out_shape = list(in_shape) out_shape[dim:dim + 1] = [2, 2, out_shape[dim] // 4] out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () )(x) np.testing.assert_array_equal(y, x.reshape(out_ty.shape)) @@ -290,7 +289,7 @@ def kernel(ctx, inp, out, _): x = np.arange(8 * 2 * 8, dtype=jnp.float32).reshape(8, 2, 8) out_ty = jax.ShapeDtypeStruct((16, 8) if dim == 0 else (8, 16), jnp.float32) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () )(x) np.testing.assert_array_equal(y, x.reshape(out_ty.shape)) @@ -329,7 +328,7 @@ def kernel(ctx, inp, out, _): copy(memref_fold(memref_slice(inp, index), dim, fold_rank), out) out = np_fold(np_inp[index], dim, fold_rank) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), np_inp, out, () )(np_inp) assert ( @@ -371,7 +370,7 @@ def kernel(ctx, out, _): del ctx iota_tensor(64, 64, mlir_dtype).store_untiled(out) expected = np.arange(64 * 64, dtype=jax_dtype).reshape(64, 64) - iota = mosaic_gpu.as_gpu_kernel( + iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, () )() np.testing.assert_array_equal(iota, expected) @@ -387,7 +386,7 @@ def kernel(ctx, out, _): del ctx mgpu.FragmentedArray.splat(c(1., mlir_dtype), (size,)).store_untiled(out) expected = np.ones((size,), jax_dtype) - mosaic_ones = mosaic_gpu.as_gpu_kernel( + mosaic_ones = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, () )() np.testing.assert_array_equal(mosaic_ones, expected) @@ -419,7 +418,7 @@ def kernel(ctx, out, smem): .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) .transpose(0, 2, 1, 3) ) - iota = mosaic_gpu.as_gpu_kernel( + iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, expected )() np.testing.assert_array_equal(iota, expected) @@ -445,12 +444,12 @@ def kernel(ctx, out, smem): dst_ref=out, swizzle=swizzle, gmem_slice=(ds(0, m), ds(0, col_tiling)), - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=mgpu.TileTransform(tiling), ) ctx.await_async_copy(0) smem_shape = jax.ShapeDtypeStruct((m // tiling[0], 1, *tiling), jax_dtype) expected = np.arange(m * n, dtype=jax_dtype).reshape(m, n) - iota = mosaic_gpu.as_gpu_kernel( + iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, smem_shape )() np.testing.assert_array_equal(iota, expected) @@ -493,7 +492,7 @@ def kernel(ctx, inp, out, smem): expected_from = expected(jax_dtype_from, from_tiling) expected_to = expected(jax_dtype_to, to_tiling) - res = mosaic_gpu.as_gpu_kernel( + res = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), @@ -525,7 +524,7 @@ def kernel(ctx, in_, out, smem): .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) .transpose(0, 2, 1, 3) ) - iota = mosaic_gpu.as_gpu_kernel( + iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), expected, expected, (expected,) * 2 )(expected) np.testing.assert_array_equal(iota, expected) @@ -593,13 +592,13 @@ def test_wgmma_basic( def kernel(ctx, lhs, rhs, out, scratch): lhs_smem, rhs_smem, barriers = scratch if tma_inputs: - lhs_transform = (mosaic_gpu.TileTransform((64, nk_tile)),) + lhs_transform = (mgpu.TileTransform((64, nk_tile)),) if lhs_transpose: assert nk_tile == 64 # Make sure we didn't have to transpose tiling. - lhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) - rhs_transform = (mosaic_gpu.TileTransform((nk_tile, nk_tile)),) + lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) + rhs_transform = (mgpu.TileTransform((nk_tile, nk_tile)),) if rhs_transpose: - rhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) + rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) ctx.async_copy( src_ref=lhs, dst_ref=lhs_smem, @@ -666,7 +665,7 @@ def quantize(x): ), mgpu.TMABarrier(2), ] - z = mosaic_gpu.as_gpu_kernel( + z = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape )(x, y) x32, y32 = x.astype(np.float32), y.astype(np.float32) @@ -724,7 +723,7 @@ def kernel(ctx, rhs, out, rhs_smem): scratch_shape = jax.ShapeDtypeStruct( (k_steps, n // nk_tile, nk_tile, nk_tile), jax_dtype ) - z = mosaic_gpu.as_gpu_kernel( + z = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), y, out_shape, scratch_shape )(y) x = np.arange(m * k, dtype=jax_dtype).reshape(m, k) @@ -752,11 +751,11 @@ def kernel(ctx, rhs, out, smem): rhs_smem, barrier = smem gmem_slice = (ds(0, k), ds(0, nk_tile)) smem_slice = (slice(None), slice(None), slice(None), ds(0, n)) - transform = (mosaic_gpu.TileTransform((nk_tile, nk_tile)),) + transform = (mgpu.TileTransform((nk_tile, nk_tile)),) if rhs_transpose: gmem_slice = gmem_slice[::-1] smem_slice = (slice(None), slice(None), ds(0, n), slice(None)) - transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) + transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) ctx.async_copy( src_ref=rhs, dst_ref=rhs_smem, @@ -781,7 +780,7 @@ def kernel(ctx, rhs, out, smem): rhs_scratch_shape = jax.ShapeDtypeStruct( (k_steps, 1, nk_tile, nk_tile), jax_dtype ) - z = mosaic_gpu.as_gpu_kernel( + z = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), y, out_shape, (rhs_scratch_shape, mgpu.TMABarrier()), )(y) x = np.arange(m * k, dtype=jax_dtype).reshape(m, k) @@ -823,7 +822,7 @@ def kernel(ctx, dst, scratch): final_arr.store_untiled(memref_slice(dst, 1)) scf.yield_([]) out_shape = jax.ShapeDtypeStruct((2, 128), jnp.int32) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (2 * 128, 1, 1), @@ -892,7 +891,7 @@ def kernel(ctx, dst, mask, collective_barrier): if group_dims: barrier_dims = (collective_dims[:2], *collective_dims[2:]) scratch = mgpu.ClusterBarrier(barrier_dims) - y, mask = mosaic_gpu.as_gpu_kernel( + y, mask = mgpu.as_gpu_kernel( kernel, cluster, (128, 1, 1), (), (out_shape, mask_shape), scratch, cluster=cluster, )() np.testing.assert_array_equal( @@ -931,7 +930,7 @@ def kernel(ctx, src, dst, smem): copy(tmp, dst, swizzle=swizzle) x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) smem = (x, mgpu.TMABarrier()) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) np.testing.assert_array_equal(y, x) @parameterized.named_parameters( @@ -1009,7 +1008,7 @@ def kernel(ctx, src, dst, scratch): ) x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) smem_shape = (jax.ShapeDtypeStruct(shape[1:], dtype), mgpu.TMABarrier()) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, cluster, (128, 1, 1), x, x, smem_shape, cluster=cluster )(x) np.testing.assert_array_equal(y, x) @@ -1033,7 +1032,7 @@ def kernel(ctx, src, dst, scratch): dst_ref=tmp, swizzle=swizzle, barrier=barrier, - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=mgpu.TileTransform(tiling), ) barrier.wait_parity(c(0, i1)) for idxs in np.ndindex(tiled_shape): @@ -1048,7 +1047,7 @@ def kernel(ctx, src, dst, scratch): jax.ShapeDtypeStruct(tile_shape(shape, tiling), dtype), mgpu.TMABarrier(), ) - f = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem) + f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem) y = f(x) np.testing.assert_array_equal(y, x) @@ -1075,7 +1074,7 @@ def kernel(ctx, src, dst, smem): copy(tmp, dst, swizzle=swizzle) x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) smem = (x, mgpu.TMABarrier()) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) np.testing.assert_array_equal(y, x) def test_parity_tracking(self): @@ -1091,7 +1090,7 @@ def kernel(ctx, src, dst, smem): barrier.wait() copy(tmp, memref_slice(dst, s)) x = np.arange(np.prod(shape), dtype=jnp.float16).reshape(shape) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, x, (x[0:1], mgpu.TMABarrier()) )(x) np.testing.assert_array_equal(y, x) @@ -1109,7 +1108,7 @@ def kernel(ctx, src, dst, tmp): ctx.async_copy(src_ref=tmp, dst_ref=dst, swizzle=swizzle) ctx.await_async_copy(0) x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) np.testing.assert_array_equal(y, x) @parameterized.parameters(0, 1) @@ -1128,7 +1127,7 @@ def kernel(ctx, src, dst, smem): src_ref=src, dst_ref=tmp, swizzle=128, - gmem_transform=mosaic_gpu.TileTransform((64, 64)), + gmem_transform=mgpu.TileTransform((64, 64)), gmem_slice=(ds(0, padded_shape[0]), ds(0, padded_shape[1])), barrier=barrier, ) @@ -1136,7 +1135,7 @@ def kernel(ctx, src, dst, smem): copy(tmp, dst, swizzle=128) x = np.arange(np.prod(shape), dtype=jnp.float16).reshape(shape) tiled = jax.ShapeDtypeStruct(tiled_shape, jnp.float16) - y_tiled = mosaic_gpu.as_gpu_kernel( + y_tiled = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, tiled, (tiled, mgpu.TMABarrier()), )(x) y = y_tiled.swapaxes(1, 2).reshape(padded_shape) @@ -1165,13 +1164,13 @@ def kernel(ctx, dst, tmp): src_ref=tmp, dst_ref=dst, swizzle=128, - gmem_transform=mosaic_gpu.TileTransform((64, 64)), + gmem_transform=mgpu.TileTransform((64, 64)), gmem_slice=(ds(0, padded_shape[0]), ds(0, padded_shape[1])), ) ctx.await_async_copy(0) tiled = jax.ShapeDtypeStruct(tiled_shape, jnp.float16) out = jax.ShapeDtypeStruct(shape, jnp.float16) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out, tiled, )() iota = np.arange(np.prod(padded_shape), dtype=jnp.float16).reshape( @@ -1187,7 +1186,7 @@ def kernel(ctx, src, dst, tmp): def run_kernel(shape): x = np.arange(np.prod(shape)).reshape(shape) - _ = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + _ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) with self.assertRaisesRegex(ValueError, "only support striding up to 5"): run_kernel([1] * 6) @@ -1224,7 +1223,7 @@ def kernel(ctx, dst, _): rhs = iota if scalar_rhs is None else c(scalar_rhs, iota.mlir_dtype) op(iota, rhs).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() ref_x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) @@ -1257,7 +1256,7 @@ def kernel(ctx, dst, _): iota = iota_tensor(m=m, n=n, mlir_dtype=f32) op(iota).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) @@ -1276,7 +1275,7 @@ def kernel(ctx, dst, _): iota = iota_tensor(m=m, n=n, mlir_dtype=f32) iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) @@ -1298,7 +1297,7 @@ def kernel(ctx, dst, _): cte_arr = cte_arr.reshape((1, 1)).broadcast((m, n)) (iota + cte_arr).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() expected = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + 1 @@ -1311,7 +1310,7 @@ def kernel(ctx, dst, _): t = mgpu.FragmentedArray.splat(v, (128,), mgpu.WGMMA_ROW_LAYOUT) t.broadcast_minor(32).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() np.testing.assert_array_equal(result, np.full((128, 32), 3.14, np.float32)) @@ -1326,7 +1325,7 @@ def kernel(ctx, *args): copy(smem_output, gmem_output) inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], )(inp) np.testing.assert_array_equal(inp, result) @@ -1341,7 +1340,7 @@ def kernel(ctx, out, *_): memref.store(grp, out, [tid]) x = np.arange(128, dtype=jnp.int32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), x, [], )() for i in range(0, 128, 4): @@ -1364,7 +1363,7 @@ def kernel(ctx, inp, out, smem): x = jnp.arange(-128, 128, dtype=jax_dtype_from) reference = x.astype(jax_dtype_to) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, reference, None, )(x) np.testing.assert_array_equal(result, reference) @@ -1382,7 +1381,7 @@ def test_multigpu(self): def kernel(ctx, src, dst, _): mgpu.FragmentedArray.load_strided(src).store_untiled(dst) x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) - f = jax.jit(mosaic_gpu.as_gpu_kernel( + f = jax.jit(mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, x, () )) # Make sure we can invoke the same program on different devices. @@ -1407,7 +1406,7 @@ def kernel(ctx, i_gmem, o_gmem, _): ty = jax.ShapeDtypeStruct((128, 128), jnp.float32) x = self.torch.randn((128, 128), dtype=self.torch.float, device='cuda') - f = mosaic_gpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ()) + f = mgpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ()) y = f(x) np.testing.assert_allclose(y.cpu(), x.cpu() * 2) del y # Make sure the destructor runs successfully.