Skip to content

Commit

Permalink
Add the unstack operation to keras core (#597)
Browse files Browse the repository at this point in the history
* Add the unstack operation to keras core

* Add a symbolic version of the op and add tests

* Correct/Add docs for the x/num argument.

[skip ci]

* Use assertRaisesRegex and test for error msg
  • Loading branch information
tirthasheshpatel authored Jul 26, 2023
1 parent b6b4376 commit 2601ce3
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 0 deletions.
7 changes: 7 additions & 0 deletions keras_core/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,10 @@ def fori_loop(lower, upper, body_fun, init_val):

def stop_gradient(variable):
return jax.lax.stop_gradient(variable)


def unstack(x, num=None, axis=0):
return [
jax.lax.index_in_dim(x, i, axis, keepdims=False)
for i in range(x.shape[axis])
]
5 changes: 5 additions & 0 deletions keras_core/backend/numpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,8 @@ def fori_loop(lower, upper, body_fun, init_val):

def stop_gradient(x):
pass


def unstack(x, num=None, axis=0):
x = np.moveaxis(x, axis, 0)
return [x[i] for i in range(x.shape[0])]
4 changes: 4 additions & 0 deletions keras_core/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,7 @@ def fori_loop(lower, upper, body_fun, init_val):

def stop_gradient(variable):
return tf.stop_gradient(variable)


def unstack(x, num=None, axis=0):
return tf.unstack(x, num=num, axis=axis)
4 changes: 4 additions & 0 deletions keras_core/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,7 @@ def stop_gradient(variable):
# We can't use `.requires_grad_(False)` here since it only
# works when the tensor is a leaf node in the graph.
return variable.detach()


def unstack(x, num=None, axis=0):
return x.unbind(axis)
54 changes: 54 additions & 0 deletions keras_core/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,60 @@ def fori_loop(lower, upper, body_fun, init_val):
return backend.core.fori_loop(lower, upper, body_fun, init_val)


class Unstack(Operation):
def __init__(self, num=None, axis=0):
super().__init__()
self.num = num
self.axis = axis

def call(self, x):
return backend.core.unstack(x, self.num, self.axis)

def compute_output_spec(self, x):
axis = self.axis
if axis < 0:
axis = len(x.shape) + axis
output_shapes = x.shape[:axis] + x.shape[axis + 1 :]
num = self.num
if num is None:
num = x.shape[axis]
if num is None:
raise ValueError(
"Cannot infer argument `num` from shape "
f"{x.shape}. Either provide a tensor with a "
"concrete shape in the `axis` dimension or "
"explicitly pass the `num` argument."
)
output = [
KerasTensor(shape=output_shapes, dtype=x.dtype) for _ in range(num)
]
return output


@keras_core_export("keras_core.ops.unstack")
def unstack(x, num=None, axis=0):
"""Unpacks the given dimension of a rank-R tensor into rank-(R-1) tensors.
Args:
x: The input tensor.
num: The length of the dimension axis. Automatically inferred
if `None`.
axis: The axis along which to unpack.
Returns:
A list of tensors unpacked along the given axis.
Example:
>>> x = keras_core.ops.array([[1, 2], [3, 4]])
>>> keras_core.ops.unstack(x, axis=0)
[array([1, 2]), array([3, 4])]
"""
if any_symbolic_tensors((x,)):
return Unstack(num, axis).symbolic_call(x)
return backend.core.unstack(x, num=num, axis=axis)


@keras_core_export("keras_core.ops.shape")
def shape(x):
"""Gets the shape of the tensor input.
Expand Down
32 changes: 32 additions & 0 deletions keras_core/ops/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ def body_fun(i, x):
result = core.fori_loop(0, 10, body_fun, initial_value)
self.assertEqual(result.shape, (3, 5, 7))

def test_unstack(self):
x = KerasTensor((2, 3, 4))
axis = 1
out = core.unstack(x, axis=axis)
self.assertEqual(len(out), 3)
for o in out:
self.assertEqual(o.shape, (2, 4))

x = KerasTensor((2, None, None))
axis, num = 1, 3
out = core.unstack(x, num=num, axis=axis)
self.assertEqual(len(out), 3)
for o in out:
self.assertEqual(o.shape, (2, None))

with self.assertRaisesRegex(
ValueError, r"Cannot infer argument `num` from shape"
):
core.unstack(x, axis=axis)


class CoreOpsCorrectnessTest(testing.TestCase):
def test_scatter(self):
Expand Down Expand Up @@ -298,3 +318,15 @@ def test_cond(self):
lambda: KerasTensor((3,)),
lambda: KerasTensor((4,)),
)

def test_unstack(self):
rng = np.random.default_rng(0)
x = rng.uniform(size=(2, 3, 4))
x_tensor = ops.convert_to_tensor(x)
axis = 1
out = ops.unstack(x_tensor, axis=axis)
out_ex = [x[:, i, :] for i in range(x.shape[axis])]
self.assertEqual(len(out), len(out_ex))
for o, o_e in zip(out, out_ex):
o = ops.convert_to_numpy(o)
self.assertAllClose(o, o_e)

0 comments on commit 2601ce3

Please sign in to comment.