Skip to content

Commit

Permalink
[JAX] Add test-case on AOT compilation on the same process
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676304268
  • Loading branch information
Google-ML-Automation committed Sep 19, 2024
1 parent 9d2e9c6 commit 7cad989
Showing 1 changed file with 37 additions and 4 deletions.
41 changes: 37 additions & 4 deletions tests/pgle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,51 @@ def f(x):
self.assertEqual(cache_miss_count[0], 0)

def testAutoPgleWithAot(self):
mesh = jtu.create_mesh((2,), ('x',))

@partial(
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
)
def f(x):
return x * 2

shape = (16, 16)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
expected = x * 2

with config.pgle_profiling_runs(1), config.enable_pgle(True):
compiled_f = f.lower(x).compile()

# Run 1: Compiled module is launched and profiled.
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(compiled_f(x), expected)
self.assertEqual(cache_miss_count[0], 1)

# Run 2: Second PGLE run should recompile the module with FDO
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(compiled_f(x), expected)
self.assertEqual(cache_miss_count[0], 1)

# Run 3: Fast-path should be used after PGLE is done
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(compiled_f(x), expected)
self.assertEqual(cache_miss_count[0], 0)

# Test AOT compilation when module were serialized on the other process.
def testAutoPgleWithXAot(self):
@jax.jit
def f(x):
return x * 2

x = jnp.arange(1)
expected = x * 2

f_lowered = f.lower(x)
serialized, in_tree, out_tree = serialize(f_lowered.compile())
compiled = deserialize_and_load(serialized, in_tree, out_tree)

with config.pgle_profiling_runs(1), config.enable_pgle(True):
f_lowered = f.lower(x)
serialized, in_tree, out_tree = serialize(f_lowered.compile())
compiled = deserialize_and_load(serialized, in_tree, out_tree)
# Run 1
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(compiled(x), expected)
Expand Down

0 comments on commit 7cad989

Please sign in to comment.