Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Add test-case on AOT compilation on the same process #23752

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions tests/pgle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,51 @@ def f(x):
self.assertEqual(cache_miss_count[0], 0)

def testAutoPgleWithAot(self):
mesh = jtu.create_mesh((2,), ('x',))

@partial(
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
)
def f(x):
return x * 2

shape = (16, 16)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
expected = x * 2

with config.pgle_profiling_runs(1), config.enable_pgle(True):
compiled_f = f.lower(x).compile()

# Run 1: Compiled module is launched and profiled.
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(compiled_f(x), expected)
self.assertEqual(cache_miss_count[0], 1)

# Run 2: Second PGLE run should recompile the module with FDO
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(compiled_f(x), expected)
self.assertEqual(cache_miss_count[0], 1)

# Run 3: Fast-path should be used after PGLE is done
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(compiled_f(x), expected)
self.assertEqual(cache_miss_count[0], 0)

# Test AOT compilation when module were serialized on the other process.
def testAutoPgleWithXAot(self):
@jax.jit
def f(x):
return x * 2

x = jnp.arange(1)
expected = x * 2

f_lowered = f.lower(x)
serialized, in_tree, out_tree = serialize(f_lowered.compile())
compiled = deserialize_and_load(serialized, in_tree, out_tree)

with config.pgle_profiling_runs(1), config.enable_pgle(True):
f_lowered = f.lower(x)
serialized, in_tree, out_tree = serialize(f_lowered.compile())
compiled = deserialize_and_load(serialized, in_tree, out_tree)
# Run 1
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(compiled(x), expected)
Expand Down