diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 4d60c3017287..5b1b86e5c1a7 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -149,6 +149,40 @@ def f(x): self.assertEqual(cache_miss_count[0], 0) def testAutoPgleWithAot(self): + mesh = jtu.create_mesh((2,), ('x',)) + + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + ) + def f(x): + return x * 2 + + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + expected = x * 2 + + with config.pgle_profiling_runs(1), config.enable_pgle(True): + compiled_f = f.lower(x).compile() + + # Run 1: Compiled module is launched and profiled. + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(compiled_f(x), expected) + self.assertEqual(cache_miss_count[0], 1) + + # Run 2: Second PGLE run should recompile the module with FDO + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(compiled_f(x), expected) + self.assertEqual(cache_miss_count[0], 1) + + # Run 3: Fast-path should be used after PGLE is done + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(compiled_f(x), expected) + self.assertEqual(cache_miss_count[0], 0) + + # Test AOT compilation when module were serialized on the other process. + def testAutoPgleWithXAot(self): @jax.jit def f(x): return x * 2 @@ -156,11 +190,10 @@ def f(x): x = jnp.arange(1) expected = x * 2 - f_lowered = f.lower(x) - serialized, in_tree, out_tree = serialize(f_lowered.compile()) - compiled = deserialize_and_load(serialized, in_tree, out_tree) - with config.pgle_profiling_runs(1), config.enable_pgle(True): + f_lowered = f.lower(x) + serialized, in_tree, out_tree = serialize(f_lowered.compile()) + compiled = deserialize_and_load(serialized, in_tree, out_tree) # Run 1 with jtu.count_cached_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(compiled(x), expected)