From a6381cd16884537f771aba7e932f0dff66b0d8d9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 6 Dec 2023 15:01:57 -0800 Subject: [PATCH] Replace references to deprecated device_buffer attributes `jax.Array.device_buffer` and `jax.Array.device_buffers` will be deprecated as of jax version 0.4.22; see https://github.com/google/jax/pull/18844. PiperOrigin-RevId: 588553845 --- chex/_src/asserts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 4f1e86d5..4aafe566 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -1226,7 +1226,7 @@ def _assert_fn(path, leaf): # Check that the leaf is a ShardedArray. if isinstance(leaf, jax.Array): if _check_sharding(leaf): - shards = tuple(buf.device() for buf in leaf.device_buffers) + shards = tuple(shard.device for shard in leaf.addressable_shards) if shards != devices: errors.append( f"Tree leaf '{_ai.format_tree_path(path)}' is sharded "