Skip to content

Commit

Permalink
[Pallas] Introduce make_kernel_from_pallas (#6713)
Browse files Browse the repository at this point in the history
Summary:
This pull request introduces make_kernel_from_pallas API which is the top level API to interact with the Pallas integration. It takes a pallas_call wrapper and than make it a custom pytorch op.

Test Plan:
python test/test_pallas.py
  • Loading branch information
alanwaketan authored Mar 13, 2024
1 parent cc55d9e commit 1bbe333
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 1 deletion.
42 changes: 42 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,48 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
# the most important fields are present.
self.assertIn("custom_call_config", payload)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
# TODO: This test cannot be ran individually, let's fix it.
def test_tpu_custom_call_pallas_wrap_add_payload(self):
import jax
import jax.numpy as jnp
import jax._src.pallas.mosaic.pallas_call_registration

from jax.experimental import pallas as pl

def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y

@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(
add_vectors_kernel, out_shape=jax.ShapeDtypeStruct(x.shape,
x.dtype))(x, y)

from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y:
(x.shape, x.dtype))

dtypes = [torch.float32, torch.float
] # TODO: torch.float64, torch.bfloat16, torch.float16 don't work.
for i in range(len(dtypes)):
x = torch.randn((i + 1, i + 1), dtype=dtypes[i]).to("xla")
y = torch.randn((i + 1, i + 1), dtype=dtypes[i]).to("xla")
expected_output = x + y
output = pt_kernel(x, y)
self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu()))

dtypes = [
torch.int32, torch.int
] # TODO: torch.int64, torch.int16, torch.int8, torch.uint8 don't work.
for i in range(len(dtypes)):
x = torch.arange(i + 1, dtype=dtypes[i]).to("xla")
y = torch.arange(i + 1, dtype=dtypes[i]).to("xla")
expected_output = x + y
output = pt_kernel(x, y)
self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu()))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
55 changes: 54 additions & 1 deletion torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import functools
import jax
import jax.numpy as jnp
import jax._src.pallas.mosaic.pallas_call_registration
import torch
import torch_xla
import torch_xla.core.xla_model as xm

from typing import List
from jax.experimental import pallas as pl
from typing import List, Callable
from torch.library import impl
from torch_xla.core.xla_model import XLA_LIB

Expand Down Expand Up @@ -56,3 +62,50 @@ def _extract_backend_config(
if op.name == "stablehlo.custom_call":
return op.backend_config.value
return None


def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
if dtype == torch.float32:
return jnp.float32
elif dtype == torch.float64:
return jnp.float64
elif dtype == torch.float16:
return jnp.float16
elif dtype == torch.bfloat16:
return jnp.bfloat16
elif dtype == torch.int32:
return jnp.int32
elif dtype == torch.int64:
return jnp.int64
elif dtype == torch.int16:
return jnp.int16
elif dtype == torch.int8:
return jnp.int8
elif dtype == torch.uint8:
return jnp.uint8
else:
raise ValueError(f"Unsupported dtype: {dtype}")


def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):
# TODO: Maybe we can cache the payload for the same input.
def wrapped_kernel(kernel: Callable, output_shape_dtype_fn: Callable, *args):
jax_args = []
for i, arg in enumerate(args):
if torch.is_tensor(arg):
# ShapedArray doesn't have any storage and thus is very suitable for generating the payload.
jax_meta_tensor = jax.core.ShapedArray(
arg.shape, convert_torch_dtype_to_jax(arg.dtype))
jax_args.append(jax_meta_tensor)
else:
# TODO: We can support more types here.
assert False, f"Unsupported argument type: {type(arg)}"

ir = jax.jit(kernel).lower(*jax_args).compiler_ir()
payload = _extract_backend_config(ir)
output_shape, output_dtype = output_shape_dtype_fn(*args)
output = torch.empty(output_shape, dtype=output_dtype).to(xm.xla_device())
torch_xla._XLAC._xla_tpu_custom_call_(output, args, payload)
return output

return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn)

0 comments on commit 1bbe333

Please sign in to comment.