diff --git a/jax/test_util.py b/jax/test_util.py index a5e9eb38b4cd..2e9652a26c70 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -314,6 +314,22 @@ def f_vjp(*args): _check_grads(f, args, order) +@contextmanager +def count_device_put(): + device_put = xla.device_put + count = [0] + + def device_put_and_count(*args, **kwargs): + count[0] += 1 + return device_put(*args, **kwargs) + + xla.device_put = device_put_and_count + try: + yield count + finally: + xla.device_put = device_put + + @contextmanager def count_primitive_compiles(): xla.xla_primitive_callable.cache_clear() diff --git a/tests/api_test.py b/tests/api_test.py index f3f562e5c4f7..451794c98bf1 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2140,17 +2140,17 @@ def test_xla_computation_zeros_doesnt_device_put(self): if not config.omnistaging_enabled: raise unittest.SkipTest("test is omnistaging-specific") - count = 0 - def device_put_and_count(*args, **kwargs): - nonlocal count - count += 1 - return orig_device_put(*args, **kwargs) - orig_device_put, xla.device_put = xla.device_put, device_put_and_count - try: + with jtu.count_device_put() as count: api.xla_computation(lambda: jnp.zeros(3))() - finally: - xla.device_put = orig_device_put - self.assertEqual(count, 0) + self.assertEqual(count[0], 0) + + # TODO(mattjj): Enable this after fixing convert_element_type. + @unittest.skipIf(True, "broken by convert_element_type.") + def test_random_split_doesnt_device_put(self): + key = jax.random.PRNGKey(1) + with jtu.count_device_put() as count: + key, _ = jax.random.split(key, 2) + self.assertEqual(count[0], 0) def test_join_concrete_arrays_with_omnistaging(self): # https://github.com/google/jax/issues/4622