Skip to content

Commit

Permalink
Replace references to deprecated device_buffer attributes
Browse files Browse the repository at this point in the history
`jax.Array.device_buffer` and `jax.Array.device_buffers` will be deprecated as of jax version 0.4.22; see jax-ml/jax#18844.

PiperOrigin-RevId: 588553845
  • Loading branch information
Jake VanderPlas authored and ChexDev committed Dec 7, 2023
1 parent cbac0f9 commit a6381cd
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,7 @@ def _assert_fn(path, leaf):
# Check that the leaf is a ShardedArray.
if isinstance(leaf, jax.Array):
if _check_sharding(leaf):
shards = tuple(buf.device() for buf in leaf.device_buffers)
shards = tuple(shard.device for shard in leaf.addressable_shards)
if shards != devices:
errors.append(
f"Tree leaf '{_ai.format_tree_path(path)}' is sharded "
Expand Down

0 comments on commit a6381cd

Please sign in to comment.