Skip to content

[shmap/partial-auto] Fixes lowering for jax.lax.axis_index in shard_map for degenerated shmaps. #1067

[shmap/partial-auto] Fixes lowering for jax.lax.axis_index in shard_map for degenerated shmaps.

[shmap/partial-auto] Fixes lowering for jax.lax.axis_index in shard_map for degenerated shmaps. #1067