diff --git a/libcpab/tensorflow/interpolation.py b/libcpab/tensorflow/interpolation.py index cde6a5f..1b2c8fd 100644 --- a/libcpab/tensorflow/interpolation.py +++ b/libcpab/tensorflow/interpolation.py @@ -93,7 +93,7 @@ def interpolate2D(data, grid, outsize): # Batch effect batch_size = out_width*out_height - batch_idx = tf.tile(tf.range(n_batch), (batch_size,)) + batch_idx = tf.reshape(tf.tile(tf.expand_dims(tf.range(n_batch), -1), (1, batch_size)), (-1,)) # Index c00 = tf.gather_nd(data, tf.stack([batch_idx, x0, y0], axis=1))