Skip to content

Commit

Permalink
Merge pull request jax-ml#12308 from jakevdp:fix-shape-error
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 473319485
  • Loading branch information
jax authors committed Sep 9, 2022
2 parents 4746a39 + fcac395 commit 056f400
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ def _broadcast_shapes_uncached(*shapes):
shape_list = [(1,) * (ndim - len(shape)) + shape for shape in shapes]
result_shape = _try_broadcast_shapes(shape_list)
if result_shape is None:
raise ValueError("Incompatible shapes for broadcasting: {}"
.format(tuple(shape_list)))
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
return result_shape

def _broadcast_ranks(s1, s2):
Expand Down

0 comments on commit 056f400

Please sign in to comment.