Skip to content

Commit

Permalink
[Mosaic GPU] Clean up the module structure
Browse files Browse the repository at this point in the history
Previously the code was awkwardly split between the `jax.experimental.mosaic.gpu`
and `jax.experimental.mosaic.gpu.dsl` namespaces. I've now merged both so that
all user-visible APIs are accessible from `jax.experimental.mosaic.gpu`.

PiperOrigin-RevId: 676857257
  • Loading branch information
apaszke authored and Google-ML-Automation committed Sep 20, 2024
1 parent 99195ea commit 81b8b4b
Show file tree
Hide file tree
Showing 11 changed files with 1,111 additions and 1,114 deletions.
8 changes: 4 additions & 4 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from jax._src import dtypes
from jax._src import tree_util
from jax._src.pallas import core as pallas_core
from jax.experimental.mosaic import gpu as mosaic_gpu
import jax.experimental.mosaic.gpu as mgpu
import jax.numpy as jnp


Expand Down Expand Up @@ -64,7 +64,7 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):


class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol):
def to_gpu_transform(self) -> mosaic_gpu.MemRefTransform:
def to_gpu_transform(self) -> mgpu.MemRefTransform:
...


Expand Down Expand Up @@ -101,8 +101,8 @@ def __call__(
inner_aval=block_aval.inner_aval.update(shape=new_block_shape)
)

def to_gpu_transform(self) -> mosaic_gpu.MemRefTransform:
return mosaic_gpu.TileTransform(self.tiling)
def to_gpu_transform(self) -> mgpu.MemRefTransform:
return mgpu.TileTransform(self.tiling)


@dataclasses.dataclass(frozen=True)
Expand Down
13 changes: 6 additions & 7 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.mosaic_gpu import core as gpu_core
from jax._src.state import primitives as sp
from jax.experimental.mosaic import gpu as mosaic_gpu
from jax.experimental.mosaic.gpu import dsl as mgpu
import jax.experimental.mosaic.gpu as mgpu
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -160,7 +159,7 @@ def stack_free_smem(self, bytes: int):
@dataclasses.dataclass(frozen=True)
class LoweringRuleContext:
module_ctx: ModuleContext
launch_ctx: mosaic_gpu.LaunchContext
launch_ctx: mgpu.LaunchContext
avals_in: Sequence[jax_core.ShapedArray]
avals_out: Sequence[jax_core.ShapedArray]

Expand All @@ -180,7 +179,7 @@ class LoweringError(Exception): # pylint: disable=g-bad-exception-name

def _eval_index_map(
module_ctx: ModuleContext,
launch_ctx: mosaic_gpu.LaunchContext,
launch_ctx: mgpu.LaunchContext,
idx: ir.Value,
block_mapping: pallas_core.BlockMapping,
) -> Sequence[ir.Value]:
Expand Down Expand Up @@ -300,7 +299,7 @@ def lower_jaxpr_to_module(
)
]

def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers: ir.Value):
def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
*buffers_gmem, (
buffers_smem,
*scratch_buffers_smem,
Expand Down Expand Up @@ -494,7 +493,7 @@ def _(step, _):
jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8)
)

module, out_structs_smem, _ = mosaic_gpu._lower_as_gpu_kernel(
module, out_structs_smem, _ = mgpu._lower_as_gpu_kernel(
body,
grid=grid,
cluster=(),
Expand Down Expand Up @@ -528,7 +527,7 @@ def deco(fn):

def lower_jaxpr_to_mosaic_gpu(
module_ctx: ModuleContext,
launch_ctx: mosaic_gpu.LaunchContext,
launch_ctx: mgpu.LaunchContext,
jaxpr: jax_core.Jaxpr,
args: Sequence[ir.Value],
consts=(),
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from jax._src.interpreters import mlir
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic_gpu import lowering
from jax.experimental.mosaic import gpu as mosaic_gpu
import jax.experimental.mosaic.gpu.core as mosaic_core


def pallas_call_lowering(
Expand Down Expand Up @@ -67,7 +67,7 @@ def pallas_call_lowering(
print(lowering_result.module.operation)

module = lowering_result.module
return mosaic_gpu._mosaic_gpu_lowering_rule(
return mosaic_core._mosaic_gpu_lowering_rule(
ctx,
*args,
module=module.operation.get_asm(binary=True, enable_debug_info=True),
Expand Down
Loading

0 comments on commit 81b8b4b

Please sign in to comment.