Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pmap/laxmap #10

Merged
merged 1 commit into from
Jun 9, 2023
Merged

add pmap/laxmap #10

merged 1 commit into from
Jun 9, 2023

Conversation

ASEM000
Copy link
Owner

@ASEM000 ASEM000 commented Jun 9, 2023

Enable jax.lax.map/ jax.pmap in the kmap/smap interface
Example:

import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=200"

import jax
import jax
import kernex as kex


@kex.kmap(
    kernel_size=(2,),
    map_kind="pmap",
    map_kwargs={"axis_name": "i"},
)
def f(x):
    return x


print(f(jax.numpy.arange(5)))

# [[0 1]
#  [1 2]
#  [2 3]
#  [3 4]]

@ASEM000 ASEM000 merged commit d024e28 into main Jun 9, 2023
@ASEM000 ASEM000 deleted the pmap/map-to-kernel-map branch June 9, 2023 21:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant