Skip to content

Commit

Permalink
Fix places that were passing jnp.array where a shape was expected.
Browse files Browse the repository at this point in the history
For reference: jax-ml/jax#6400.

PiperOrigin-RevId: 369471455
  • Loading branch information
pschuh authored and copybara-github committed Apr 20, 2021
1 parent 3c9190b commit 28743ee
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions trax/models/research/transformer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,7 @@ def _UpdateRow(x):
# In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e`
# and pick up (L2, H) tensor slice from there.
zero = jnp.array(0, dtype=len_e.dtype) # avoid int32/int64 mismatch
l2_np = jnp.array(L2, dtype=len_e.dtype)
h_np = jnp.array(H, dtype=len_e.dtype)
return fastmath.dynamic_slice(row_ed, (len_e, zero), (l2_np, h_np))
return fastmath.dynamic_slice(row_ed, (len_e, zero), (L2, H))

return fastmath.map(_UpdateRow, [vec_ed, tok_e, tok_d])

Expand Down

0 comments on commit 28743ee

Please sign in to comment.