From 2601ce318d05de8e84f449501e79f2723c6e2795 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Wed, 26 Jul 2023 15:57:25 +0000 Subject: [PATCH] Add the unstack operation to keras core (#597) * Add the unstack operation to keras core * Add a symbolic version of the op and add tests * Correct/Add docs for the x/num argument. [skip ci] * Use assertRaisesRegex and test for error msg --- keras_core/backend/jax/core.py | 7 ++++ keras_core/backend/numpy/core.py | 5 +++ keras_core/backend/tensorflow/core.py | 4 ++ keras_core/backend/torch/core.py | 4 ++ keras_core/ops/core.py | 54 +++++++++++++++++++++++++++ keras_core/ops/core_test.py | 32 ++++++++++++++++ 6 files changed, 106 insertions(+) diff --git a/keras_core/backend/jax/core.py b/keras_core/backend/jax/core.py index f0cb022fe..019aff77a 100644 --- a/keras_core/backend/jax/core.py +++ b/keras_core/backend/jax/core.py @@ -287,3 +287,10 @@ def fori_loop(lower, upper, body_fun, init_val): def stop_gradient(variable): return jax.lax.stop_gradient(variable) + + +def unstack(x, num=None, axis=0): + return [ + jax.lax.index_in_dim(x, i, axis, keepdims=False) + for i in range(x.shape[axis]) + ] diff --git a/keras_core/backend/numpy/core.py b/keras_core/backend/numpy/core.py index dd1a91130..880e00548 100644 --- a/keras_core/backend/numpy/core.py +++ b/keras_core/backend/numpy/core.py @@ -217,3 +217,8 @@ def fori_loop(lower, upper, body_fun, init_val): def stop_gradient(x): pass + + +def unstack(x, num=None, axis=0): + x = np.moveaxis(x, axis, 0) + return [x[i] for i in range(x.shape[0])] diff --git a/keras_core/backend/tensorflow/core.py b/keras_core/backend/tensorflow/core.py index 8085e9659..bf2496d1e 100644 --- a/keras_core/backend/tensorflow/core.py +++ b/keras_core/backend/tensorflow/core.py @@ -183,3 +183,7 @@ def fori_loop(lower, upper, body_fun, init_val): def stop_gradient(variable): return tf.stop_gradient(variable) + + +def unstack(x, num=None, axis=0): + return tf.unstack(x, num=num, axis=axis) diff --git a/keras_core/backend/torch/core.py b/keras_core/backend/torch/core.py index 22f7dfff6..1b6afa7f5 100644 --- a/keras_core/backend/torch/core.py +++ b/keras_core/backend/torch/core.py @@ -353,3 +353,7 @@ def stop_gradient(variable): # We can't use `.requires_grad_(False)` here since it only # works when the tensor is a leaf node in the graph. return variable.detach() + + +def unstack(x, num=None, axis=0): + return x.unbind(axis) diff --git a/keras_core/ops/core.py b/keras_core/ops/core.py index 9b7d297da..37e1a9461 100644 --- a/keras_core/ops/core.py +++ b/keras_core/ops/core.py @@ -344,6 +344,60 @@ def fori_loop(lower, upper, body_fun, init_val): return backend.core.fori_loop(lower, upper, body_fun, init_val) +class Unstack(Operation): + def __init__(self, num=None, axis=0): + super().__init__() + self.num = num + self.axis = axis + + def call(self, x): + return backend.core.unstack(x, self.num, self.axis) + + def compute_output_spec(self, x): + axis = self.axis + if axis < 0: + axis = len(x.shape) + axis + output_shapes = x.shape[:axis] + x.shape[axis + 1 :] + num = self.num + if num is None: + num = x.shape[axis] + if num is None: + raise ValueError( + "Cannot infer argument `num` from shape " + f"{x.shape}. Either provide a tensor with a " + "concrete shape in the `axis` dimension or " + "explicitly pass the `num` argument." + ) + output = [ + KerasTensor(shape=output_shapes, dtype=x.dtype) for _ in range(num) + ] + return output + + +@keras_core_export("keras_core.ops.unstack") +def unstack(x, num=None, axis=0): + """Unpacks the given dimension of a rank-R tensor into rank-(R-1) tensors. + + Args: + x: The input tensor. + num: The length of the dimension axis. Automatically inferred + if `None`. + axis: The axis along which to unpack. + + Returns: + A list of tensors unpacked along the given axis. + + Example: + + >>> x = keras_core.ops.array([[1, 2], [3, 4]]) + >>> keras_core.ops.unstack(x, axis=0) + [array([1, 2]), array([3, 4])] + """ + if any_symbolic_tensors((x,)): + return Unstack(num, axis).symbolic_call(x) + return backend.core.unstack(x, num=num, axis=axis) + + @keras_core_export("keras_core.ops.shape") def shape(x): """Gets the shape of the tensor input. diff --git a/keras_core/ops/core_test.py b/keras_core/ops/core_test.py index b248c3dab..f6e6d4e15 100644 --- a/keras_core/ops/core_test.py +++ b/keras_core/ops/core_test.py @@ -56,6 +56,26 @@ def body_fun(i, x): result = core.fori_loop(0, 10, body_fun, initial_value) self.assertEqual(result.shape, (3, 5, 7)) + def test_unstack(self): + x = KerasTensor((2, 3, 4)) + axis = 1 + out = core.unstack(x, axis=axis) + self.assertEqual(len(out), 3) + for o in out: + self.assertEqual(o.shape, (2, 4)) + + x = KerasTensor((2, None, None)) + axis, num = 1, 3 + out = core.unstack(x, num=num, axis=axis) + self.assertEqual(len(out), 3) + for o in out: + self.assertEqual(o.shape, (2, None)) + + with self.assertRaisesRegex( + ValueError, r"Cannot infer argument `num` from shape" + ): + core.unstack(x, axis=axis) + class CoreOpsCorrectnessTest(testing.TestCase): def test_scatter(self): @@ -298,3 +318,15 @@ def test_cond(self): lambda: KerasTensor((3,)), lambda: KerasTensor((4,)), ) + + def test_unstack(self): + rng = np.random.default_rng(0) + x = rng.uniform(size=(2, 3, 4)) + x_tensor = ops.convert_to_tensor(x) + axis = 1 + out = ops.unstack(x_tensor, axis=axis) + out_ex = [x[:, i, :] for i in range(x.shape[axis])] + self.assertEqual(len(out), len(out_ex)) + for o, o_e in zip(out, out_ex): + o = ops.convert_to_numpy(o) + self.assertAllClose(o, o_e)