Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mosaic GPU] Clean up the module structure #23760

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from jax._src import dtypes
from jax._src import tree_util
from jax._src.pallas import core as pallas_core
from jax.experimental.mosaic import gpu as mosaic_gpu
import jax.experimental.mosaic.gpu as mgpu
import jax.numpy as jnp


Expand Down Expand Up @@ -64,7 +64,7 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):


class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol):
def to_gpu_transform(self) -> mosaic_gpu.MemRefTransform:
def to_gpu_transform(self) -> mgpu.MemRefTransform:
...


Expand Down Expand Up @@ -101,8 +101,8 @@ def __call__(
inner_aval=block_aval.inner_aval.update(shape=new_block_shape)
)

def to_gpu_transform(self) -> mosaic_gpu.MemRefTransform:
return mosaic_gpu.TileTransform(self.tiling)
def to_gpu_transform(self) -> mgpu.MemRefTransform:
return mgpu.TileTransform(self.tiling)


@dataclasses.dataclass(frozen=True)
Expand Down
13 changes: 6 additions & 7 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.mosaic_gpu import core as gpu_core
from jax._src.state import primitives as sp
from jax.experimental.mosaic import gpu as mosaic_gpu
from jax.experimental.mosaic.gpu import dsl as mgpu
import jax.experimental.mosaic.gpu as mgpu
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -160,7 +159,7 @@ def stack_free_smem(self, bytes: int):
@dataclasses.dataclass(frozen=True)
class LoweringRuleContext:
module_ctx: ModuleContext
launch_ctx: mosaic_gpu.LaunchContext
launch_ctx: mgpu.LaunchContext
avals_in: Sequence[jax_core.ShapedArray]
avals_out: Sequence[jax_core.ShapedArray]

Expand All @@ -180,7 +179,7 @@ class LoweringError(Exception): # pylint: disable=g-bad-exception-name

def _eval_index_map(
module_ctx: ModuleContext,
launch_ctx: mosaic_gpu.LaunchContext,
launch_ctx: mgpu.LaunchContext,
idx: ir.Value,
block_mapping: pallas_core.BlockMapping,
) -> Sequence[ir.Value]:
Expand Down Expand Up @@ -300,7 +299,7 @@ def lower_jaxpr_to_module(
)
]

def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers: ir.Value):
def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
*buffers_gmem, (
buffers_smem,
*scratch_buffers_smem,
Expand Down Expand Up @@ -494,7 +493,7 @@ def _(step, _):
jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8)
)

module, out_structs_smem, _ = mosaic_gpu._lower_as_gpu_kernel(
module, out_structs_smem, _ = mgpu._lower_as_gpu_kernel(
body,
grid=grid,
cluster=(),
Expand Down Expand Up @@ -528,7 +527,7 @@ def deco(fn):

def lower_jaxpr_to_mosaic_gpu(
module_ctx: ModuleContext,
launch_ctx: mosaic_gpu.LaunchContext,
launch_ctx: mgpu.LaunchContext,
jaxpr: jax_core.Jaxpr,
args: Sequence[ir.Value],
consts=(),
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from jax._src.interpreters import mlir
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic_gpu import lowering
from jax.experimental.mosaic import gpu as mosaic_gpu
import jax.experimental.mosaic.gpu.core as mosaic_core


def pallas_call_lowering(
Expand Down Expand Up @@ -67,7 +67,7 @@ def pallas_call_lowering(
print(lowering_result.module.operation)

module = lowering_result.module
return mosaic_gpu._mosaic_gpu_lowering_rule(
return mosaic_core._mosaic_gpu_lowering_rule(
ctx,
*args,
module=module.operation.get_asm(binary=True, enable_debug_info=True),
Expand Down
Loading