From a6ac5aa014c4c2d6668bbfa4d4a469a4c7489e0d Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Mon, 19 Apr 2021 20:38:17 -0700 Subject: [PATCH] Fix places that were passing jnp.array where a shape was expected. For reference: https://github.com/google/jax/pull/6400. --- trax/models/research/transformer2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trax/models/research/transformer2.py b/trax/models/research/transformer2.py index 72007ed16..bab64672e 100644 --- a/trax/models/research/transformer2.py +++ b/trax/models/research/transformer2.py @@ -225,9 +225,7 @@ def _UpdateRow(x): # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e` # and pick up (L2, H) tensor slice from there. zero = jnp.array(0, dtype=len_e.dtype) # avoid int32/int64 mismatch - l2_np = jnp.array(L2, dtype=len_e.dtype) - h_np = jnp.array(H, dtype=len_e.dtype) - return fastmath.dynamic_slice(row_ed, (len_e, zero), (l2_np, h_np)) + return fastmath.dynamic_slice(row_ed, (len_e, zero), (L2, H)) return fastmath.map(_UpdateRow, [vec_ed, tok_e, tok_d])