Add a sharding rule for reduce_precision_p
and properly thread eqn.ctx in loops.py where we create pe.new_jaxpr_eqn
's#25974
Merged
copybara-service[bot] merged 1 commit intomainfrom test_716841435Jan 18, 2025
+13-7