Skip to content

Commit

Permalink
Replace references to deprecated jax array attributes device_buffer a…
Browse files Browse the repository at this point in the history
…nd device_buffers

PiperOrigin-RevId: 588553845
  • Loading branch information
Jake VanderPlas authored and ChexDev committed Dec 6, 2023
1 parent cbac0f9 commit d4414e2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,7 @@ def _assert_fn(path, leaf):
# Check that the leaf is a ShardedArray.
if isinstance(leaf, jax.Array):
if _check_sharding(leaf):
shards = tuple(buf.device() for buf in leaf.device_buffers)
shards = tuple(shard.device for shard in leaf.addressable_shards)
if shards != devices:
errors.append(
f"Tree leaf '{_ai.format_tree_path(path)}' is sharded "
Expand Down

0 comments on commit d4414e2

Please sign in to comment.