diff --git a/trax/models/research/transformer2.py b/trax/models/research/transformer2.py index 72007ed16..bab64672e 100644 --- a/trax/models/research/transformer2.py +++ b/trax/models/research/transformer2.py @@ -225,9 +225,7 @@ def _UpdateRow(x): # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e` # and pick up (L2, H) tensor slice from there. zero = jnp.array(0, dtype=len_e.dtype) # avoid int32/int64 mismatch - l2_np = jnp.array(L2, dtype=len_e.dtype) - h_np = jnp.array(H, dtype=len_e.dtype) - return fastmath.dynamic_slice(row_ed, (len_e, zero), (l2_np, h_np)) + return fastmath.dynamic_slice(row_ed, (len_e, zero), (L2, H)) return fastmath.map(_UpdateRow, [vec_ed, tok_e, tok_d])