From ffd736e27ac5d5c9889c0b756b10055cb4528d7b Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Tue, 19 Sep 2023 12:47:41 +0800 Subject: [PATCH] Makes `ops.split` in torch consistent with other backends (#914) * Makes split in torch consistent with other backends * Update error msg --- keras_core/backend/torch/numpy.py | 18 +++++++++-- keras_core/ops/numpy_test.py | 53 +++++++++++-------------------- 2 files changed, 34 insertions(+), 37 deletions(-) diff --git a/keras_core/backend/torch/numpy.py b/keras_core/backend/torch/numpy.py index ba613cfcd..4986320f6 100644 --- a/keras_core/backend/torch/numpy.py +++ b/keras_core/backend/torch/numpy.py @@ -800,22 +800,34 @@ def sort(x, axis=-1): def split(x, indices_or_sections, axis=0): x = convert_to_tensor(x) + dim = x.shape[axis] if isinstance(indices_or_sections, (list, tuple)): idxs = convert_to_tensor(indices_or_sections) start_size = indices_or_sections[0] - end_size = x.shape[axis] - indices_or_sections[-1] + end_size = dim - indices_or_sections[-1] chunk_sizes = ( [start_size] + torch.diff(idxs).type(torch.int).tolist() + [end_size] ) else: - chunk_sizes = x.shape[axis] // indices_or_sections - return torch.split( + if dim % indices_or_sections != 0: + raise ValueError( + f"Received indices_or_sections={indices_or_sections} " + f"(interpreted as a number of sections) and axis={axis}, " + f"but input dimension x.shape[{axis}]={x.shape[axis]} " + f"is not divisible by {indices_or_sections}. " + f"Full input shape: x.shape={x.shape}" + ) + chunk_sizes = dim // indices_or_sections + out = torch.split( tensor=x, split_size_or_sections=chunk_sizes, dim=axis, ) + if dim == 0 and isinstance(indices_or_sections, int): + out = tuple(out[0].clone() for _ in range(indices_or_sections)) + return out def stack(x, axis=0): diff --git a/keras_core/ops/numpy_test.py b/keras_core/ops/numpy_test.py index 0ff0d364b..2ae7f12c6 100644 --- a/keras_core/ops/numpy_test.py +++ b/keras_core/ops/numpy_test.py @@ -3283,40 +3283,25 @@ def test_sort(self): def test_split(self): x = np.array([[1, 2, 3], [3, 2, 1]]) - if backend.backend() == "torch": - self.assertAllClose( - [backend.convert_to_numpy(t) for t in knp.split(x, 2)], - np.split(x, 2), - ) - self.assertAllClose( - [backend.convert_to_numpy(t) for t in knp.Split(2)(x)], - np.split(x, 2), - ) - self.assertAllClose( - [ - backend.convert_to_numpy(t) - for t in knp.split(x, [1, 2], axis=1) - ], - np.split(x, [1, 2], axis=1), - ) - self.assertAllClose( - [ - backend.convert_to_numpy(t) - for t in knp.Split([1, 2], axis=1)(x) - ], - np.split(x, [1, 2], axis=1), - ) - else: - self.assertAllClose(knp.split(x, 2), np.split(x, 2)) - self.assertAllClose(knp.Split(2)(x), np.split(x, 2)) - self.assertAllClose( - knp.split(x, [1, 2], axis=1), - np.split(x, [1, 2], axis=1), - ) - self.assertAllClose( - knp.Split([1, 2], axis=1)(x), - np.split(x, [1, 2], axis=1), - ) + self.assertAllClose(knp.split(x, 2), np.split(x, 2)) + self.assertAllClose(knp.Split(2)(x), np.split(x, 2)) + self.assertAllClose( + knp.split(x, [1, 2], axis=1), + np.split(x, [1, 2], axis=1), + ) + self.assertAllClose( + knp.Split([1, 2], axis=1)(x), + np.split(x, [1, 2], axis=1), + ) + + # test invalid indices_or_sections + with self.assertRaises(Exception): + knp.split(x, 3) + + # test zero dimension + x = np.ones(shape=(0,)) + self.assertEqual(len(knp.split(x, 2)), 2) + self.assertEqual(len(knp.Split(2)(x)), 2) def test_sqrt(self): x = np.array([[1, 4, 9], [16, 25, 36]])