The jnp.einsum
op provides a DSL-based unified interface to matmul and
tensordot ops.
This einshape
library is designed to offer a similar DSL-based approach
to unifying reshape, squeeze, expand_dims, and transpose operations.
Some examples:
einshape("n->n111", x)
is equivalent toexpand_dims(x, axis=1)
three timeseinshape("a1b11->ab", x)
is equivalent tosqueeze(x, axis=[1,3,4])
einshape("nhwc->nchw", x)
is equivalent totranspose(x, perm=[0,3,1,2])
einshape("mnhwc->(mn)hwc", x)
is equivalent to a reshape combining the two leading dimensionseinshape("(mn)hwc->mnhwc", x, n=batch_size)
is equivalent to a reshape splitting the leading dimension into two, using kwargs (m or n or both) to supply the necessary additional shape informationeinshape("mn...->(mn)...", x)
combines the two leading dimensions without knowing the rank ofx
einshape("n...->n(...)", x)
performs a 'batch flatten'einshape("ij->ijk", x, k=3)
inserts a trailing dimension and tiles along iteinshape("ij->i(nj)", x, n=3)
tiles along the second dimension
See jax_ops.py
for the JAX implementation of the einshape
function.
Alternatively, the parser and engine are exposed in engine.py
allowing
analogous implementations in TensorFlow or other frameworks.
Einshape can be installed with the following command:
pip3 install git+https://github.com/deepmind/einshape
Einshape will work with either Jax or TensorFlow. To allow for that it does not list either as a requirement, so it is necessary to ensure that Jax or TensorFlow is installed separately.
Jax version:
from einshape import jax_einshape as einshape
from jax import numpy as jnp
a = jnp.array([[1, 2], [3, 4]])
b = einshape("ij->(ij)", a)
# b is [1, 2, 3, 4]
TensorFlow version:
from einshape import tf_einshape as einshape
import tensorflow as tf
a = tf.constant([[1, 2], [3, 4]])
b = einshape("ij->(ij)", a)
# b is [1, 2, 3, 4]
Numpy version:
from einshape import numpy_einshape as einshape
import numpy as np
a = np.array([[1, 2], [3, 4]])
b = einshape("ij->(ij)", a)
# b is [1, 2, 3, 4]
An einshape equation is always of the form {lhs}->{rhs}
, where {lhs}
and
{rhs}
both stand for expressions. An expression represents the axes of an
array; the relationship between two expressions illustrate how an array should
be transformed.
An expression is a non-empty sequence of the following elements:
A single letter a-z, representing one axis of an array.
For example, the expressions ab
and jq
both represent an array of rank 2.
Every index name that is present on the left-hand side of an equation must
also be present on the right-hand side. So, ab->a
is not a valid
equation, but a->ba
is valid (and will tile a vector b
times).
...
, representing any axes of an array that are not otherwise represented in
the expression. This is similar to the use of -1
as an axis in a reshape
operation.
For example, a...b
can represent any array of rank 2 or more: a
will refer to the first axis and b
to the last. The equation ...ab->...ba
will swap the last two axes of an array.
An expression may not include more than one ellipsis (because that would be ambiguous). Like an index name, an ellipsis must be present in both halves of an equation or neither.
({components})
, where components
is a sequence of index names and ellipsis
elements. The entire group corresponds to a single axis of the array; the
group's components represent factors of the axis size. This can be used to
reshape an axis into many axes. All the factors except at most one must be
specified using keyword arguments.
For example, einshape('(ab)->ab', x, a=10)
reshapes an array of rank 1 (whose
length must be a multiple of 10) into an array of rank 2 (whose first dimension
is of length 10).
Groups may not be nested.
The digit 1
, representing a single axis of length 1. This is
useful for expanding and squeezing unit dimensions.
For example, the equation 1...->...
squeezes a leading axis (which must have
length one).
This is not an official Google product.