diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index cf7de61fc37b..319434dbd459 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -121,8 +121,10 @@ def _custom_partitioning_propagate_user_sharding(user_sharding, shape, def _to_hlo_sharding(sharding, num_dimensions): - if not isinstance(sharding, jax.sharding.Sharding): - raise ValueError("Custom Partitioning rules must return shardings.") + if not isinstance(sharding, jax.sharding.XLACompatibleSharding): + raise ValueError( + "Custom Partitioning rules must return XLACompatibleShardings." + ) return xc.HloSharding.from_proto(sharding._to_xla_op_sharding(num_dimensions))