From ebb1615d3b31283d3da18b04bbe985614301f187 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 8 Dec 2023 14:09:26 -0800 Subject: [PATCH] Avoid use of deprecated `device_buffer` attriutes of jax.Array These have been deprecated as of JAX v0.4.22 PiperOrigin-RevId: 589237893 --- chex/_src/asserts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 4f1e86d5..4aafe566 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -1226,7 +1226,7 @@ def _assert_fn(path, leaf): # Check that the leaf is a ShardedArray. if isinstance(leaf, jax.Array): if _check_sharding(leaf): - shards = tuple(buf.device() for buf in leaf.device_buffers) + shards = tuple(shard.device for shard in leaf.addressable_shards) if shards != devices: errors.append( f"Tree leaf '{_ai.format_tree_path(path)}' is sharded "