Skip to content

Commit

Permalink
Avoid use of deprecated device_buffer attriutes of jax.Array
Browse files Browse the repository at this point in the history
These have been deprecated as of JAX v0.4.22

PiperOrigin-RevId: 589237893
  • Loading branch information
Jake VanderPlas authored and ChexDev committed Dec 9, 2023
1 parent cbac0f9 commit 226f810
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,7 @@ def _assert_fn(path, leaf):
# Check that the leaf is a ShardedArray.
if isinstance(leaf, jax.Array):
if _check_sharding(leaf):
shards = tuple(buf.device() for buf in leaf.device_buffers)
shards = tuple(shard.device for shard in leaf.addressable_shards)
if shards != devices:
errors.append(
f"Tree leaf '{_ai.format_tree_path(path)}' is sharded "
Expand Down

0 comments on commit 226f810

Please sign in to comment.