Skip to content

Commit

Permalink
Internal bug fix.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590251937
  • Loading branch information
T5X Team authored and t5-copybara committed Dec 12, 2023
1 parent 77f4664 commit 49a4e50
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion t5x/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def bounds_from_last_device(last_device: jax.Device) -> HardwareMesh:
# Must be passed the device at the highest-coordinate corner of the
# relevant mesh, which is a requirement we know is satisfied by the last
# device in jax.devices().
if hasattr(last_device, 'coords'):
if hasattr(last_device, 'coords') and last_device.coords.shape == (3,):
x, y, z = last_device.coords
return x + 1, y + 1, z + 1, last_device.core_on_chip + 1
else:
Expand Down

0 comments on commit 49a4e50

Please sign in to comment.