Skip to content

Commit

Permalink
custom_parititioning: in lower sharding, Sharding should be XLACompat…
Browse files Browse the repository at this point in the history
…ibleSharding.

PiperOrigin-RevId: 537077304
  • Loading branch information
pschuh authored and jax authors committed Jun 1, 2023
1 parent c8311c6 commit 5c2070c
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions jax/experimental/custom_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ def _custom_partitioning_propagate_user_sharding(user_sharding, shape,


def _to_hlo_sharding(sharding, num_dimensions):
if not isinstance(sharding, jax.sharding.Sharding):
raise ValueError("Custom Partitioning rules must return shardings.")
if not isinstance(sharding, jax.sharding.XLACompatibleSharding):
raise ValueError(
"Custom Partitioning rules must return XLACompatibleShardings."
)
return xc.HloSharding.from_proto(sharding._to_xla_op_sharding(num_dimensions))


Expand Down

0 comments on commit 5c2070c

Please sign in to comment.