From 8a5020b0ca8b135f1b057bdda07129277aa936c4 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Thu, 30 Nov 2023 21:21:35 +0100 Subject: [PATCH 01/95] Draft for the ivy.unflatten | not finished --- ivy/functional/backends/torch/manipulation.py | 10 ++++ .../frontends/torch/miscellaneous_ops.py | 5 ++ ivy/functional/ivy/manipulation.py | 48 +++++++++++++++++++ 3 files changed, 63 insertions(+) diff --git a/ivy/functional/backends/torch/manipulation.py b/ivy/functional/backends/torch/manipulation.py index 4cfb62addec6e..18fe925de1ad1 100644 --- a/ivy/functional/backends/torch/manipulation.py +++ b/ivy/functional/backends/torch/manipulation.py @@ -39,6 +39,16 @@ def concat( concat.support_native_out = True +def unflatten( + x: torch.Tensor, + /, + *, + axis: int = 0, + sizes: Tuple[int] = None, +) -> torch.Tensor: + return torch.unflatten(input=x, dim=axis, sizes=sizes) + + def expand_dims( x: torch.Tensor, /, diff --git a/ivy/functional/frontends/torch/miscellaneous_ops.py b/ivy/functional/frontends/torch/miscellaneous_ops.py index 23eeef185e73e..ecf5e18c5d110 100644 --- a/ivy/functional/frontends/torch/miscellaneous_ops.py +++ b/ivy/functional/frontends/torch/miscellaneous_ops.py @@ -502,6 +502,11 @@ def triu_indices(row, col, offset=0, dtype="int64", device="cpu", layout=None): return ivy.stack(ivy.nonzero(sample_matrix)).astype(dtype) +@to_ivy_arrays_and_back +def unflatten(x, /, *, axis, sizes): + return ivy.unflatten(x, axis, sizes) + + @to_ivy_arrays_and_back def vander(x, N=None, increasing=False): # if N == 0: diff --git a/ivy/functional/ivy/manipulation.py b/ivy/functional/ivy/manipulation.py index 318c663b739b7..44eb341fac211 100644 --- a/ivy/functional/ivy/manipulation.py +++ b/ivy/functional/ivy/manipulation.py @@ -330,6 +330,54 @@ def flip( return current_backend(x).flip(x, copy=copy, axis=axis, out=out) +@handle_exceptions +@handle_backend_invalid +@handle_nestable +@handle_array_like_without_promotion +@handle_view +@handle_out_argument +@to_native_arrays_and_back +@handle_array_function +@handle_device +def unflatten( + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + axis: int, + sizes: Tuple[int], +) -> ivy.Array: + """ + Expand a dimension of the input tensor over multiple dimensions. + + Parameters + ---------- + x + The input tensor. + axis + Dimension to be unflattened, specified as an index into input.shape. + sizes + New shape of the unflattened dimension. One of its elements can be -1 in + which case the corresponding output dimension is inferred. Otherwise, + the product of sizes must equal input.shape[dim]. + + Returns + ------- + ret + A View of input with the specified dimension unflattened. + + + Examples + -------- + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape + torch.Size([5, 2, 2, 3, 1, 1, 3]) + """ + return current_backend(x).unflatten(x, axis, sizes) + + @handle_exceptions @handle_backend_invalid @handle_nestable From 48df5d3ea2ffb21b83524f56b19efc533e9bd74d Mon Sep 17 00:00:00 2001 From: Abdullah Sabry Date: Fri, 1 Dec 2023 04:51:55 -0800 Subject: [PATCH 02/95] feat: Added `tridiagonal_solve` in tensorflow frontend (#23279) Co-authored-by: NripeshN --- ivy/functional/frontends/tensorflow/linalg.py | 49 ++++++++ .../test_tensorflow/test_linalg.py | 106 ++++++++++++++++++ 2 files changed, 155 insertions(+) diff --git a/ivy/functional/frontends/tensorflow/linalg.py b/ivy/functional/frontends/tensorflow/linalg.py index b0e8bdfdc7c5f..c0dcad4dd5906 100644 --- a/ivy/functional/frontends/tensorflow/linalg.py +++ b/ivy/functional/frontends/tensorflow/linalg.py @@ -458,3 +458,52 @@ def tensorsolve(a, b, axes): @to_ivy_arrays_and_back def trace(x, name=None): return ivy.trace(x, axis1=-2, axis2=-1) + + +@to_ivy_arrays_and_back +@with_supported_dtypes( + { + "2.13.0 and below": ( + "float32", + "float64", + "complex64", + "complex128", + ) + }, + "tensorflow", +) +def tridiagonal_solve( + diagonals, + rhs, + diagonals_format="compact", + transpose_rhs=False, + conjugate_rhs=False, + name=None, + partial_pivoting=True, + perturb_singular=False, +): + if transpose_rhs is True: + rhs_copy = ivy.matrix_transpose(rhs) + if conjugate_rhs is True: + rhs_copy = ivy.conj(rhs) + if not transpose_rhs and not conjugate_rhs: + rhs_copy = ivy.array(rhs) + + if diagonals_format == "matrix": + return ivy.solve(diagonals, rhs_copy) + elif diagonals_format in ["sequence", "compact"]: + diagonals = ivy.array(diagonals) + dim = diagonals[0].shape[0] + diagonals[[0, -1], [-1, 0]] = 0 + dummy_idx = [0, 0] + indices = ivy.array([ + [(i, i + 1) for i in range(dim - 1)] + [dummy_idx], + [(i, i) for i in range(dim)], + [dummy_idx] + [(i + 1, i) for i in range(dim - 1)], + ]) + constructed_matrix = ivy.scatter_nd( + indices, diagonals, shape=ivy.array([dim, dim]) + ) + return ivy.solve(constructed_matrix, rhs_copy) + else: + raise "Unexpected diagonals_format" diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py index fdc93ffc80933..e3d22b6f3f4ad 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py @@ -153,6 +153,41 @@ def _get_second_matrix(draw): ) +@st.composite +def _get_tridiagonal_dtype_matrix_format(draw): + input_dtype_strategy = st.shared( + st.sampled_from(draw(helpers.get_dtypes("float_and_complex"))), + key="shared_dtype", + ) + input_dtype = draw(input_dtype_strategy) + shared_size = draw( + st.shared(helpers.ints(min_value=2, max_value=4), key="shared_size") + ) + diagonals_format = draw(st.sampled_from(["compact", "sequence", "matrix"])) + if diagonals_format == "matrix": + matrix = draw( + helpers.array_values( + dtype=input_dtype, + shape=tuple([shared_size, shared_size]), + min_value=2, + max_value=5, + ).filter(tridiagonal_matrix_filter) + ) + elif diagonals_format in ["compact", "sequence"]: + matrix = draw( + helpers.array_values( + dtype=input_dtype, + shape=tuple([3, shared_size]), + min_value=2, + max_value=5, + ).filter(tridiagonal_compact_filter) + ) + if diagonals_format == "sequence": + matrix = list(matrix) + + return input_dtype, matrix, diagonals_format + + # --- Main --- # # ------------ # @@ -1207,3 +1242,74 @@ def test_tensorflow_trace( fn_tree=fn_tree, x=x[0], ) + + +# tridiagonal_solve +@handle_frontend_test( + fn_tree="tensorflow.linalg.tridiagonal_solve", + x=_get_tridiagonal_dtype_matrix_format(), + y=_get_second_matrix(), + transpose_rhs=st.just(False), + conjugate_rhs=st.booleans(), +) +def test_tensorflow_tridiagonal_solve( + *, + x, + y, + transpose_rhs, + conjugate_rhs, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtype1, x1, diagonals_format = x + input_dtype2, x2 = y + helpers.test_frontend_function( + input_dtypes=[input_dtype1, input_dtype2], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-3, + atol=1e-3, + diagonals=x1, + rhs=x2, + diagonals_format=diagonals_format, + transpose_rhs=transpose_rhs, + conjugate_rhs=conjugate_rhs, + ) + + +def tridiagonal_compact_filter(x): + diagonals = ivy.array(x) + dim = diagonals[0].shape[0] + diagonals[[0, -1], [-1, 0]] = 0 + dummy_idx = [0, 0] + indices = ivy.array([ + [(i, i + 1) for i in range(dim - 1)] + [dummy_idx], + [(i, i) for i in range(dim)], + [dummy_idx] + [(i + 1, i) for i in range(dim - 1)], + ]) + matrix = ivy.scatter_nd( + indices, diagonals, ivy.array([dim, dim]), reduction="replace" + ) + return tridiagonal_matrix_filter(matrix) + + +def tridiagonal_matrix_filter(x): + dim = x.shape[0] + if ivy.abs(ivy.det(x)) < 1e-3: + return False + for i in range(dim): + for j in range(dim): + cell = x[i][j] + if i == j or i == j - 1 or i == j + 1: + if cell == 0: + return False + else: + if cell != 0: + return False + return True From ed14e0e58717e14cd4ef4c887587f3347b8ec859 Mon Sep 17 00:00:00 2001 From: Abdullah Sabry Date: Fri, 1 Dec 2023 04:54:01 -0800 Subject: [PATCH 03/95] feat: added `floor_mod` in paddle frontend (#26064) Co-authored-by: NripeshN --- .../frontends/paddle/tensor/tensor.py | 6 ++++ .../test_paddle/test_tensor/test_tensor.py | 36 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/ivy/functional/frontends/paddle/tensor/tensor.py b/ivy/functional/frontends/paddle/tensor/tensor.py index 2d3e814391442..530b36b1087f4 100644 --- a/ivy/functional/frontends/paddle/tensor/tensor.py +++ b/ivy/functional/frontends/paddle/tensor/tensor.py @@ -781,6 +781,12 @@ def floor_divide(self, y, name=None): def mod(self, y, name=None): return paddle_frontend.Tensor(ivy.fmod(self._ivy_array, _to_ivy_array(y))) + @with_supported_dtypes( + {"2.5.2 and below": ("float32", "float64", "int32", "int64")}, "paddle" + ) + def floor_mod(self, y, name=None): + return paddle_frontend.remainder(self, y) + # cond @with_supported_dtypes({"2.5.2 and below": ("float32", "float64")}, "paddle") def cond(self, p=None, name=None): diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py index 3ab33f46e4b0f..345fe4cbc04a5 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py @@ -5392,6 +5392,42 @@ def test_paddle_tensor_expand( ) +# floor_mod +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="paddle.to_tensor", + method_name="floor_mod", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + min_value=2, + shared_dtype=True, + ), +) +def test_paddle_tensor_floor_mod( + dtypes_and_x, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x = dtypes_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"y": x[1]}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", From f0e2366201d2b2ad48dfbd4e2609842e86922291 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Fri, 1 Dec 2023 16:57:46 +0000 Subject: [PATCH 04/95] feat: add base Splitter class for decision trees in sklearn front --- .../frontends/sklearn/tree/_splitter.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 ivy/functional/frontends/sklearn/tree/_splitter.py diff --git a/ivy/functional/frontends/sklearn/tree/_splitter.py b/ivy/functional/frontends/sklearn/tree/_splitter.py new file mode 100644 index 0000000000000..e199271527150 --- /dev/null +++ b/ivy/functional/frontends/sklearn/tree/_splitter.py @@ -0,0 +1,80 @@ +import ivy + + +class Splitter: + def __init__( + self, + criterion, + max_features: int, + min_samples_leaf: int, + min_weight_leaf: float, + random_state, + *args, + ): + self.criterion = criterion + self.n_samples = 0 + self.n_features = 0 + self.max_features = max_features + self.min_samples_leaf = min_samples_leaf + self.min_weight_leaf = min_weight_leaf + self.random_state = random_state + + def init( + self, + X, + y, + sample_weight, + missing_values_in_feature_mask, + *args, + ): + + n_samples = X.shape[0] + self.samples = ivy.empty(n_samples, dtype=ivy.int32) + samples = self.samples + j = 0 + weighted_n_samples = 0.0 + + for i in range(n_samples): + if sample_weight is None or sample_weight[i] != 0.0: + samples[j] = i + j += 1 + if sample_weight is not None: + weighted_n_samples += sample_weight[i] + else: + weighted_n_samples += 1.0 + + self.n_samples = j + self.weighted_n_samples = weighted_n_samples + n_features = X.shape[1] + self.features = ivy.arange(n_features, dtype=ivy.int32) + self.n_features = n_features + self.feature_values = ivy.empty(n_samples, dtype=ivy.float32) + self.constant_features = ivy.empty(n_features, dtype=ivy.int32) + self.y = y + self.sample_weight = sample_weight + if missing_values_in_feature_mask is not None: + self.criterion.init_sum_missing() + return 0 + + def node_reset(self, start, end, weighted_n_node_samples): + self.start = start + self.end = end + self.criterion.init( + self.y, + self.sample_weight, + self.weighted_n_samples, + self.samples, + start, + end, + ) + weighted_n_node_samples = self.criterion.weighted_n_node_samples + return 0, weighted_n_node_samples + + def node_split(self, impurity, split, n_constant_features): + pass + + def node_value(self, dest, node_id): + return self.criterion.node_value(dest, node_id) + + def node_impurity(self): + return self.criterion.node_impurity() From b892dd4044bde6f01bb30a0132ead19e0e4ccc27 Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Fri, 1 Dec 2023 17:00:21 +0000 Subject: [PATCH 05/95] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ivy/functional/frontends/sklearn/tree/_splitter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ivy/functional/frontends/sklearn/tree/_splitter.py b/ivy/functional/frontends/sklearn/tree/_splitter.py index e199271527150..ecf72db430ef6 100644 --- a/ivy/functional/frontends/sklearn/tree/_splitter.py +++ b/ivy/functional/frontends/sklearn/tree/_splitter.py @@ -27,7 +27,6 @@ def init( missing_values_in_feature_mask, *args, ): - n_samples = X.shape[0] self.samples = ivy.empty(n_samples, dtype=ivy.int32) samples = self.samples From ffdf7f3d8ca4b8d597590a6e2ef541cda3a2aa88 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Fri, 1 Dec 2023 17:12:59 +0000 Subject: [PATCH 06/95] feat: add base structure for BestSplitter method which overrides init and node_split from Splitter --- .../frontends/sklearn/tree/_splitter.py | 40 +++++++++++++++++-- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/ivy/functional/frontends/sklearn/tree/_splitter.py b/ivy/functional/frontends/sklearn/tree/_splitter.py index ecf72db430ef6..33f4da240449c 100644 --- a/ivy/functional/frontends/sklearn/tree/_splitter.py +++ b/ivy/functional/frontends/sklearn/tree/_splitter.py @@ -5,9 +5,9 @@ class Splitter: def __init__( self, criterion, - max_features: int, - min_samples_leaf: int, - min_weight_leaf: float, + max_features, + min_samples_leaf, + min_weight_leaf, random_state, *args, ): @@ -77,3 +77,37 @@ def node_value(self, dest, node_id): def node_impurity(self): return self.criterion.node_impurity() + + +class BestSplitter(Splitter): + def init( + self, + X, + y, + sample_weight, + missing_values_in_feature_mask, + *args, + ): + Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask, *args) + self.partitioner = None + + def node_split(self, impurity, split, n_constant_features): + return node_split_best( + self, + self.partitioner, + self.criterion, + impurity, + split, + n_constant_features, + ) + + +def node_split_best( + splitter: Splitter, + partitioner, + criterion, + impurity, + split, + n_constant_features +): + pass From fa6ab8e967b73125e25f85996bc3b52e535fa474 Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Fri, 1 Dec 2023 17:14:38 +0000 Subject: [PATCH 07/95] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ivy/functional/frontends/sklearn/tree/_splitter.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/ivy/functional/frontends/sklearn/tree/_splitter.py b/ivy/functional/frontends/sklearn/tree/_splitter.py index 33f4da240449c..6ecd472d2b971 100644 --- a/ivy/functional/frontends/sklearn/tree/_splitter.py +++ b/ivy/functional/frontends/sklearn/tree/_splitter.py @@ -103,11 +103,6 @@ def node_split(self, impurity, split, n_constant_features): def node_split_best( - splitter: Splitter, - partitioner, - criterion, - impurity, - split, - n_constant_features + splitter: Splitter, partitioner, criterion, impurity, split, n_constant_features ): pass From 00a781f20fbfce95495e283285ac2fdda4c527a6 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Fri, 1 Dec 2023 18:10:59 +0000 Subject: [PATCH 08/95] feat: add DensePartitioner required for both the best splitter and random splitter for decision trees. check https://github.com/scikit-learn/scikit-learn/blob/a6603bcd8953e6bcc94081fa5c5d5741eb408927/sklearn/tree/_splitter.pyx#L860 --- .../frontends/sklearn/tree/_splitter.py | 169 ++++++++++++++++++ 1 file changed, 169 insertions(+) diff --git a/ivy/functional/frontends/sklearn/tree/_splitter.py b/ivy/functional/frontends/sklearn/tree/_splitter.py index 6ecd472d2b971..c222d7361d41f 100644 --- a/ivy/functional/frontends/sklearn/tree/_splitter.py +++ b/ivy/functional/frontends/sklearn/tree/_splitter.py @@ -1,5 +1,7 @@ import ivy +FEATURE_THRESHOLD = 1e-7 + class Splitter: def __init__( @@ -79,6 +81,169 @@ def node_impurity(self): return self.criterion.node_impurity() +class DensePartitioner: + X = [] + samples = [] + feature_values = [] + start = 0 + end = 0 + n_missing = 0 + missing_values_in_feature_mask = [] + + def __init__( + self, + X, + samples, + feature_values, + missing_values_in_feature_mask, + ): + self.X = X + self.samples = samples + self.feature_values = feature_values + self.missing_values_in_feature_mask = missing_values_in_feature_mask + + def init_node_split(self, start, end): + self.start = start + self.end = end + self.n_missing = 0 + + def sort_samples_and_feature_values(self, current_feature): + feature_values = self.feature_values + X = self.X + samples = self.samples + n_missing = 0 + missing_values_in_feature_mask = self.missing_values_in_feature_mask + if ( + missing_values_in_feature_mask is not None + and missing_values_in_feature_mask[current_feature] + ): + i, current_end = self.start, self.end - 1 + while i <= current_end: + if ivy.isnan(X[samples[current_end], current_feature]): + n_missing += 1 + current_end -= 1 + continue + if ivy.isnan(X[samples[i], current_feature]): + samples[i], samples[current_end] = samples[current_end], samples[i] + n_missing += 1 + current_end -= 1 + feature_values[i] = X[samples[i], current_feature] + i += 1 + else: + for i in range(self.start, self.end): + feature_values[i] = X[int(samples[i]), int(current_feature)] + ( + self.feature_values[self.start : self.end], + self.samples[self.start : self.end], + ) = sort( + feature_values[self.start : self.end], + samples[self.start : self.end], + self.end - self.start - n_missing, + ) + self.n_missing = n_missing + + def find_min_max( + self, + current_feature: int, + min_feature_value_out: float, + max_feature_value_out: float, + ): + current_feature = 0 + X = self.X + samples = self.samples + min_feature_value = X[samples[self.start], current_feature] + max_feature_value = min_feature_value + feature_values = self.feature_values + feature_values[self.start] = min_feature_value + for p in range(self.start + 1, self.end): + current_feature_value = X[samples[p], current_feature] + feature_values[p] = current_feature_value + + if current_feature_value < min_feature_value: + min_feature_value = current_feature_value + elif current_feature_value > max_feature_value: + max_feature_value = current_feature_value + return min_feature_value, max_feature_value + + def next_p(self, p_prev: int, p: int): + feature_values = self.feature_values + end_non_missing = self.end - self.n_missing + + while ( + p + 1 < end_non_missing + and feature_values[p + 1] <= feature_values[p] + FEATURE_THRESHOLD + ): + p += 1 + p_prev = p + p += 1 + return p_prev, p + + def partition_samples(self, current_thershold: float): + p = self.start + partition_end = self.end + samples = self.samples + feature_values = self.feature_values + while p < partition_end: + if feature_values[p] <= current_thershold: + p += 1 + else: + partition_end -= 1 + + feature_values[p], feature_values[partition_end] = ( + feature_values[partition_end], + feature_values[p], + ) + samples[p], samples[partition_end] = ( + samples[partition_end], + samples[p], + ) + return partition_end + + def partition_samples_final( + self, + best_pos, + best_threshold, + best_feature, + best_n_missing, + ): + start = self.start + p = start + end = self.end - 1 + partition_end = end - best_n_missing + samples = self.samples + X = self.X + + if best_n_missing != 0: + while p < partition_end: + if ivy.isnan(X[samples[end], best_feature]): + end -= 1 + continue + current_value = X[samples[p], best_feature] + if ivy.isnan(current_value): + samples[p], samples[end] = samples[end], samples[p] + end -= 1 + current_value = X[samples[p], best_feature] + if current_value <= best_threshold: + p += 1 + else: + samples[p], samples[partition_end] = ( + samples[partition_end], + samples[p], + ) + partition_end -= 1 + else: + while p < partition_end: + if X[samples[p], best_feature] <= best_threshold: + p += 1 + else: + samples[p], samples[partition_end] = ( + samples[partition_end], + samples[p], + ) + partition_end -= 1 + self.samples = samples + + class BestSplitter(Splitter): def init( self, @@ -106,3 +271,7 @@ def node_split_best( splitter: Splitter, partitioner, criterion, impurity, split, n_constant_features ): pass + + +def sort(feature_values, samples, n): + return 0, 0 From 17325edbd276921eacd2ed3e1a8c2624dd720356 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Fri, 1 Dec 2023 18:20:56 +0000 Subject: [PATCH 09/95] feat: add custom sort function for DensePartitions's sort_samples_and_feature_values --- ivy/functional/frontends/sklearn/tree/_splitter.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ivy/functional/frontends/sklearn/tree/_splitter.py b/ivy/functional/frontends/sklearn/tree/_splitter.py index c222d7361d41f..037233038533f 100644 --- a/ivy/functional/frontends/sklearn/tree/_splitter.py +++ b/ivy/functional/frontends/sklearn/tree/_splitter.py @@ -274,4 +274,7 @@ def node_split_best( def sort(feature_values, samples, n): - return 0, 0 + if n == 0: + return + idx = ivy.argsort(feature_values) + return feature_values[idx], samples[idx] From e66e1d527be1f463c842482a1711572b1b524ec4 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Sat, 2 Dec 2023 01:52:57 +0400 Subject: [PATCH 10/95] test commit (#27421) From cbeaa5e8b50658819922606e3616b9d5981ef821 Mon Sep 17 00:00:00 2001 From: Abdullah Sabry Date: Fri, 1 Dec 2023 13:55:39 -0800 Subject: [PATCH 11/95] feat: Added `torchTensor.frac` to torch frontend (#27417) Co-authored-by: NripeshN --- ivy/functional/frontends/torch/tensor.py | 4 ++ .../test_frontends/test_torch/test_tensor.py | 38 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index c6bf10cd8df2b..66193ab80f063 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -2017,6 +2017,10 @@ def uniform_(self, from_=0, to=1, *, generator=None): ) return self + @with_supported_dtypes({"2.1.1 and below": ("float32", "float64")}, "torch") + def frac(self, name=None): + return torch_frontend.frac(self._ivy_array) + @with_unsupported_dtypes( { "2.1.1 and below": ( diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index a4b88c6137f96..8ab5c29226aee 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -7088,6 +7088,44 @@ def test_torch_fmod_( ) +# frac +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="frac", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes(kind="valid"), + num_arrays=1, + max_value=1e6, + min_value=-1e6, + ), +) +def test_torch_frac( + dtype_and_x, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", From 6b880f917076c273213aac9704fafe2a8e39fa0c Mon Sep 17 00:00:00 2001 From: Haris Mahmood <70361308+hmahmood24@users.noreply.github.com> Date: Sat, 2 Dec 2023 07:29:16 +0000 Subject: [PATCH 12/95] Add __getstate__ and __setstate__ methods to frontend torch.nn.Module --- ivy/functional/frontends/torch/nn/modules/module.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ivy/functional/frontends/torch/nn/modules/module.py b/ivy/functional/frontends/torch/nn/modules/module.py index bba71b98e2a6e..319a9989ac9d9 100644 --- a/ivy/functional/frontends/torch/nn/modules/module.py +++ b/ivy/functional/frontends/torch/nn/modules/module.py @@ -312,3 +312,11 @@ def __dir__(self): keys = [key for key in keys if not key[0].isdigit()] return sorted(keys) + + def __getstate__(self): + state = self.__dict__.copy() + state.pop("_compiled_call_impl", None) + return state + + def __setstate__(self, state): + self.__dict__.update(state) From a39a99a89bec47a18bcc8082bc5625835c3dcb88 Mon Sep 17 00:00:00 2001 From: Sai-Suraj-27 Date: Sun, 3 Dec 2023 19:17:33 +0530 Subject: [PATCH 13/95] fix: Fixed typos (used `codespell` pre-commit hook) (#27426) Co-authored-by: Bhushan Srivastava <59949692+he11owthere@users.noreply.github.com> --- ivy/data_classes/array/experimental/creation.py | 2 +- ivy/data_classes/container/base.py | 8 ++++---- ivy/data_classes/container/experimental/creation.py | 4 ++-- ivy/engines/XLA/rust_api/xla_rs/xla_rs.cc | 2 +- ivy/functional/ivy/experimental/activations.py | 2 +- ivy/functional/ivy/experimental/creation.py | 2 +- ivy/functional/ivy/general.py | 2 +- ivy/stateful/module.py | 2 +- .../helpers/hypothesis_helpers/general_helpers.py | 2 +- scripts/backend_generation/generate.py | 6 +++--- 10 files changed, 16 insertions(+), 16 deletions(-) diff --git a/ivy/data_classes/array/experimental/creation.py b/ivy/data_classes/array/experimental/creation.py index fdca1bcffabf3..6d814ee70f85b 100644 --- a/ivy/data_classes/array/experimental/creation.py +++ b/ivy/data_classes/array/experimental/creation.py @@ -335,7 +335,7 @@ def polyval( Returns ------- ret - Simplified result of substituing x in the coefficients - final value of + Simplified result of substituting x in the coefficients - final value of polynomial. Examples diff --git a/ivy/data_classes/container/base.py b/ivy/data_classes/container/base.py index 7addfcff8c52c..b7b35452b3cde 100644 --- a/ivy/data_classes/container/base.py +++ b/ivy/data_classes/container/base.py @@ -1602,7 +1602,7 @@ def _cont_prune_key_chains_input_as_dict(self, key_chains, return_cont=None): # ---------------# def cont_duplicate_array_keychains(self): - duplciates = () + duplicates = () key_chains = self.cont_all_key_chains() skips = set() for i in range(len(key_chains)): @@ -1618,9 +1618,9 @@ def cont_duplicate_array_keychains(self): if key_chains[j] not in temp_duplicates: temp_duplicates += (key_chains[j],) if len(temp_duplicates) > 0: - duplciates += (temp_duplicates,) - skips = chain.from_iterable(duplciates) - return duplciates + duplicates += (temp_duplicates,) + skips = chain.from_iterable(duplicates) + return duplicates def cont_update_config(self, **config): new_config = {} diff --git a/ivy/data_classes/container/experimental/creation.py b/ivy/data_classes/container/experimental/creation.py index 03a0184d81eaf..540162d93e6fe 100644 --- a/ivy/data_classes/container/experimental/creation.py +++ b/ivy/data_classes/container/experimental/creation.py @@ -1290,7 +1290,7 @@ def static_polyval( Returns ------- ret - Output container containing simplified result of substituing x in the + Output container containing simplified result of substituting x in the coefficients - final value of polynomial. """ return ContainerBase.cont_multi_map_in_function( @@ -1385,7 +1385,7 @@ def polyval( Returns ------- ret - Output container containing simplified result of substituing x in the + Output container containing simplified result of substituting x in the coefficients - final value of polynomial. """ return self.static_polyval(self, coeffs, x) diff --git a/ivy/engines/XLA/rust_api/xla_rs/xla_rs.cc b/ivy/engines/XLA/rust_api/xla_rs/xla_rs.cc index 6fa315da050fe..e4f19d746b0e6 100644 --- a/ivy/engines/XLA/rust_api/xla_rs/xla_rs.cc +++ b/ivy/engines/XLA/rust_api/xla_rs/xla_rs.cc @@ -1188,7 +1188,7 @@ status execute(const pjrt_loaded_executable exe, const literal *inputs, ASSIGN_OR_RETURN_STATUS(buffer, client->BufferFromHostLiteral(*inputs[i], device)); // Wait for the transfer to have completed to avoid the literal potentially - // getting out of scope before it has been transfered. + // getting out of scope before it has been transferred. MAYBE_RETURN_STATUS(buffer->GetReadyFuture().Await()); input_buffer_ptrs.push_back(buffer.release()); } diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index fe15ccc31721f..1b1fda6b32a08 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -799,7 +799,7 @@ def celu( out: Optional[ivy.Array] = None, ) -> ivy.Array: """ - Apply the Continously Differentiable Exponential Linear Unit (CELU) activation + Apply the Continuously Differentiable Exponential Linear Unit (CELU) activation function to each element of the input. Parameters diff --git a/ivy/functional/ivy/experimental/creation.py b/ivy/functional/ivy/experimental/creation.py index bbda425e2ccbe..ff9ed655f1592 100644 --- a/ivy/functional/ivy/experimental/creation.py +++ b/ivy/functional/ivy/experimental/creation.py @@ -1197,7 +1197,7 @@ def polyval( Returns ------- ret - Simplified result of substituing x in the coefficients - final value + Simplified result of substituting x in the coefficients - final value of polynomial. Examples diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index f0e7ad7af4f76..5406a86e8367f 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -2841,7 +2841,7 @@ def get_item( if query.ndim == 0: if query is False: return ivy.zeros(shape=(0,) + x.shape, dtype=x.dtype) - return x[None] # eqivalent to ivy.expand_dims(x, axis=0) + return x[None] # equivalent to ivy.expand_dims(x, axis=0) query = ivy.nonzero(query, as_tuple=False) ret = ivy.gather_nd(x, query) else: diff --git a/ivy/stateful/module.py b/ivy/stateful/module.py index 41d5be1525d7b..b6f21663635b3 100644 --- a/ivy/stateful/module.py +++ b/ivy/stateful/module.py @@ -571,7 +571,7 @@ def __repr__(self): main_str += ")" return main_str - # Methods to be Optionally Overriden # + # Methods to be Optionally Overridden # # -----------------------------------# def _create_variables(self, *, device=None, dtype=None): diff --git a/ivy_tests/test_ivy/helpers/hypothesis_helpers/general_helpers.py b/ivy_tests/test_ivy/helpers/hypothesis_helpers/general_helpers.py index 971a19bdee0b2..7662f4d2753f5 100644 --- a/ivy_tests/test_ivy/helpers/hypothesis_helpers/general_helpers.py +++ b/ivy_tests/test_ivy/helpers/hypothesis_helpers/general_helpers.py @@ -165,7 +165,7 @@ def general_helpers_dtype_info_helper(backend, kind_dtype, dtype): # from array-api repo class BroadcastError(ValueError): - """Shapes do not broadcast with eachother.""" + """Shapes do not broadcast with each other.""" # from array-api repo diff --git a/scripts/backend_generation/generate.py b/scripts/backend_generation/generate.py index ef2a7af9789f9..f10355a0148c2 100644 --- a/scripts/backend_generation/generate.py +++ b/scripts/backend_generation/generate.py @@ -268,9 +268,9 @@ def _update_valid_config_value(key): ret = ret.strip("") if ret == "": return True - indicies = ret.split(" ") - indicies = [int(item.strip(" ")) for item in indicies] - for i in sorted(indicies, reverse=True): + indices = ret.split(" ") + indices = [int(item.strip(" ")) for item in indices] + for i in sorted(indices, reverse=True): del config_valids[key][i] return True From f56f91d9d74b0df45b85f92800c9ad765ea08a28 Mon Sep 17 00:00:00 2001 From: Sai-Suraj-27 Date: Sun, 3 Dec 2023 20:32:26 +0530 Subject: [PATCH 14/95] fix: Fixed passing of arguments in `get_referrers_recursive()` function call (#27428) Co-authored-by: NripeshN --- ivy/functional/ivy/general.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index 5406a86e8367f..879b3919c52d1 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -254,7 +254,11 @@ def get_referrers_recursive( def get_referrers_recursive_inner(): return get_referrers_recursive( - ref, depth + 1, max_depth, seen_set, local_set + ref, + depth=depth + 1, + max_depth=max_depth, + seen_set=seen_set, + local_set=local_set, ) this_repr = "tracked" if seen else str(ref).replace(" ", "") From 392ec7c133593486ba68846b43484e66ab6f8e0a Mon Sep 17 00:00:00 2001 From: G544 <55620913+G544@users.noreply.github.com> Date: Mon, 4 Dec 2023 12:08:31 +0300 Subject: [PATCH 15/95] feat: add erfc_ to torch frontend (#27291) Co-authored-by: NripeshN --- ivy/functional/frontends/torch/tensor.py | 8 ++++ .../test_frontends/test_torch/test_tensor.py | 37 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 66193ab80f063..01b9dd26665fe 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -11,6 +11,7 @@ ) from ivy.func_wrapper import with_unsupported_dtypes from ivy.func_wrapper import with_supported_dtypes +from ivy.func_wrapper import with_supported_device_and_dtypes from ivy.functional.frontends.torch.func_wrapper import ( _to_ivy_array, numpy_to_torch_style_args, @@ -518,6 +519,13 @@ def erf_(self, *, out=None): self.ivy_array = self.erf(out=out).ivy_array return self + @with_supported_device_and_dtypes( + {"2.1.0 and below": {"cpu": ("float32", "float64")}}, + "torch", + ) + def erfc_(self, *, out=None): + return torch_frontend.erfc(self, out=out) + def new_zeros( self, *args, diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index 8ab5c29226aee..ec5010e2f2b4c 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -12791,6 +12791,43 @@ def test_torch_tensor_corrcoef( ) +# erfc_ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="erfc_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_torch_tensor_erfc_( + dtype_and_x, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + rtol_=1e-2, + atol_=1e-2, + ) + + # positive @handle_frontend_method( class_tree=CLASS_TREE, From 7d775835da26dbd6a1952d63004ed3e2ea52d49c Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Mon, 4 Dec 2023 09:27:22 +0000 Subject: [PATCH 16/95] feat: add node_split_best function in sklearn frontend splitter submodule derived from https://github.com/scikit-learn/scikit-learn/blob/c6654f919601ec54436f7a24a94e0874a763a8a1/sklearn/tree/_splitter.pyx#L289C20-L289C20 --- .../frontends/sklearn/tree/_splitter.py | 210 +++++++++++++++++- 1 file changed, 208 insertions(+), 2 deletions(-) diff --git a/ivy/functional/frontends/sklearn/tree/_splitter.py b/ivy/functional/frontends/sklearn/tree/_splitter.py index 037233038533f..4ece5b2c180a5 100644 --- a/ivy/functional/frontends/sklearn/tree/_splitter.py +++ b/ivy/functional/frontends/sklearn/tree/_splitter.py @@ -1,6 +1,7 @@ import ivy FEATURE_THRESHOLD = 1e-7 +INFINITY = ivy.inf class Splitter: @@ -244,6 +245,9 @@ def partition_samples_final( self.samples = samples +class SplitRecord: ... + + class BestSplitter(Splitter): def init( self, @@ -267,10 +271,212 @@ def node_split(self, impurity, split, n_constant_features): ) +# --- Helpers --- # +# --------------- # + + +def _init_split(split_record, start_pos): + return + + +# --- Main --- # +# ------------ # + + def node_split_best( - splitter: Splitter, partitioner, criterion, impurity, split, n_constant_features + splitter, partitioner, criterion, impurity, split, n_constant_features ): - pass + start = splitter.start + end = splitter.end + splitter.samples + features = splitter.features + constant_features = splitter.constant_features + n_features = splitter.n_features + + feature_values = splitter.feature_values + max_features = splitter.max_features + min_samples_leaf = splitter.min_samples_leaf + min_weight_leaf = splitter.min_weight_leaf + + best_split = SplitRecord() + current_split = SplitRecord() + current_proxy_improvement = -INFINITY + best_proxy_improvement = -INFINITY + + f_i = n_features + f_j = 0 + p = 0 + p_prev = 0 + + n_visited_features = 0 + # Number of features discovered to be constant during the split search + n_found_constants = 0 + # Number of features known to be constant and drawn without replacement + n_drawn_constants = 0 + n_known_constants = n_constant_features + # n_total_constants = n_known_constants + n_found_constants + n_total_constants = n_known_constants + best_split = _init_split(best_split, end) + partitioner.init_node_split(start, end) + while f_i > n_total_constants and ( + n_visited_features < max_features + or n_visited_features <= n_found_constants + n_drawn_constants + ): + n_visited_features += 1 + f_j = ivy.randint(n_drawn_constants, f_i - n_found_constants) + + if f_j < n_known_constants: + features[n_drawn_constants], features[f_j] = ( + features[f_j], + features[n_drawn_constants], + ) + + n_drawn_constants += 1 + continue + + # f_j in the interval [n_known_constants, f_i - n_found_constants[ + f_j += n_found_constants + # f_j in the interval [n_total_constants, f_i[ + current_split.feature = features[f_j] + partitioner.sort_samples_and_feature_values(current_split.feature) + n_missing = partitioner.n_missing + end_non_missing = end - n_missing + + if ( + end_non_missing == start + or feature_values[end_non_missing - 1] + <= feature_values[start] + FEATURE_THRESHOLD + ): + features[f_j], features[n_total_constants] = ( + features[n_total_constants], + features[f_j], + ) + + n_found_constants += 1 + n_total_constants += 1 + continue + + f_i -= 1 + features[f_i], features[f_j] = features[f_j], features[f_i] + has_missing = n_missing != 0 + criterion.init_missing(n_missing) + + n_searches = 2 if has_missing else 1 + + for i in range(n_searches): + missing_go_to_left = i == 1 + criterion.missing_go_to_left = missing_go_to_left + criterion.reset() + p = start + + while p < end_non_missing: + p_prev, p = partitioner.next_p(p_prev, p) + + if p >= end_non_missing: + continue + + if missing_go_to_left: + n_left = p - start + n_missing + n_right = end_non_missing - p + else: + n_left = p - start + n_right = end_non_missing - p + n_missing + + if n_left < min_samples_leaf or n_right < min_samples_leaf: + continue + + current_split.pos = p + criterion.update(current_split.pos) + + if ( + criterion.weighted_n_left < min_weight_leaf + or criterion.weighted_n_right < min_weight_leaf + ): + continue + + current_proxy_improvement = criterion.proxy_impurity_improvement() + + if current_proxy_improvement > best_proxy_improvement: + best_proxy_improvement = current_proxy_improvement + current_split.threshold = ( + feature_values[p_prev] / 2.0 + feature_values[p] / 2.0 + ) + + if ( + current_split.threshold == feature_values[p] + or current_split.threshold == INFINITY + or current_split.threshold == -INFINITY + ): + current_split.threshold = feature_values[p_prev] + + current_split.n_missing = n_missing + if n_missing == 0: + current_split.missing_go_to_left = n_left > n_right + else: + current_split.missing_go_to_left = missing_go_to_left + + best_split = SplitRecord(**current_split.__dict__) + + if has_missing: + n_left, n_right = end - start - n_missing, n_missing + p = end - n_missing + missing_go_to_left = 0 + + if not ((n_left < min_samples_leaf) or (n_right < min_samples_leaf)): + criterion.missing_go_to_left = missing_go_to_left + criterion.update(p) + + if not ( + criterion.weighted_n_left < min_weight_leaf + or criterion.weighted_n_right < min_weight_leaf + ): + current_proxy_improvement = criterion.proxy_impurity_improvement() + + if current_proxy_improvement > best_proxy_improvement: + best_proxy_improvement = current_proxy_improvement + current_split.threshold = INFINITY + current_split.missing_go_to_left = missing_go_to_left + current_split.n_missing = n_missing + current_split.pos = p + best_split = current_split + + # Reorganize into samples[start:best_split.pos] + samples[best_split.pos:end] + if best_split.pos < end: + partitioner.partition_samples_final( + best_split.pos, + best_split.threshold, + best_split.feature, + best_split.n_missing, + ) + + if best_split.n_missing != 0: + criterion.init_missing(best_split.n_missing) + + criterion.missing_go_to_left = best_split.missing_go_to_left + criterion.reset() + criterion.update(best_split.pos) + + ( + best_split.impurity_left, + best_split.impurity_right, + ) = criterion.children_impurity( + best_split.impurity_left, best_split.impurity_right + ) + + best_split.improvement = criterion.impurity_improvement( + impurity, best_split.impurity_left, best_split.impurity_right + ) + + # best_split, samples = shift_missing_values_to_left_if_required(best_split, samples, end) + # todo : implement shift_missing_values_to_left_if_required + features[0:n_known_constants] = constant_features[0:n_known_constants] + constant_features[n_known_constants:n_found_constants] = features[ + n_known_constants:n_found_constants + ] + + split = best_split + n_constant_features = n_total_constants + return 0, n_constant_features, split def sort(feature_values, samples, n): From 888d200537405e8712ce6e1f106a09e274471abd Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Mon, 4 Dec 2023 09:40:08 +0000 Subject: [PATCH 17/95] feat: define _init_split helper function used in node_split_best derived from https://github.com/scikit-learn/scikit-learn/blob/c6654f919601ec54436f7a24a94e0874a763a8a1/sklearn/tree/_splitter.pyx#L43 --- ivy/functional/frontends/sklearn/tree/_splitter.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ivy/functional/frontends/sklearn/tree/_splitter.py b/ivy/functional/frontends/sklearn/tree/_splitter.py index 4ece5b2c180a5..5a7e5e324cc90 100644 --- a/ivy/functional/frontends/sklearn/tree/_splitter.py +++ b/ivy/functional/frontends/sklearn/tree/_splitter.py @@ -276,7 +276,15 @@ def node_split(self, impurity, split, n_constant_features): def _init_split(split_record, start_pos): - return + split_record.impurity_left = INFINITY + split_record.impurity_right = INFINITY + split_record.pos = start_pos + split_record.feature = 0 + split_record.threshold = 0.0 + split_record.improvement = -INFINITY + split_record.missing_go_to_left = False + split_record.n_missing = 0 + return split_record # --- Main --- # From 27a5f9863e07bd8d5a303f2f0209836d8165b08f Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Mon, 4 Dec 2023 09:47:07 +0000 Subject: [PATCH 18/95] feat: define SplitRecord class used for split structure in node best splitter function --- .../frontends/sklearn/tree/_splitter.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/ivy/functional/frontends/sklearn/tree/_splitter.py b/ivy/functional/frontends/sklearn/tree/_splitter.py index 5a7e5e324cc90..0960eeef81009 100644 --- a/ivy/functional/frontends/sklearn/tree/_splitter.py +++ b/ivy/functional/frontends/sklearn/tree/_splitter.py @@ -245,7 +245,26 @@ def partition_samples_final( self.samples = samples -class SplitRecord: ... +class SplitRecord: + def __init__( + self, + feature=0, + pos=0, + threshold=0.0, + improvement=-INFINITY, + impurity_left=0.0, + impurity_right=0.0, + missing_go_to_left=False, + n_missing=0, + ): + self.feature = feature + self.pos = pos + self.threshold = threshold + self.improvement = improvement + self.impurity_left = impurity_left + self.impurity_right = impurity_right + self.missing_go_to_left = missing_go_to_left + self.n_missing = n_missing class BestSplitter(Splitter): From 827902bfd12b6edc17061b0db7923a93bb502b2c Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Mon, 4 Dec 2023 09:51:07 +0000 Subject: [PATCH 19/95] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ivy/functional/frontends/sklearn/tree/_splitter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/frontends/sklearn/tree/_splitter.py b/ivy/functional/frontends/sklearn/tree/_splitter.py index 0960eeef81009..a2c333d1ee675 100644 --- a/ivy/functional/frontends/sklearn/tree/_splitter.py +++ b/ivy/functional/frontends/sklearn/tree/_splitter.py @@ -1,7 +1,6 @@ import ivy FEATURE_THRESHOLD = 1e-7 -INFINITY = ivy.inf class Splitter: @@ -511,3 +510,4 @@ def sort(feature_values, samples, n): return idx = ivy.argsort(feature_values) return feature_values[idx], samples[idx] +INFINITY = ivy.inf From f3986b049762645e944eb151b6c9a16a15bb4ea8 Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Mon, 4 Dec 2023 09:52:44 +0000 Subject: [PATCH 20/95] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ivy/functional/frontends/sklearn/tree/_splitter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ivy/functional/frontends/sklearn/tree/_splitter.py b/ivy/functional/frontends/sklearn/tree/_splitter.py index a2c333d1ee675..6ea7e810e4022 100644 --- a/ivy/functional/frontends/sklearn/tree/_splitter.py +++ b/ivy/functional/frontends/sklearn/tree/_splitter.py @@ -510,4 +510,6 @@ def sort(feature_values, samples, n): return idx = ivy.argsort(feature_values) return feature_values[idx], samples[idx] + + INFINITY = ivy.inf From dd2a3245f98bf7a4e75bcdfe11f706a68cb3cc71 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Mon, 4 Dec 2023 10:29:24 +0000 Subject: [PATCH 21/95] fix: use ivy.inf temporarily cutting out INFINITY constant to avoid undefined behaviour caused by the lintbot --- .../frontends/sklearn/tree/_splitter.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/ivy/functional/frontends/sklearn/tree/_splitter.py b/ivy/functional/frontends/sklearn/tree/_splitter.py index 6ea7e810e4022..af172e77e3646 100644 --- a/ivy/functional/frontends/sklearn/tree/_splitter.py +++ b/ivy/functional/frontends/sklearn/tree/_splitter.py @@ -250,7 +250,7 @@ def __init__( feature=0, pos=0, threshold=0.0, - improvement=-INFINITY, + improvement=-ivy.inf, impurity_left=0.0, impurity_right=0.0, missing_go_to_left=False, @@ -294,12 +294,12 @@ def node_split(self, impurity, split, n_constant_features): def _init_split(split_record, start_pos): - split_record.impurity_left = INFINITY - split_record.impurity_right = INFINITY + split_record.impurity_left = ivy.inf + split_record.impurity_right = ivy.inf split_record.pos = start_pos split_record.feature = 0 split_record.threshold = 0.0 - split_record.improvement = -INFINITY + split_record.improvement = -ivy.inf split_record.missing_go_to_left = False split_record.n_missing = 0 return split_record @@ -326,8 +326,8 @@ def node_split_best( best_split = SplitRecord() current_split = SplitRecord() - current_proxy_improvement = -INFINITY - best_proxy_improvement = -INFINITY + current_proxy_improvement = -ivy.inf + best_proxy_improvement = -ivy.inf f_i = n_features f_j = 0 @@ -430,8 +430,8 @@ def node_split_best( if ( current_split.threshold == feature_values[p] - or current_split.threshold == INFINITY - or current_split.threshold == -INFINITY + or current_split.threshold == ivy.inf + or current_split.threshold == -ivy.inf ): current_split.threshold = feature_values[p_prev] @@ -460,7 +460,7 @@ def node_split_best( if current_proxy_improvement > best_proxy_improvement: best_proxy_improvement = current_proxy_improvement - current_split.threshold = INFINITY + current_split.threshold = ivy.inf current_split.missing_go_to_left = missing_go_to_left current_split.n_missing = n_missing current_split.pos = p @@ -510,6 +510,3 @@ def sort(feature_values, samples, n): return idx = ivy.argsort(feature_values) return feature_values[idx], samples[idx] - - -INFINITY = ivy.inf From 09f96bd7f0d30ae2faedc3243be418f36b821c50 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Mon, 4 Dec 2023 11:14:31 +0000 Subject: [PATCH 22/95] refactor: remove unwanted variables in node_split_best in sklearn frontend splitter moduel --- ivy/functional/frontends/sklearn/tree/_splitter.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ivy/functional/frontends/sklearn/tree/_splitter.py b/ivy/functional/frontends/sklearn/tree/_splitter.py index af172e77e3646..0461682204e72 100644 --- a/ivy/functional/frontends/sklearn/tree/_splitter.py +++ b/ivy/functional/frontends/sklearn/tree/_splitter.py @@ -314,7 +314,6 @@ def node_split_best( ): start = splitter.start end = splitter.end - splitter.samples features = splitter.features constant_features = splitter.constant_features n_features = splitter.n_features @@ -326,12 +325,9 @@ def node_split_best( best_split = SplitRecord() current_split = SplitRecord() - current_proxy_improvement = -ivy.inf best_proxy_improvement = -ivy.inf f_i = n_features - f_j = 0 - p = 0 p_prev = 0 n_visited_features = 0 From b48a92eef2825684e03a6348936933bbc0884358 Mon Sep 17 00:00:00 2001 From: Sai-Suraj-27 Date: Mon, 4 Dec 2023 21:34:01 +0530 Subject: [PATCH 23/95] fix: Fixed order of functions and added missing `staticmethod` decorator (#27436) --- .../container/experimental/creation.py | 105 +++++++++--------- 1 file changed, 53 insertions(+), 52 deletions(-) diff --git a/ivy/data_classes/container/experimental/creation.py b/ivy/data_classes/container/experimental/creation.py index 540162d93e6fe..fc861420bc843 100644 --- a/ivy/data_classes/container/experimental/creation.py +++ b/ivy/data_classes/container/experimental/creation.py @@ -1252,57 +1252,6 @@ def static_unsorted_segment_mean( map_sequences=map_sequences, ) - def static_polyval( - coeffs: ivy.Container, - x: Union[ivy.Container, int, float], - *, - key_chains: Optional[Union[List[str], Dict[str, str]]] = None, - to_apply: bool = True, - prune_unapplied: bool = False, - map_sequences: bool = False, - ) -> ivy.Container: - r""" - ivy.Container static method variant of ivy.polyval. This method simply wraps the - function, and so the docstring for ivy.polyval also applies to this method with - minimal changes. - - Evaluate and return a polynomial at specific given values. - - Parameters - ---------- - coeffs - Polynomial coefficients (including zero) from highest degree - to constant term. - x - The value of the indeterminate variable at which to evaluate the polynomial. - key_chains - The key-chains to apply or not apply the method to. Default is ``None``. - to_apply - If True, the method will be applied to key_chains, otherwise key_chains - will be skipped. Default is ``True``. - prune_unapplied - Whether to prune key_chains for which the function was not applied. - Default is ``False``. - map_sequences - Whether to also map method to sequences (lists, tuples). - Default is ``False``. - - Returns - ------- - ret - Output container containing simplified result of substituting x in the - coefficients - final value of polynomial. - """ - return ContainerBase.cont_multi_map_in_function( - "polyval", - coeffs, - x, - key_chains=key_chains, - to_apply=to_apply, - prune_unapplied=prune_unapplied, - map_sequences=map_sequences, - ) - def unsorted_segment_mean( self: ivy.Container, segment_ids: Union[ivy.Array, ivy.Container], @@ -1359,6 +1308,58 @@ def unsorted_segment_mean( num_segments, ) + @staticmethod + def static_polyval( + coeffs: ivy.Container, + x: Union[ivy.Container, int, float], + *, + key_chains: Optional[Union[List[str], Dict[str, str]]] = None, + to_apply: bool = True, + prune_unapplied: bool = False, + map_sequences: bool = False, + ) -> ivy.Container: + r""" + ivy.Container static method variant of ivy.polyval. This method simply wraps the + function, and so the docstring for ivy.polyval also applies to this method with + minimal changes. + + Evaluate and return a polynomial at specific given values. + + Parameters + ---------- + coeffs + Polynomial coefficients (including zero) from highest degree + to constant term. + x + The value of the indeterminate variable at which to evaluate the polynomial. + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + + Returns + ------- + ret + Output container containing simplified result of substituting x in the + coefficients - final value of polynomial. + """ + return ContainerBase.cont_multi_map_in_function( + "polyval", + coeffs, + x, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + ) + def polyval( self: ivy.Container, coeffs: ivy.Container, @@ -1388,4 +1389,4 @@ def polyval( Output container containing simplified result of substituting x in the coefficients - final value of polynomial. """ - return self.static_polyval(self, coeffs, x) + return self.static_polyval(coeffs, x) From 3543bece6588809a670cb2eaa83dde77e00a7ab7 Mon Sep 17 00:00:00 2001 From: Sai-Suraj-27 Date: Mon, 4 Dec 2023 21:34:40 +0530 Subject: [PATCH 24/95] fix: Fixed raising `TypeError` exception instead or `ValueError` exception for invalid type (#27439) --- ivy/compiler/replace_with.py | 2 +- ivy/functional/frontends/jax/numpy/fft.py | 2 +- ivy/functional/frontends/jax/numpy/indexing.py | 2 +- .../frontends/numpy/fft/discrete_fourier_transform.py | 4 ++-- ivy/functional/frontends/paddle/nn/functional/vision.py | 4 ++-- ivy/functional/frontends/torch/nn/modules/module.py | 2 +- ivy/utils/assertions.py | 2 +- 7 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ivy/compiler/replace_with.py b/ivy/compiler/replace_with.py index 3bbab72988447..b19d8b84dc03c 100644 --- a/ivy/compiler/replace_with.py +++ b/ivy/compiler/replace_with.py @@ -15,7 +15,7 @@ def replace_with(new_func): def decorator(original_func): if not callable(original_func) or not callable(new_func): - raise ValueError( + raise TypeError( f"Both '{original_func.__name__}' and '{new_func.__name__}' should be" " callable." ) diff --git a/ivy/functional/frontends/jax/numpy/fft.py b/ivy/functional/frontends/jax/numpy/fft.py index 69b9415b6176b..5732511ac8838 100644 --- a/ivy/functional/frontends/jax/numpy/fft.py +++ b/ivy/functional/frontends/jax/numpy/fft.py @@ -24,7 +24,7 @@ def fftfreq(n, d=1.0, *, dtype=None): if not isinstance( n, (int, type(ivy.int8), type(ivy.int16), type(ivy.int32), type(ivy.int64)) ): - raise ValueError("n should be an integer") + raise TypeError("n should be an integer") dtype = ivy.float64 if dtype is None else ivy.as_ivy_dtype(dtype) diff --git a/ivy/functional/frontends/jax/numpy/indexing.py b/ivy/functional/frontends/jax/numpy/indexing.py index 54e7ee79028ae..1e0b778bd9aa6 100644 --- a/ivy/functional/frontends/jax/numpy/indexing.py +++ b/ivy/functional/frontends/jax/numpy/indexing.py @@ -52,7 +52,7 @@ def __getitem__(self, key): newobj = _make_1d_grid_from_slice(item) item_ndim = 0 elif isinstance(item, str): - raise ValueError("string directive must be placed at the beginning") + raise TypeError("string directive must be placed at the beginning") else: newobj = array(item, copy=False) item_ndim = newobj.ndim diff --git a/ivy/functional/frontends/numpy/fft/discrete_fourier_transform.py b/ivy/functional/frontends/numpy/fft/discrete_fourier_transform.py index 609d4e303072b..d98660c654e4b 100644 --- a/ivy/functional/frontends/numpy/fft/discrete_fourier_transform.py +++ b/ivy/functional/frontends/numpy/fft/discrete_fourier_transform.py @@ -31,7 +31,7 @@ def fftfreq(n, d=1.0): if not isinstance( n, (int, type(ivy.int8), type(ivy.int16), type(ivy.int32), type(ivy.int64)) ): - raise ValueError("n should be an integer") + raise TypeError("n should be an integer") N = (n - 1) // 2 + 1 val = 1.0 / (n * d) @@ -135,7 +135,7 @@ def rfftfreq(n, d=1.0): if not isinstance( n, (int, type(ivy.int8), type(ivy.int16), type(ivy.int32), type(ivy.int64)) ): - raise ValueError("n should be an integer") + raise TypeError("n should be an integer") val = 1.0 / (n * d) N = n // 2 + 1 diff --git a/ivy/functional/frontends/paddle/nn/functional/vision.py b/ivy/functional/frontends/paddle/nn/functional/vision.py index 5f25344dd0a34..935ce239a7dbb 100644 --- a/ivy/functional/frontends/paddle/nn/functional/vision.py +++ b/ivy/functional/frontends/paddle/nn/functional/vision.py @@ -118,7 +118,7 @@ def pixel_shuffle(x, upscale_factor, data_format="NCHW"): ) if not isinstance(upscale_factor, int): - raise ValueError("upscale factor must be int type") + raise TypeError("upscale factor must be int type") if data_format not in ["NCHW", "NHWC"]: raise ValueError( @@ -172,7 +172,7 @@ def pixel_unshuffle(x, downscale_factor, data_format="NCHW"): ) if not isinstance(downscale_factor, int): - raise ValueError("Downscale factor must be int type") + raise TypeError("Downscale factor must be int type") if downscale_factor <= 0: raise ValueError("Downscale factor must be positive") diff --git a/ivy/functional/frontends/torch/nn/modules/module.py b/ivy/functional/frontends/torch/nn/modules/module.py index 319a9989ac9d9..31f30e4ad0b1e 100644 --- a/ivy/functional/frontends/torch/nn/modules/module.py +++ b/ivy/functional/frontends/torch/nn/modules/module.py @@ -147,7 +147,7 @@ def get_submodule(self, target: str) -> "Module": mod = getattr(mod, item) if not isinstance(mod, Module): - raise AttributeError("`" + item + "` is not an nn.Module") + raise TypeError("`" + item + "` is not an nn.Module") return mod diff --git a/ivy/utils/assertions.py b/ivy/utils/assertions.py index 3ce9cb927f8be..77f4cce892390 100644 --- a/ivy/utils/assertions.py +++ b/ivy/utils/assertions.py @@ -184,7 +184,7 @@ def check_same_dtype(x1, x2, message=""): def check_unsorted_segment_valid_params(data, segment_ids, num_segments): if not isinstance(num_segments, int): - raise ValueError("num_segments must be of integer type") + raise TypeError("num_segments must be of integer type") valid_dtypes = [ ivy.int32, From 070bcd603fb258c0de2a568331c1dd065cc43343 Mon Sep 17 00:00:00 2001 From: Haris Mahmood <70361308+hmahmood24@users.noreply.github.com> Date: Mon, 4 Dec 2023 16:08:13 +0000 Subject: [PATCH 25/95] fix: Add a check in the converters to filter any NoneType model parameters when doing a to_keras_module on an ivy.Module instance --- ivy/stateful/converters.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ivy/stateful/converters.py b/ivy/stateful/converters.py index 599b5532d0a24..b3d78f296f6cf 100644 --- a/ivy/stateful/converters.py +++ b/ivy/stateful/converters.py @@ -476,13 +476,19 @@ def _assign_variables(self): ), ) self._ivy_module.v.cont_map( - lambda x, kc: self.add_weight( - name=kc, shape=x.shape, dtype=x.dtype, trainable=True + lambda x, kc: ( + self.add_weight( + name=kc, shape=x.shape, dtype=x.dtype, trainable=True + ) + if x is not None + else x ) ) model_weights = [] self._ivy_module.v.cont_map( - lambda x, kc: model_weights.append(ivy.to_numpy(x)) + lambda x, kc: ( + model_weights.append(ivy.to_numpy(x)) if x is not None else x + ) ) self.set_weights(model_weights) From 110fd6852b223938a9eb7f4d5b154edb8747a7b7 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Mon, 4 Dec 2023 20:41:26 +0100 Subject: [PATCH 26/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 8395caf0efd10..ac67db9e4e08e 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -199,7 +199,7 @@ def gather( for z in zip_list: p, i = z r = torch.gather( - p, (axis - batch_dims) % p.ndim, i, sparse_grad=False, out=False + p, (axis - batch_dims) % p.ndim, i, sparse_grad=False, out=None ) result.append(r) From 64ca8d7d120c009501f614d59a5afc48036786d2 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Mon, 4 Dec 2023 20:57:19 +0100 Subject: [PATCH 27/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index ac67db9e4e08e..21da5e7726197 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -199,7 +199,7 @@ def gather( for z in zip_list: p, i = z r = torch.gather( - p, (axis - batch_dims) % p.ndim, i, sparse_grad=False, out=None + p, (axis - batch_dims) % p.ndim, torch.reshape(i, (-1,)), sparse_grad=False, out=None ) result.append(r) From 45c3ada8339b9326e8cbc7310c64dd433711db43 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Mon, 4 Dec 2023 21:11:21 +0100 Subject: [PATCH 28/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 21da5e7726197..4b73c798a1ab1 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -199,7 +199,8 @@ def gather( for z in zip_list: p, i = z r = torch.gather( - p, (axis - batch_dims) % p.ndim, torch.reshape(i, (-1,)), sparse_grad=False, out=None + p, (axis - batch_dims) % p.ndim, torch.reshape(i, (-1,)), + sparse_grad=False, out=None ) result.append(r) From 58b4b2e0d27d873f00881f58a7aa6a813044448a Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Mon, 4 Dec 2023 21:51:05 +0100 Subject: [PATCH 29/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 4b73c798a1ab1..77271f088105e 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -186,6 +186,7 @@ def gather( batch_dims %= len(params.shape) ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) result = [] + indices.to(torch.int64) if batch_dims == 0: result = torch.gather(params, axis, indices, sparse_grad=False, out=out) else: @@ -204,8 +205,8 @@ def gather( ) result.append(r) - result = torch.stack(result) - result = result.reshape([*params.shape[0:batch_dims], *result.shape[1:]]) + result = torch.cat(result, dim=0) + result = result.reshape((*params.shape[0:batch_dims], *result.shape[1:])) if ivy.exists(out): return ivy.inplace_update(out, result) From 176af0c8255a7942d385e22f6cd7cf19a52d0f93 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Mon, 4 Dec 2023 22:28:29 +0100 Subject: [PATCH 30/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 77271f088105e..894c54aedbd07 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -186,7 +186,7 @@ def gather( batch_dims %= len(params.shape) ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) result = [] - indices.to(torch.int64) + indices.to("int64") if batch_dims == 0: result = torch.gather(params, axis, indices, sparse_grad=False, out=out) else: @@ -206,7 +206,9 @@ def gather( result.append(r) result = torch.cat(result, dim=0) - result = result.reshape((*params.shape[0:batch_dims], *result.shape[1:])) + result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] + + params.shape[axis + 1 :]) + ) if ivy.exists(out): return ivy.inplace_update(out, result) From 368410dd88c80a91b8ec498c5b80060c65a0ed34 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Mon, 4 Dec 2023 22:37:54 +0100 Subject: [PATCH 31/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 894c54aedbd07..cb9d95251b8ac 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -186,7 +186,7 @@ def gather( batch_dims %= len(params.shape) ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) result = [] - indices.to("int64") + indices.to(torch.int64) if batch_dims == 0: result = torch.gather(params, axis, indices, sparse_grad=False, out=out) else: @@ -209,8 +209,8 @@ def gather( result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :]) ) - if ivy.exists(out): - return ivy.inplace_update(out, result) + if ivy.exists(out): + return ivy.inplace_update(out, result) return result From fdca35a8343db9a3d04736b5c2db65062952f233 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Mon, 4 Dec 2023 22:40:31 +0100 Subject: [PATCH 32/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index cb9d95251b8ac..e16f925ec50b9 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -186,9 +186,8 @@ def gather( batch_dims %= len(params.shape) ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) result = [] - indices.to(torch.int64) if batch_dims == 0: - result = torch.gather(params, axis, indices, sparse_grad=False, out=out) + result = torch.gather(params, axis, indices.long(), sparse_grad=False, out=out) else: for b in range(batch_dims): if b == 0: @@ -200,7 +199,7 @@ def gather( for z in zip_list: p, i = z r = torch.gather( - p, (axis - batch_dims) % p.ndim, torch.reshape(i, (-1,)), + p, (axis - batch_dims) % p.ndim, torch.reshape(i, (-1,)).long(), sparse_grad=False, out=None ) From 25a826892827b8c09d13d5db3f5fb0f58f8eaafb Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Mon, 4 Dec 2023 22:48:14 +0100 Subject: [PATCH 33/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index e16f925ec50b9..c5ec5b444a266 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -187,7 +187,9 @@ def gather( ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) result = [] if batch_dims == 0: - result = torch.gather(params, axis, indices.long(), sparse_grad=False, out=out) + result = torch.gather( + params, axis, torch.reshape(indices, (-1,)).long(), + sparse_grad=False, out=out) else: for b in range(batch_dims): if b == 0: From 77d2cadda34f7c61858ee0b925e32e740cc6d885 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Mon, 4 Dec 2023 23:37:28 +0100 Subject: [PATCH 34/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index c5ec5b444a266..7d32d641dacf3 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -188,7 +188,7 @@ def gather( result = [] if batch_dims == 0: result = torch.gather( - params, axis, torch.reshape(indices, (-1,)).long(), + params, axis, torch.reshape(indices, (-1)).long(), sparse_grad=False, out=out) else: for b in range(batch_dims): @@ -201,7 +201,7 @@ def gather( for z in zip_list: p, i = z r = torch.gather( - p, (axis - batch_dims) % p.ndim, torch.reshape(i, (-1,)).long(), + p, (axis - batch_dims) % p.ndim, torch.reshape(i, (-1)).long(), sparse_grad=False, out=None ) From 3d6e04854a1fe89196f4c462d95d30f7236a9d54 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Mon, 4 Dec 2023 23:48:13 +0100 Subject: [PATCH 35/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 7d32d641dacf3..c5ec5b444a266 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -188,7 +188,7 @@ def gather( result = [] if batch_dims == 0: result = torch.gather( - params, axis, torch.reshape(indices, (-1)).long(), + params, axis, torch.reshape(indices, (-1,)).long(), sparse_grad=False, out=out) else: for b in range(batch_dims): @@ -201,7 +201,7 @@ def gather( for z in zip_list: p, i = z r = torch.gather( - p, (axis - batch_dims) % p.ndim, torch.reshape(i, (-1)).long(), + p, (axis - batch_dims) % p.ndim, torch.reshape(i, (-1,)).long(), sparse_grad=False, out=None ) From b0da19d6b16f4d453414505aa04fa2130ac93921 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Mon, 4 Dec 2023 23:57:01 +0100 Subject: [PATCH 36/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index c5ec5b444a266..5ae64e15a8461 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -188,7 +188,7 @@ def gather( result = [] if batch_dims == 0: result = torch.gather( - params, axis, torch.reshape(indices, (-1,)).long(), + params, axis, torch.reshape(indices.long(), params.shape), sparse_grad=False, out=out) else: for b in range(batch_dims): @@ -201,7 +201,7 @@ def gather( for z in zip_list: p, i = z r = torch.gather( - p, (axis - batch_dims) % p.ndim, torch.reshape(i, (-1,)).long(), + p, (axis - batch_dims), torch.reshape(i.long(), p.shape), sparse_grad=False, out=None ) From 541d857464f68701814a78aa9d8cef45f64402dc Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 00:11:39 +0100 Subject: [PATCH 37/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 5ae64e15a8461..5d47b7351e9f7 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -188,7 +188,7 @@ def gather( result = [] if batch_dims == 0: result = torch.gather( - params, axis, torch.reshape(indices.long(), params.shape), + params, axis, torch.reshape(indices.long(), params.size()), sparse_grad=False, out=out) else: for b in range(batch_dims): @@ -201,7 +201,7 @@ def gather( for z in zip_list: p, i = z r = torch.gather( - p, (axis - batch_dims), torch.reshape(i.long(), p.shape), + p, (axis - batch_dims), torch.reshape(i.long(), p.size()), sparse_grad=False, out=None ) From 7d4f58e6623fa99b34c1a8f446273c578773bb4d Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 00:23:13 +0100 Subject: [PATCH 38/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 5d47b7351e9f7..db4c91438a150 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -188,7 +188,7 @@ def gather( result = [] if batch_dims == 0: result = torch.gather( - params, axis, torch.reshape(indices.long(), params.size()), + params, axis, torch.reshape(indices.long(), list(params.size())), sparse_grad=False, out=out) else: for b in range(batch_dims): @@ -201,7 +201,7 @@ def gather( for z in zip_list: p, i = z r = torch.gather( - p, (axis - batch_dims), torch.reshape(i.long(), p.size()), + p, (axis - batch_dims), torch.reshape(i.long(), list(p.size())), sparse_grad=False, out=None ) From 4707d0f02ef0f73f879150adade4b7ddbfa419c7 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 00:28:55 +0100 Subject: [PATCH 39/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index db4c91438a150..a334f7eb39621 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -188,7 +188,7 @@ def gather( result = [] if batch_dims == 0: result = torch.gather( - params, axis, torch.reshape(indices.long(), list(params.size())), + params, axis, torch.reshape(indices.long(), tuple(params.size())), sparse_grad=False, out=out) else: for b in range(batch_dims): @@ -201,7 +201,7 @@ def gather( for z in zip_list: p, i = z r = torch.gather( - p, (axis - batch_dims), torch.reshape(i.long(), list(p.size())), + p, (axis - batch_dims), torch.reshape(i.long(), tuple(p.size())), sparse_grad=False, out=None ) From 1f3cee477d84312734157d9d5d4f65fc0887e4bf Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 01:13:32 +0100 Subject: [PATCH 40/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index a334f7eb39621..7f2fb80b204be 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -188,7 +188,7 @@ def gather( result = [] if batch_dims == 0: result = torch.gather( - params, axis, torch.reshape(indices.long(), tuple(params.size())), + params, axis, torch.reshape(indices.long(), params.shape), sparse_grad=False, out=out) else: for b in range(batch_dims): @@ -201,7 +201,7 @@ def gather( for z in zip_list: p, i = z r = torch.gather( - p, (axis - batch_dims), torch.reshape(i.long(), tuple(p.size())), + p, (axis - batch_dims), torch.reshape(i.long(), (-1, )), sparse_grad=False, out=None ) From 11d00556d16d776e49c95929a3142c669fd118be Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 01:26:24 +0100 Subject: [PATCH 41/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 7f2fb80b204be..fb4a9e3138077 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -188,7 +188,7 @@ def gather( result = [] if batch_dims == 0: result = torch.gather( - params, axis, torch.reshape(indices.long(), params.shape), + params, axis, torch.reshape(indices.long(), (-1, )), sparse_grad=False, out=out) else: for b in range(batch_dims): From 2ca112cc9fdbd0437fbb8f7e00bd6ed37ff6e9d9 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 01:45:54 +0100 Subject: [PATCH 42/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index fb4a9e3138077..9c5650ac6c061 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -189,7 +189,10 @@ def gather( if batch_dims == 0: result = torch.gather( params, axis, torch.reshape(indices.long(), (-1, )), - sparse_grad=False, out=out) + sparse_grad=False, out=out).reshape((params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1 :]) + ) else: for b in range(batch_dims): if b == 0: From 9c38328c6068c20bfce74af51d732eccc0cd090a Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 02:18:23 +0100 Subject: [PATCH 43/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 9c5650ac6c061..f882ef37c4e03 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -187,16 +187,14 @@ def gather( ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) result = [] if batch_dims == 0: - result = torch.gather( - params, axis, torch.reshape(indices.long(), (-1, )), - sparse_grad=False, out=out).reshape((params.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1 :]) - ) + result = torch.take( + params, indices) else: + params_slices = torch.unbind(params, axis=0) + indices_slices = torch.unbind(indices, axis=0) for b in range(batch_dims): if b == 0: - zip_list = [(p, i) for p, i in zip(params, indices)] + zip_list = [(p, i) for p, i in zip(params_slices, indices_slices)] else: zip_list = [ (p, i) for z in [zip(p1, i1) for p1, i1 in zip_list] for p, i in z From 5c5935e15a71e31b2105380d5778a9c6974e9dc1 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 02:31:37 +0100 Subject: [PATCH 44/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index f882ef37c4e03..5c4b834007bde 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -211,8 +211,8 @@ def gather( result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :]) ) - if ivy.exists(out): - return ivy.inplace_update(out, result) + if ivy.exists(out): + return ivy.inplace_update(out, result) return result From f13c16f7d69bc29682edc3d15a89210539810d69 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 02:34:50 +0100 Subject: [PATCH 45/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 5c4b834007bde..60ec6119c8f4f 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -188,7 +188,7 @@ def gather( result = [] if batch_dims == 0: result = torch.take( - params, indices) + params, indices, out=out) else: params_slices = torch.unbind(params, axis=0) indices_slices = torch.unbind(indices, axis=0) @@ -211,8 +211,8 @@ def gather( result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :]) ) - if ivy.exists(out): - return ivy.inplace_update(out, result) + if ivy.exists(out): + return ivy.inplace_update(out, result) return result From 7d6ad4ab322199bb62b028e73fef7c580674d7a1 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 02:39:28 +0100 Subject: [PATCH 46/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 60ec6119c8f4f..40f85220a3e26 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -207,7 +207,7 @@ def gather( ) result.append(r) - result = torch.cat(result, dim=0) + result = torch.stack(result) result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :]) ) From c05b919d6464c61837e1c1ed08fe2ef86ced87bb Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 02:44:04 +0100 Subject: [PATCH 47/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 40f85220a3e26..1174248e801cc 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -188,7 +188,7 @@ def gather( result = [] if batch_dims == 0: result = torch.take( - params, indices, out=out) + params, indices.long(), out=out) else: params_slices = torch.unbind(params, axis=0) indices_slices = torch.unbind(indices, axis=0) From f155aa6569a4ae45e7728519d7f82a518c52b22c Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 18:34:18 +0100 Subject: [PATCH 48/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 1174248e801cc..828bf4eb261bd 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -188,7 +188,10 @@ def gather( result = [] if batch_dims == 0: result = torch.take( - params, indices.long(), out=out) + params, indices.long(), out=None) + result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] + + params.shape[axis + 1 :]) + ) else: params_slices = torch.unbind(params, axis=0) indices_slices = torch.unbind(indices, axis=0) @@ -211,8 +214,8 @@ def gather( result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :]) ) - if ivy.exists(out): - return ivy.inplace_update(out, result) + if ivy.exists(out): + return ivy.inplace_update(out, result) return result From a132f7005c9084b32063b25aff443afa295e5c31 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 20:06:31 +0100 Subject: [PATCH 49/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 828bf4eb261bd..db9f03f3de6ef 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -187,8 +187,10 @@ def gather( ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) result = [] if batch_dims == 0: - result = torch.take( - params, indices.long(), out=None) + params_slices = torch.unbind(params, axis=0) + result = [torch.take( + p, indices.long(), out=None) for p in params_slices] + result = torch.stack(result) result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :]) ) From 8170ad67a5600140f48970930248f7496ee37542 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 20:38:27 +0100 Subject: [PATCH 50/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index db9f03f3de6ef..2c68c84ebabc4 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -187,10 +187,14 @@ def gather( ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) result = [] if batch_dims == 0: - params_slices = torch.unbind(params, axis=0) - result = [torch.take( - p, indices.long(), out=None) for p in params_slices] - result = torch.stack(result) + if params.dim() > indices.dim(): + params_slices = torch.unbind(params, axis=0) + result = [torch.take( + p, indices.long(), out=None) for p in params_slices] + result = torch.stack(result) + else: + result = torch.take( + params, indices.long(), out=None) result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :]) ) From a6d2676487f408b9ef51deb270af43396e94521b Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 21:08:39 +0100 Subject: [PATCH 51/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 2c68c84ebabc4..8b837acd0cd37 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -189,15 +189,17 @@ def gather( if batch_dims == 0: if params.dim() > indices.dim(): params_slices = torch.unbind(params, axis=0) + indices_expanded = indices.expand(params_slices[0].shape) result = [torch.take( - p, indices.long(), out=None) for p in params_slices] + p, indices_expanded.long(), out=None) for p in params_slices] result = torch.stack(result) + result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] + + params.shape[axis + 1 :]) + ) else: result = torch.take( params, indices.long(), out=None) - result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] - + params.shape[axis + 1 :]) - ) + else: params_slices = torch.unbind(params, axis=0) indices_slices = torch.unbind(indices, axis=0) From e6be13c082ff444d19fed5ebd3a7011dff6ffd84 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 21:54:45 +0100 Subject: [PATCH 52/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 8b837acd0cd37..515f03de887a9 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -199,6 +199,9 @@ def gather( else: result = torch.take( params, indices.long(), out=None) + result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] + + params.shape[axis + 1 :]) + ) else: params_slices = torch.unbind(params, axis=0) From c2985f99d7399b7722761673e546e57f5b4bcfbc Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 5 Dec 2023 23:44:56 +0100 Subject: [PATCH 53/95] fixing torch_gather backend --- ivy/functional/backends/torch/general.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 515f03de887a9..c8748af09c46f 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -187,15 +187,16 @@ def gather( ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) result = [] if batch_dims == 0: - if params.dim() > indices.dim(): - params_slices = torch.unbind(params, axis=0) - indices_expanded = indices.expand(params_slices[0].shape) - result = [torch.take( - p, indices_expanded.long(), out=None) for p in params_slices] - result = torch.stack(result) - result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] - + params.shape[axis + 1 :]) - ) + dim_diff = params.dim() - indices.dim() + if dim_diff > 0: + params_squeezed = params + for d in range(dim_diff): + params_squeezed.squeeze(-1) + result = torch.take( + params_squeezed, indices.long(), out=None) + for d in range(dim_diff): + result = result.expand((params.shape[:axis] + indices.shape[batch_dims:] + + params.shape[axis + 1 :])) else: result = torch.take( params, indices.long(), out=None) From ed7fd7cb6fb5a3c5db727f75dfa8143ad9189193 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Thu, 7 Dec 2023 18:33:00 +0100 Subject: [PATCH 54/95] fixing torch_gather backend, add reshape and expand --- ivy/functional/backends/torch/general.py | 57 +++++++++++++++++------- 1 file changed, 42 insertions(+), 15 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index c8748af09c46f..ef0f3f27c3e01 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -186,27 +186,54 @@ def gather( batch_dims %= len(params.shape) ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) result = [] + + def expand_p_i(params, indices): + + dim_helper_table = [1 for dim in list(params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1:])] + singleton_dims = torch.Size(dim_helper_table) + params_expanded = params.reshape(params.shape[:axis] + + singleton_dims[batch_dims:] + + params.shape[axis + 1:]) + params_expanded = params_expanded.expand(params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1:]) + indices_expanded = indices.reshape.reshape(singleton_dims[:axis] + + indices.shape[batch_dims:] + + singleton_dims[axis + 1:]) + indices_expanded = indices_expanded.expand(params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1:]) + return params_expanded, indices_expanded + if batch_dims == 0: dim_diff = params.dim() - indices.dim() - if dim_diff > 0: - params_squeezed = params - for d in range(dim_diff): - params_squeezed.squeeze(-1) - result = torch.take( - params_squeezed, indices.long(), out=None) - for d in range(dim_diff): - result = result.expand((params.shape[:axis] + indices.shape[batch_dims:] - + params.shape[axis + 1 :])) + if dim_diff != 0: + params_expanded, indices_expanded = expand_p_i(params, indices) + result = torch.gather( + params_expanded, axis, + indices_expanded.long(), + sparse_grad=False, out=out + ) + return result else: - result = torch.take( - params, indices.long(), out=None) - result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] - + params.shape[axis + 1 :]) + result = torch.gather( + params, axis, indices.long(), sparse_grad=False, out=out) + result = result.reshape((params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1:]) ) else: - params_slices = torch.unbind(params, axis=0) - indices_slices = torch.unbind(indices, axis=0) + dim_diff = params.dim() - indices.dim() + params_expanded = params + indices_expanded = indices + if dim_diff != 0: + params_expanded, indices_expanded = expand_p_i(params, indices) + + params_slices = torch.unbind(params_expanded, axis=0) + indices_slices = torch.unbind(indices_expanded, axis=0) for b in range(batch_dims): if b == 0: zip_list = [(p, i) for p, i in zip(params_slices, indices_slices)] From 07f7e59629d0532cbf218d6b308c58e2f1b8eb35 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Thu, 7 Dec 2023 18:41:18 +0100 Subject: [PATCH 55/95] fixing torch_gather backend, add reshape and expand --- ivy/functional/backends/torch/general.py | 53 ++++++++++++++---------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index ef0f3f27c3e01..f7c07330884ff 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -184,35 +184,38 @@ def gather( ) -> torch.Tensor: axis %= len(params.shape) batch_dims %= len(params.shape) - ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) + ivy.utils.assertions.check_gather_input_valid( + params, + indices, + axis, + batch_dims + ) result = [] def expand_p_i(params, indices): - - dim_helper_table = [1 for dim in list(params.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1:])] + dim_helper_table = [1 for dim in list(params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1:])] singleton_dims = torch.Size(dim_helper_table) - params_expanded = params.reshape(params.shape[:axis] + params_expanded = params.reshape(params.shape[:axis] + singleton_dims[batch_dims:] + params.shape[axis + 1:]) params_expanded = params_expanded.expand(params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1:]) - indices_expanded = indices.reshape.reshape(singleton_dims[:axis] + indices_expanded = indices.reshape.reshape(singleton_dims[:axis] + indices.shape[batch_dims:] + singleton_dims[axis + 1:]) - indices_expanded = indices_expanded.expand(params.shape[:axis] + indices_expanded = indices_expanded.expand(params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1:]) - return params_expanded, indices_expanded - + return params_expanded, indices_expanded if batch_dims == 0: dim_diff = params.dim() - indices.dim() if dim_diff != 0: params_expanded, indices_expanded = expand_p_i(params, indices) result = torch.gather( - params_expanded, axis, + params_expanded, axis, indices_expanded.long(), sparse_grad=False, out=out ) @@ -220,38 +223,42 @@ def expand_p_i(params, indices): else: result = torch.gather( params, axis, indices.long(), sparse_grad=False, out=out) - result = result.reshape((params.shape[:axis] + result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1:]) - ) - + ) else: dim_diff = params.dim() - indices.dim() params_expanded = params indices_expanded = indices if dim_diff != 0: params_expanded, indices_expanded = expand_p_i(params, indices) - params_slices = torch.unbind(params_expanded, axis=0) indices_slices = torch.unbind(indices_expanded, axis=0) for b in range(batch_dims): if b == 0: - zip_list = [(p, i) for p, i in zip(params_slices, indices_slices)] + zip_list = [(p, i) for + p, i in + zip(params_slices, indices_slices) + ] else: zip_list = [ - (p, i) for z in [zip(p1, i1) for p1, i1 in zip_list] for p, i in z - ] + (p, i) for z in [zip(p1, i1) for p1, i1 in zip_list] + for p, i in z + ] for z in zip_list: p, i = z r = torch.gather( - p, (axis - batch_dims), torch.reshape(i.long(), (-1, )), - sparse_grad=False, out=None - ) + p, (axis - batch_dims), + torch.reshape(i.long(), (-1, )), + sparse_grad=False, out=None + ) result.append(r) result = torch.stack(result) - result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] - + params.shape[axis + 1 :]) + result = result.reshape((params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1:]) ) if ivy.exists(out): return ivy.inplace_update(out, result) From 9e5c796c22aac4a988be5ac76f1bf669d8c7bc89 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Thu, 7 Dec 2023 18:50:36 +0100 Subject: [PATCH 56/95] fixing torch_gather backend, add reshape and expand --- ivy/functional/backends/torch/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index f7c07330884ff..124870e22494f 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -203,7 +203,7 @@ def expand_p_i(params, indices): params_expanded = params_expanded.expand(params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1:]) - indices_expanded = indices.reshape.reshape(singleton_dims[:axis] + indices_expanded = indices.reshape(singleton_dims[:axis] + indices.shape[batch_dims:] + singleton_dims[axis + 1:]) indices_expanded = indices_expanded.expand(params.shape[:axis] From f1c6cbff4da1b173ee9da2d5643578a0064f1d0b Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Thu, 7 Dec 2023 19:16:14 +0100 Subject: [PATCH 57/95] fixing torch_gather backend, add reshape and expand --- ivy/functional/backends/torch/general.py | 34 ++++++++++++++---------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 124870e22494f..a2aa9d3a1f8e2 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -197,19 +197,25 @@ def expand_p_i(params, indices): + indices.shape[batch_dims:] + params.shape[axis + 1:])] singleton_dims = torch.Size(dim_helper_table) - params_expanded = params.reshape(params.shape[:axis] - + singleton_dims[batch_dims:] - + params.shape[axis + 1:]) - params_expanded = params_expanded.expand(params.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1:]) - indices_expanded = indices.reshape(singleton_dims[:axis] - + indices.shape[batch_dims:] - + singleton_dims[axis + 1:]) - indices_expanded = indices_expanded.expand(params.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1:]) - return params_expanded, indices_expanded + params_ex = params.reshape(params.shape[:axis] + + singleton_dims[batch_dims:] + + params.shape[axis + 1:]) + indices_ex = indices.reshape(singleton_dims[:axis] + + indices.shape[batch_dims:] + + singleton_dims[axis + 1:]) + if (params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1:]) != params.expanded.shape: + params_expanded = params_ex.expand(params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1:]) + if (indices.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1:]) != params.expanded.shape: + indices_expanded = indices_ex.expand(params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1:]) + return params_expanded, indices_expanded if batch_dims == 0: dim_diff = params.dim() - indices.dim() if dim_diff != 0: @@ -226,7 +232,7 @@ def expand_p_i(params, indices): result = result.reshape((params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1:]) - ) + ) else: dim_diff = params.dim() - indices.dim() params_expanded = params From 4cf9645c70e3f75d5597e032dde9dae62b20e70a Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Thu, 7 Dec 2023 19:26:38 +0100 Subject: [PATCH 58/95] fixing torch_gather backend, add reshape and expand --- ivy/functional/backends/torch/general.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index a2aa9d3a1f8e2..cb1be8f1949c4 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -205,17 +205,17 @@ def expand_p_i(params, indices): + singleton_dims[axis + 1:]) if (params.shape[:axis] + indices.shape[batch_dims:] - + params.shape[axis + 1:]) != params.expanded.shape: - params_expanded = params_ex.expand(params.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1:]) - if (indices.shape[:axis] + + params.shape[axis + 1:]) != params_ex.shape: + params_ex = params_ex.expand(params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1:]) + if (indices.shape[:axis] + indices.shape[batch_dims:] - + params.shape[axis + 1:]) != params.expanded.shape: - indices_expanded = indices_ex.expand(params.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1:]) - return params_expanded, indices_expanded + + params.shape[axis + 1:]) != indices_ex.shape: + indices_ex = indices_ex.expand(params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1:]) + return params_ex, indices_ex if batch_dims == 0: dim_diff = params.dim() - indices.dim() if dim_diff != 0: @@ -243,8 +243,8 @@ def expand_p_i(params, indices): indices_slices = torch.unbind(indices_expanded, axis=0) for b in range(batch_dims): if b == 0: - zip_list = [(p, i) for - p, i in + zip_list = [(p, i) for + p, i in zip(params_slices, indices_slices) ] else: From de6e629bb4dafed453f203262d585742430f60e1 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Thu, 7 Dec 2023 19:47:26 +0100 Subject: [PATCH 59/95] fixing torch_gather backend, add reshape and expand --- ivy/functional/backends/torch/general.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index cb1be8f1949c4..4336456ee4553 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -183,6 +183,7 @@ def gather( out: Optional[torch.Tensor] = None, ) -> torch.Tensor: axis %= len(params.shape) + axis = abs(len(params.shape) + axis) if axis < 0 else axis batch_dims %= len(params.shape) ivy.utils.assertions.check_gather_input_valid( params, From 3ee224182a412b67a79b745f23c8eff64bfa4148 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Thu, 7 Dec 2023 20:15:01 +0100 Subject: [PATCH 60/95] fixing torch_gather backend, add reshape and expand --- ivy/functional/backends/torch/general.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 4336456ee4553..d1952d7c91c5c 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -194,16 +194,18 @@ def gather( result = [] def expand_p_i(params, indices): - dim_helper_table = [1 for dim in list(params.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1:])] - singleton_dims = torch.Size(dim_helper_table) + res_dim_helper_table = [1 for dim in list(params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1:])] + res_singleton_dims = torch.Size(res_dim_helper_table) + param_singleton_dims_table = [1 for dim in list(params.shape)] + param_singleton_dims = torch.Size(param_singleton_dims_table) params_ex = params.reshape(params.shape[:axis] - + singleton_dims[batch_dims:] + + res_singleton_dims[batch_dims:] + params.shape[axis + 1:]) - indices_ex = indices.reshape(singleton_dims[:axis] + indices_ex = indices.reshape(param_singleton_dims[:axis] + indices.shape[batch_dims:] - + singleton_dims[axis + 1:]) + + param_singleton_dims[axis + 1:]) if (params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1:]) != params_ex.shape: From 8c8f83bd62ea94d441d9a3ebcbf3445a21811712 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Mon, 11 Dec 2023 14:36:24 +0100 Subject: [PATCH 61/95] gather_backend fixes --- ivy/functional/backends/torch/general.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index d1952d7c91c5c..a99a0e836591f 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -224,10 +224,10 @@ def expand_p_i(params, indices): if dim_diff != 0: params_expanded, indices_expanded = expand_p_i(params, indices) result = torch.gather( - params_expanded, axis, - indices_expanded.long(), - sparse_grad=False, out=out - ) + params_expanded, axis, + indices_expanded.long(), + sparse_grad=False, out=out + ) return result else: result = torch.gather( From 49c1e8f0f66f0b4dba82a36fabfc9bdb6f2acfee Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Thu, 14 Dec 2023 18:27:00 +0100 Subject: [PATCH 62/95] torch_gather_fix copied, not cleaned --- ivy/functional/backends/torch/general.py | 133 +++++++++++++---------- 1 file changed, 75 insertions(+), 58 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index a99a0e836591f..1397aca399778 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -185,90 +185,107 @@ def gather( axis %= len(params.shape) axis = abs(len(params.shape) + axis) if axis < 0 else axis batch_dims %= len(params.shape) - ivy.utils.assertions.check_gather_input_valid( - params, - indices, - axis, - batch_dims - ) + ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) result = [] def expand_p_i(params, indices): - res_dim_helper_table = [1 for dim in list(params.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1:])] - res_singleton_dims = torch.Size(res_dim_helper_table) + ind_dim_helper_table = [1 for dim in list(indices.shape)] + ind_singleton_dims = torch.Size(ind_dim_helper_table) param_singleton_dims_table = [1 for dim in list(params.shape)] param_singleton_dims = torch.Size(param_singleton_dims_table) - params_ex = params.reshape(params.shape[:axis] - + res_singleton_dims[batch_dims:] - + params.shape[axis + 1:]) - indices_ex = indices.reshape(param_singleton_dims[:axis] - + indices.shape[batch_dims:] - + param_singleton_dims[axis + 1:]) - if (params.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1:]) != params_ex.shape: - params_ex = params_ex.expand(params.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1:]) - if (indices.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1:]) != indices_ex.shape: - indices_ex = indices_ex.expand(params.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1:]) + + params_ex = ( + params + if (indices.dim() - batch_dims <= 1) + else params.reshape( + params.shape[:axis] + + ind_singleton_dims[batch_dims:] + + params.shape[axis + 1 :] + ) + ) + indices_ex = ( + indices + if (params.dim() <= 1) + else indices.reshape( + param_singleton_dims[:axis] + + indices.shape[batch_dims:] + + param_singleton_dims[axis + 1 :] + ) + ) + + if ( + params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] + ) >= params_ex.shape: + params_ex = params_ex.expand( + params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1 :] + ) + if ( + params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] + ) >= indices_ex.shape: + indices_ex = indices_ex.expand( + params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1 :] + ) return params_ex, indices_ex + if batch_dims == 0: dim_diff = params.dim() - indices.dim() if dim_diff != 0: params_expanded, indices_expanded = expand_p_i(params, indices) result = torch.gather( - params_expanded, axis, - indices_expanded.long(), - sparse_grad=False, out=out + params_expanded, + axis, + indices_expanded.long(), + sparse_grad=False, + out=out, ) return result else: + params_expanded, indices_expanded = expand_p_i(params, indices) result = torch.gather( - params, axis, indices.long(), sparse_grad=False, out=out) - result = result.reshape((params.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1:]) - ) + params_expanded, + axis, + indices_expanded.long(), + sparse_grad=False, + out=out, + ) + else: - dim_diff = params.dim() - indices.dim() - params_expanded = params - indices_expanded = indices - if dim_diff != 0: - params_expanded, indices_expanded = expand_p_i(params, indices) - params_slices = torch.unbind(params_expanded, axis=0) - indices_slices = torch.unbind(indices_expanded, axis=0) + params_slices = torch.unbind(params, axis=0) if params.dim() > 1 else params + indices_slices = torch.unbind(indices, axis=0) if indices.dim() > 1 else indices for b in range(batch_dims): if b == 0: - zip_list = [(p, i) for - p, i in - zip(params_slices, indices_slices) - ] + zip_list = [(p, i) for p, i in zip(params_slices, indices_slices)] else: zip_list = [ - (p, i) for z in [zip(p1, i1) for p1, i1 in zip_list] - for p, i in z - ] + (p, i) for z in [zip(p1, i1) for p1, i1 in zip_list] for p, i in z + ] for z in zip_list: p, i = z + dim_diff = p.dim() - i.dim() + p_expanded = p + i_expanded = i + if dim_diff != 0: + p_expanded, i_expanded = expand_p_i(p, i) r = torch.gather( - p, (axis - batch_dims), - torch.reshape(i.long(), (-1, )), - sparse_grad=False, out=None - ) + p_expanded, + (axis - batch_dims), + i_expanded.long(), + sparse_grad=False, + out=None, + ) result.append(r) result = torch.stack(result) - result = result.reshape((params.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1:]) - ) + result = result.reshape( + params.shape[:axis] + + max(indices.shape[batch_dims:], torch.Size([1])) + + params.shape[axis + 1 :] + ) + if ivy.exists(out): return ivy.inplace_update(out, result) From 3eed23b68f4ac019f4d215b43eb106c705ce2eed Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Thu, 14 Dec 2023 19:20:40 +0100 Subject: [PATCH 63/95] torch_gather_fix copied, not cleaned --- ivy/functional/backends/torch/general.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 1397aca399778..a134c534f7b07 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -195,11 +195,21 @@ def expand_p_i(params, indices): param_singleton_dims = torch.Size(param_singleton_dims_table) params_ex = ( - params - if (indices.dim() - batch_dims <= 1) + params.reshape( + ( + params.shape[:axis] + + max([ + ind_singleton_dims[batch_dims : batch_dims + 1], + params.shape[axis : axis + 1], + ]) + + ind_singleton_dims[batch_dims + 1 :] + + params.shape[axis + 1 :] + ) + ) + if (indices.dim() - batch_dims <= 1 or params.dim() - axis <= 1) else params.reshape( params.shape[:axis] - + ind_singleton_dims[batch_dims:] + + ind_singleton_dims[batch_dims + 1 :] + params.shape[axis + 1 :] ) ) From c31c34374efa3043e4f4041ed7597a439778f3cb Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Thu, 14 Dec 2023 19:44:16 +0100 Subject: [PATCH 64/95] torch_gather_fix copied, not cleaned --- ivy/functional/backends/torch/general.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index a134c534f7b07..6a7ba47f58c41 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -194,21 +194,13 @@ def expand_p_i(params, indices): param_singleton_dims_table = [1 for dim in list(params.shape)] param_singleton_dims = torch.Size(param_singleton_dims_table) - params_ex = ( - params.reshape( - ( - params.shape[:axis] - + max([ - ind_singleton_dims[batch_dims : batch_dims + 1], - params.shape[axis : axis + 1], - ]) - + ind_singleton_dims[batch_dims + 1 :] - + params.shape[axis + 1 :] - ) - ) - if (indices.dim() - batch_dims <= 1 or params.dim() - axis <= 1) - else params.reshape( + params_ex = params.reshape( + ( params.shape[:axis] + + max([ + ind_singleton_dims[batch_dims : batch_dims + 1], + params.shape[axis : axis + 1], + ]) + ind_singleton_dims[batch_dims + 1 :] + params.shape[axis + 1 :] ) From 298e46a3d4f1cd76bcf3114b0b8094108ff9b155 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Fri, 22 Dec 2023 17:16:49 +0100 Subject: [PATCH 65/95] torch_gather_fix fixing problems with the dimensions of params and indices --- ivy/functional/backends/torch/general.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 793d0114951a5..bd30e3c286f79 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -216,7 +216,16 @@ def expand_p_i(params, indices): if ( params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] - ) >= params_ex.shape: + ) >= params_ex.shape and ( + len( + list( + params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1 :] + ) + ) + != params_ex.dim() + ): params_ex = params_ex.expand( params.shape[:axis] + indices.shape[batch_dims:] @@ -224,7 +233,16 @@ def expand_p_i(params, indices): ) if ( params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] - ) >= indices_ex.shape: + ) >= indices_ex.shape and ( + len( + list( + params.shape[:axis] + + indices.shape[batch_dims:] + + params.shape[axis + 1 :] + ) + ) + != indices_ex.dim() + ): indices_ex = indices_ex.expand( params.shape[:axis] + indices.shape[batch_dims:] From ce8e863519edf5ef020cbe5715b561300dc08fd1 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Fri, 22 Dec 2023 17:30:59 +0100 Subject: [PATCH 66/95] torch_gather_fix fixing problems with the dimensions of params and indices --- ivy/functional/backends/torch/general.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index bd30e3c286f79..255ec95e47361 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -233,16 +233,7 @@ def expand_p_i(params, indices): ) if ( params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] - ) >= indices_ex.shape and ( - len( - list( - params.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1 :] - ) - ) - != indices_ex.dim() - ): + ) >= indices_ex.shape: indices_ex = indices_ex.expand( params.shape[:axis] + indices.shape[batch_dims:] From 93e5542590af0bd4c71deb3484d09267727fef43 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Fri, 22 Dec 2023 18:35:09 +0100 Subject: [PATCH 67/95] torch_gather_fix fixing problems with the dimensions of params and indices --- ivy/functional/backends/torch/general.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 255ec95e47361..e83e41845191b 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -214,26 +214,21 @@ def expand_p_i(params, indices): ) ) - if ( + params_ex_mask = torch.tensor( params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] - ) >= params_ex.shape and ( - len( - list( - params.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1 :] - ) - ) - != params_ex.dim() - ): + ) * (torch.tensor(params_ex.shape) == 1) + + if any((params_ex_mask != (torch.tensor(params_ex.shape) == 1)).flatten()): params_ex = params_ex.expand( params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] ) - if ( + indices_ex_mask = torch.tensor( params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] - ) >= indices_ex.shape: + ) * (torch.tensor(indices_ex.shape) == 1) + + if any((indices_ex_mask != (torch.tensor(indices_ex.shape) == 1)).flatten()): indices_ex = indices_ex.expand( params.shape[:axis] + indices.shape[batch_dims:] From e1bff9472612bfb7e543ecdf2fbe3e8d19383745 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Fri, 22 Dec 2023 19:02:36 +0100 Subject: [PATCH 68/95] torch_gather_fix fixing problems with the dimensions of params and indices --- ivy/functional/backends/torch/general.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index e83e41845191b..40c0f07b0b398 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -197,11 +197,8 @@ def expand_p_i(params, indices): params_ex = params.reshape(( params.shape[:axis] - + max([ - ind_singleton_dims[batch_dims : batch_dims + 1], - params.shape[axis : axis + 1], - ]) - + ind_singleton_dims[batch_dims + 1 :] + + ind_singleton_dims[batch_dims:-1] + + max([ind_singleton_dims[-1:], params.shape[axis : axis + 1]]) + params.shape[axis + 1 :] )) indices_ex = ( @@ -218,7 +215,7 @@ def expand_p_i(params, indices): params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] ) * (torch.tensor(params_ex.shape) == 1) - if any((params_ex_mask != (torch.tensor(params_ex.shape) == 1)).flatten()): + if any((params_ex_mask != (torch.tensor(params_ex.shape) == 1).flatten())): params_ex = params_ex.expand( params.shape[:axis] + indices.shape[batch_dims:] @@ -228,7 +225,7 @@ def expand_p_i(params, indices): params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] ) * (torch.tensor(indices_ex.shape) == 1) - if any((indices_ex_mask != (torch.tensor(indices_ex.shape) == 1)).flatten()): + if any((indices_ex_mask != (torch.tensor(indices_ex.shape) == 1).flatten())): indices_ex = indices_ex.expand( params.shape[:axis] + indices.shape[batch_dims:] From 48131e3391104f83122a9c59ad62184c5b2d7993 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Fri, 22 Dec 2023 19:31:35 +0100 Subject: [PATCH 69/95] torch_gather_fix fixing problems with the dimensions of params and indices --- ivy/functional/backends/torch/general.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 40c0f07b0b398..a0b0fc8b2ce5a 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -191,16 +191,26 @@ def gather( def expand_p_i(params, indices): ind_dim_helper_table = [1 for dim in list(indices.shape)] - ind_singleton_dims = torch.Size(ind_dim_helper_table) + torch.Size(ind_dim_helper_table) param_singleton_dims_table = [1 for dim in list(params.shape)] param_singleton_dims = torch.Size(param_singleton_dims_table) - params_ex = params.reshape(( - params.shape[:axis] - + ind_singleton_dims[batch_dims:-1] - + max([ind_singleton_dims[-1:], params.shape[axis : axis + 1]]) - + params.shape[axis + 1 :] - )) + params_insert_shape = ( + torch.tensor(indices.shape[batch_dims:]) + == torch.tensor(params.shape[axis : axis + 1]) + ).long() + +( + torch.tensor(indices.shape[batch_dims:]) + != torch.tensor(params.shape[axis : axis + 1]) + ).long() + + params_ex = ( + indices + if (params.dim() <= 1) + else params.reshape( + (params.shape[:axis] + params_insert_shape + params.shape[axis + 1 :]) + ) + ) indices_ex = ( indices if (params.dim() <= 1) From d557df58ecbd7ad83b3282d401f95cc305f13fcd Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Fri, 22 Dec 2023 19:44:14 +0100 Subject: [PATCH 70/95] torch_gather_fix fixing problems with the dimensions of params and indices --- ivy/functional/backends/torch/general.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index a0b0fc8b2ce5a..25ed9ed4a6d50 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -195,14 +195,16 @@ def expand_p_i(params, indices): param_singleton_dims_table = [1 for dim in list(params.shape)] param_singleton_dims = torch.Size(param_singleton_dims_table) - params_insert_shape = ( - torch.tensor(indices.shape[batch_dims:]) - == torch.tensor(params.shape[axis : axis + 1]) - ).long() - +( - torch.tensor(indices.shape[batch_dims:]) - != torch.tensor(params.shape[axis : axis + 1]) - ).long() + params_insert_shape = torch.Size([ + ( + torch.tensor(indices.shape[batch_dims:]) + == torch.tensor(params.shape[axis : axis + 1]) + ).long() + + ( + torch.tensor(indices.shape[batch_dims:]) + != torch.tensor(params.shape[axis : axis + 1]) + ).long() + ]) params_ex = ( indices From 19141e2f1cefd07edec86993c85faece2919e471 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Fri, 22 Dec 2023 20:05:41 +0100 Subject: [PATCH 71/95] torch_gather_fix fixing problems with the dimensions of params and indices --- ivy/functional/backends/torch/general.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 25ed9ed4a6d50..d82098f2e0526 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -256,6 +256,7 @@ def expand_p_i(params, indices): sparse_grad=False, out=out, ) + result = result.to(params.dtype) return result else: params_expanded, indices_expanded = expand_p_i(params, indices) @@ -299,6 +300,7 @@ def expand_p_i(params, indices): + max(indices.shape[batch_dims:], torch.Size([1])) + params.shape[axis + 1 :] ) + result = result.to(params.dtype) if ivy.exists(out): return ivy.inplace_update(out, result) From 54326ed07bac93593183843880ff606cff4d4ad0 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Fri, 22 Dec 2023 20:19:11 +0100 Subject: [PATCH 72/95] torch_gather_fix fixing problems with the dimensions of params and indices --- ivy/functional/backends/torch/general.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index d82098f2e0526..011018e62297a 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -195,16 +195,16 @@ def expand_p_i(params, indices): param_singleton_dims_table = [1 for dim in list(params.shape)] param_singleton_dims = torch.Size(param_singleton_dims_table) - params_insert_shape = torch.Size([ - ( - torch.tensor(indices.shape[batch_dims:]) - == torch.tensor(params.shape[axis : axis + 1]) - ).long() - + ( - torch.tensor(indices.shape[batch_dims:]) - != torch.tensor(params.shape[axis : axis + 1]) - ).long() - ]) + params_insert_shape = ( + torch.tensor(indices.shape[batch_dims:]) + == torch.tensor(params.shape[axis : axis + 1]) + ).long() + +( + torch.tensor(indices.shape[batch_dims:]) + != torch.tensor(params.shape[axis : axis + 1]) + ).long() + params_insert_shape = [dim for dim in params_insert_shape] + params_insert_shape = torch.Size(params_insert_shape) params_ex = ( indices From 4b34ec71b03c347cc6970e2079fa15edb2792728 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Fri, 22 Dec 2023 20:34:54 +0100 Subject: [PATCH 73/95] torch_gather_fix fixing problems with the dimensions of params and indices --- ivy/functional/backends/torch/general.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 011018e62297a..9f6546c21efb2 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -267,6 +267,7 @@ def expand_p_i(params, indices): sparse_grad=False, out=out, ) + result = result.to(dtype=params.dtype) else: params_slices = torch.unbind(params, axis=0) if params.dim() > 1 else params From 7c3cfc0ce9722b63e7e85372db602d79dd8cef69 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Fri, 22 Dec 2023 21:47:01 +0100 Subject: [PATCH 74/95] torch_gather_fix fixing problems with the dimensions of params and indices --- ivy/functional/backends/torch/general.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 9f6546c21efb2..dbd6531bbeb11 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -204,7 +204,11 @@ def expand_p_i(params, indices): != torch.tensor(params.shape[axis : axis + 1]) ).long() params_insert_shape = [dim for dim in params_insert_shape] - params_insert_shape = torch.Size(params_insert_shape) + params_insert_shape = ( + torch.Size(params_insert_shape) + if torch.Size(params_insert_shape) != torch.Size([0]) + else params.shape[axis : axis + 1] + ) params_ex = ( indices From cd07b61604196157fdcbeafc5587801c2ce467a4 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Fri, 22 Dec 2023 22:05:13 +0100 Subject: [PATCH 75/95] torch_gather_fix fixing problems with the dimensions of params and indices --- ivy/functional/backends/torch/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index dbd6531bbeb11..f4790ca7a61a2 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -211,7 +211,7 @@ def expand_p_i(params, indices): ) params_ex = ( - indices + params if (params.dim() <= 1) else params.reshape( (params.shape[:axis] + params_insert_shape + params.shape[axis + 1 :]) From 7ac2867d479d9a29385c1048742fc5433a4938a8 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Fri, 22 Dec 2023 23:04:16 +0100 Subject: [PATCH 76/95] torch_gather_fix fixing problems with the dimensions of params and indices --- ivy/functional/backends/torch/general.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index f4790ca7a61a2..b877732d8cec6 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -196,12 +196,10 @@ def expand_p_i(params, indices): param_singleton_dims = torch.Size(param_singleton_dims_table) params_insert_shape = ( - torch.tensor(indices.shape[batch_dims:]) - == torch.tensor(params.shape[axis : axis + 1]) - ).long() - +( - torch.tensor(indices.shape[batch_dims:]) - != torch.tensor(params.shape[axis : axis + 1]) + (torch.tensor(indices.shape[batch_dims:]) == params.shape[axis]) + * params.shape[axis] + ).long() + ( + torch.tensor(indices.shape[batch_dims:]) != params.shape[axis] ).long() params_insert_shape = [dim for dim in params_insert_shape] params_insert_shape = ( @@ -210,12 +208,8 @@ def expand_p_i(params, indices): else params.shape[axis : axis + 1] ) - params_ex = ( - params - if (params.dim() <= 1) - else params.reshape( - (params.shape[:axis] + params_insert_shape + params.shape[axis + 1 :]) - ) + params_ex = params.reshape( + (params.shape[:axis] + params_insert_shape + params.shape[axis + 1 :]) ) indices_ex = ( indices From cd02d6a680ef0ea0a9ac7c0e07aed11edf39af8c Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Sat, 23 Dec 2023 00:48:16 +0100 Subject: [PATCH 77/95] torch_gather_fix fixing problems with the dimensions of params and indices + --- ivy/functional/backends/torch/general.py | 26 ++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index b877732d8cec6..2f14b8330fb1c 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -196,21 +196,21 @@ def expand_p_i(params, indices): param_singleton_dims = torch.Size(param_singleton_dims_table) params_insert_shape = ( - (torch.tensor(indices.shape[batch_dims:]) == params.shape[axis]) - * params.shape[axis] + ( + torch.tensor(indices.shape[batch_dims:]) + == params.shape[axis % len(params)] + ) + * params.shape[axis % len(params)] ).long() + ( - torch.tensor(indices.shape[batch_dims:]) != params.shape[axis] + torch.tensor(indices.shape[batch_dims:]) != params.shape[axis % len(params)] ).long() params_insert_shape = [dim for dim in params_insert_shape] params_insert_shape = ( torch.Size(params_insert_shape) if torch.Size(params_insert_shape) != torch.Size([0]) - else params.shape[axis : axis + 1] + else indices.shape[batch_dims:] ) - params_ex = params.reshape( - (params.shape[:axis] + params_insert_shape + params.shape[axis + 1 :]) - ) indices_ex = ( indices if (params.dim() <= 1) @@ -220,6 +220,16 @@ def expand_p_i(params, indices): + param_singleton_dims[axis + 1 :] ) ) + params_shape = list( + params.shape[:axis] + params_insert_shape + params.shape[axis + 1 :] + ) + if params.shape[axis] not in params_shape: + params_shape[-1] = params.shape[axis] + params_shape = torch.Size(params_shape) + + params_ex = ( + params if indices_ex.dim() == params.dim() else params.reshape(params_shape) + ) params_ex_mask = torch.tensor( params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] @@ -234,13 +244,13 @@ def expand_p_i(params, indices): indices_ex_mask = torch.tensor( params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] ) * (torch.tensor(indices_ex.shape) == 1) - if any((indices_ex_mask != (torch.tensor(indices_ex.shape) == 1).flatten())): indices_ex = indices_ex.expand( params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] ) + return params_ex, indices_ex if batch_dims == 0: From c587e9f8decd0b3f119a8f01a96b2fd2b4d4e86d Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Sat, 23 Dec 2023 00:50:54 +0100 Subject: [PATCH 78/95] torch_gather_fix fixing problems with the dimensions of params and indices + --- ivy/functional/backends/torch/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 2f14b8330fb1c..8949b30fae202 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -223,7 +223,7 @@ def expand_p_i(params, indices): params_shape = list( params.shape[:axis] + params_insert_shape + params.shape[axis + 1 :] ) - if params.shape[axis] not in params_shape: + if params.shape[axis % len(params)] not in params_shape: params_shape[-1] = params.shape[axis] params_shape = torch.Size(params_shape) From a92239f585d85ed3367ea37b94820d208edb90bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20Ko=C5=BCdo=C5=84?= <102428159+Kacper-W-Kozdon@users.noreply.github.com> Date: Sat, 23 Dec 2023 01:38:08 +0100 Subject: [PATCH 79/95] remove manipulation.py from torch.gather() PR --- ivy/functional/backends/torch/manipulation.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/ivy/functional/backends/torch/manipulation.py b/ivy/functional/backends/torch/manipulation.py index 18fe925de1ad1..4cfb62addec6e 100644 --- a/ivy/functional/backends/torch/manipulation.py +++ b/ivy/functional/backends/torch/manipulation.py @@ -39,16 +39,6 @@ def concat( concat.support_native_out = True -def unflatten( - x: torch.Tensor, - /, - *, - axis: int = 0, - sizes: Tuple[int] = None, -) -> torch.Tensor: - return torch.unflatten(input=x, dim=axis, sizes=sizes) - - def expand_dims( x: torch.Tensor, /, From 09d689e55f97a389a621fe2830d64cad239ad1c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20Ko=C5=BCdo=C5=84?= <102428159+Kacper-W-Kozdon@users.noreply.github.com> Date: Sat, 23 Dec 2023 01:39:38 +0100 Subject: [PATCH 80/95] remove miscellaneous_ops.py from torch.gather() PR --- ivy/functional/frontends/torch/miscellaneous_ops.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/ivy/functional/frontends/torch/miscellaneous_ops.py b/ivy/functional/frontends/torch/miscellaneous_ops.py index 01a85bc6b4374..446227f997913 100644 --- a/ivy/functional/frontends/torch/miscellaneous_ops.py +++ b/ivy/functional/frontends/torch/miscellaneous_ops.py @@ -502,11 +502,6 @@ def triu_indices(row, col, offset=0, dtype="int64", device="cpu", layout=None): return ivy.stack(ivy.nonzero(sample_matrix)).astype(dtype) -@to_ivy_arrays_and_back -def unflatten(x, /, *, axis, sizes): - return ivy.unflatten(x, axis, sizes) - - @to_ivy_arrays_and_back def vander(x, N=None, increasing=False): # if N == 0: From b2772c9d248ac46b014fc6ae82efebcb40876ba5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20Ko=C5=BCdo=C5=84?= <102428159+Kacper-W-Kozdon@users.noreply.github.com> Date: Sat, 23 Dec 2023 01:41:18 +0100 Subject: [PATCH 81/95] remove manipulation.py from torch.gather() PR --- ivy/functional/ivy/manipulation.py | 47 ------------------------------ 1 file changed, 47 deletions(-) diff --git a/ivy/functional/ivy/manipulation.py b/ivy/functional/ivy/manipulation.py index 5c2a433302a79..a039a9159e3a0 100644 --- a/ivy/functional/ivy/manipulation.py +++ b/ivy/functional/ivy/manipulation.py @@ -327,53 +327,6 @@ def flip( return current_backend(x).flip(x, copy=copy, axis=axis, out=out) -@handle_exceptions -@handle_backend_invalid -@handle_nestable -@handle_array_like_without_promotion -@handle_view -@handle_out_argument -@to_native_arrays_and_back -@handle_array_function -@handle_device -def unflatten( - x: Union[ivy.Array, ivy.NativeArray], - /, - *, - axis: int, - sizes: Tuple[int], -) -> ivy.Array: - """Expand a dimension of the input tensor over multiple dimensions. - - Parameters - ---------- - x - The input tensor. - axis - Dimension to be unflattened, specified as an index into input.shape. - sizes - New shape of the unflattened dimension. One of its elements can be -1 in - which case the corresponding output dimension is inferred. Otherwise, - the product of sizes must equal input.shape[dim]. - - Returns - ------- - ret - A View of input with the specified dimension unflattened. - - - Examples - -------- - >>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape - torch.Size([3, 2, 2, 1]) - >>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape - torch.Size([3, 2, 2, 1]) - >>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape - torch.Size([5, 2, 2, 3, 1, 1, 3]) - """ - return current_backend(x).unflatten(x, axis, sizes) - - @handle_exceptions @handle_backend_invalid @handle_nestable From 2a12e0c6fe5079b4a7148281f10f32a1ec877105 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 26 Dec 2023 12:07:37 +0100 Subject: [PATCH 82/95] torch_gather_fix for pytest --- ivy/functional/backends/torch/general.py | 31 ++++++++++++------------ 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 8949b30fae202..5a5b314acd31e 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -189,7 +189,7 @@ def gather( ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) result = [] - def expand_p_i(params, indices): + def expand_p_i(params, indices, axis=axis): ind_dim_helper_table = [1 for dim in list(indices.shape)] torch.Size(ind_dim_helper_table) param_singleton_dims_table = [1 for dim in list(params.shape)] @@ -226,7 +226,7 @@ def expand_p_i(params, indices): if params.shape[axis % len(params)] not in params_shape: params_shape[-1] = params.shape[axis] params_shape = torch.Size(params_shape) - + new_axis = (axis + len(params_insert_shape) - 1) % len(params_shape) params_ex = ( params if indices_ex.dim() == params.dim() else params.reshape(params_shape) ) @@ -244,33 +244,34 @@ def expand_p_i(params, indices): indices_ex_mask = torch.tensor( params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] ) * (torch.tensor(indices_ex.shape) == 1) + if any((indices_ex_mask != (torch.tensor(indices_ex.shape) == 1).flatten())): indices_ex = indices_ex.expand( params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] ) - - return params_ex, indices_ex + return params_ex, indices_ex, new_axis if batch_dims == 0: dim_diff = params.dim() - indices.dim() if dim_diff != 0: - params_expanded, indices_expanded = expand_p_i(params, indices) + params_expanded, indices_expanded, new_axis = expand_p_i(params, indices) result = torch.gather( params_expanded, - axis, + new_axis, indices_expanded.long(), sparse_grad=False, out=out, ) - result = result.to(params.dtype) + result = result.to(dtype=params.dtype) return result else: - params_expanded, indices_expanded = expand_p_i(params, indices) + params_expanded, indices_expanded, new_axis = expand_p_i(params, indices) + result = torch.gather( params_expanded, - axis, + new_axis, indices_expanded.long(), sparse_grad=False, out=out, @@ -280,6 +281,7 @@ def expand_p_i(params, indices): else: params_slices = torch.unbind(params, axis=0) if params.dim() > 1 else params indices_slices = torch.unbind(indices, axis=0) if indices.dim() > 1 else indices + for b in range(batch_dims): if b == 0: zip_list = [(p, i) for p, i in zip(params_slices, indices_slices)] @@ -293,10 +295,11 @@ def expand_p_i(params, indices): p_expanded = p i_expanded = i if dim_diff != 0: - p_expanded, i_expanded = expand_p_i(p, i) + p_expanded, i_expanded, new_axis = expand_p_i(p, i) + r = torch.gather( p_expanded, - (axis - batch_dims), + (new_axis - batch_dims), i_expanded.long(), sparse_grad=False, out=None, @@ -309,11 +312,7 @@ def expand_p_i(params, indices): + max(indices.shape[batch_dims:], torch.Size([1])) + params.shape[axis + 1 :] ) - result = result.to(params.dtype) - - if ivy.exists(out): - return ivy.inplace_update(out, result) - + result = result.to(dtype=params.dtype) return result From f1a07b10c55e98af0ca305737dedd067660503fa Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 26 Dec 2023 22:58:47 +0100 Subject: [PATCH 83/95] torch_gather_fix for pytest --- ivy/functional/backends/torch/general.py | 133 +++++++---------------- 1 file changed, 42 insertions(+), 91 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 5a5b314acd31e..01bcec6bfaef9 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -189,99 +189,57 @@ def gather( ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) result = [] - def expand_p_i(params, indices, axis=axis): - ind_dim_helper_table = [1 for dim in list(indices.shape)] - torch.Size(ind_dim_helper_table) - param_singleton_dims_table = [1 for dim in list(params.shape)] - param_singleton_dims = torch.Size(param_singleton_dims_table) - - params_insert_shape = ( - ( - torch.tensor(indices.shape[batch_dims:]) - == params.shape[axis % len(params)] - ) - * params.shape[axis % len(params)] - ).long() + ( - torch.tensor(indices.shape[batch_dims:]) != params.shape[axis % len(params)] - ).long() - params_insert_shape = [dim for dim in params_insert_shape] - params_insert_shape = ( - torch.Size(params_insert_shape) - if torch.Size(params_insert_shape) != torch.Size([0]) - else indices.shape[batch_dims:] - ) - - indices_ex = ( - indices - if (params.dim() <= 1) - else indices.reshape( - param_singleton_dims[:axis] - + indices.shape[batch_dims:] - + param_singleton_dims[axis + 1 :] - ) - ) - params_shape = list( - params.shape[:axis] + params_insert_shape + params.shape[axis + 1 :] - ) - if params.shape[axis % len(params)] not in params_shape: - params_shape[-1] = params.shape[axis] - params_shape = torch.Size(params_shape) - new_axis = (axis + len(params_insert_shape) - 1) % len(params_shape) - params_ex = ( - params if indices_ex.dim() == params.dim() else params.reshape(params_shape) - ) - - params_ex_mask = torch.tensor( - params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] - ) * (torch.tensor(params_ex.shape) == 1) + def expand_p_i(params, indices, axis=axis, batch_dims=batch_dims): + axis %= len(params.shape) + new_axis = axis + indices_ex = torch.clone(indices).detach() + abs(params.dim() - (indices_ex.dim() - batch_dims)) + stack_dims1 = params.shape[:axis] + stack_dims2 = params.shape[axis + 1 :] + indices_ex = indices_ex.reshape(( + torch.Size([1 for dim in stack_dims1]) + + torch.Size([-1]) + + torch.Size([1 for dim in stack_dims2]) + )) + indices_ex = indices_ex.expand( + (stack_dims1 + torch.Size([-1]) + stack_dims2) + ).reshape((stack_dims1 + torch.Size([-1]) + stack_dims2)) + return indices_ex, new_axis + + final_shape = ( + params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] + ) - if any((params_ex_mask != (torch.tensor(params_ex.shape) == 1).flatten())): - params_ex = params_ex.expand( - params.shape[:axis] - + indices.shape[batch_dims:] - + params.shape[axis + 1 :] - ) - indices_ex_mask = torch.tensor( - params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] - ) * (torch.tensor(indices_ex.shape) == 1) + if batch_dims == 0: + dim_diff = abs(params.dim() - (indices.dim() - batch_dims)) + if dim_diff != 0: + indices_expanded, new_axis = expand_p_i(params, indices) - if any((indices_ex_mask != (torch.tensor(indices_ex.shape) == 1).flatten())): - indices_ex = indices_ex.expand( + result = torch.gather( + params, new_axis, indices_expanded.long(), sparse_grad=False, out=out + ).reshape( params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] ) - return params_ex, indices_ex, new_axis - - if batch_dims == 0: - dim_diff = params.dim() - indices.dim() - if dim_diff != 0: - params_expanded, indices_expanded, new_axis = expand_p_i(params, indices) - result = torch.gather( - params_expanded, - new_axis, - indices_expanded.long(), - sparse_grad=False, - out=out, - ) result = result.to(dtype=params.dtype) + if ivy.exists(out): + return ivy.inplace_update(out, result) return result else: - params_expanded, indices_expanded, new_axis = expand_p_i(params, indices) - + indices_expanded, new_axis = expand_p_i(params, indices) result = torch.gather( - params_expanded, - new_axis, - indices_expanded.long(), - sparse_grad=False, - out=out, - ) + params, new_axis, indices_expanded.long(), sparse_grad=False, out=out + ).reshape(final_shape) result = result.to(dtype=params.dtype) else: - params_slices = torch.unbind(params, axis=0) if params.dim() > 1 else params - indices_slices = torch.unbind(indices, axis=0) if indices.dim() > 1 else indices - + indices_ex = indices + new_axis = axis + params_slices = torch.unbind(params, axis=0) if params.shape[0] > 0 else params + indices_slices = ( + torch.unbind(indices_ex, axis=0) if indices.shape[0] > 0 else indices_ex + ) for b in range(batch_dims): if b == 0: zip_list = [(p, i) for p, i in zip(params_slices, indices_slices)] @@ -291,18 +249,9 @@ def expand_p_i(params, indices, axis=axis): ] for z in zip_list: p, i = z - dim_diff = p.dim() - i.dim() - p_expanded = p - i_expanded = i - if dim_diff != 0: - p_expanded, i_expanded, new_axis = expand_p_i(p, i) - + i_ex, new_axis = expand_p_i(p, i) r = torch.gather( - p_expanded, - (new_axis - batch_dims), - i_expanded.long(), - sparse_grad=False, - out=None, + p, (new_axis - batch_dims), i_ex.long(), sparse_grad=False, out=None ) result.append(r) @@ -313,6 +262,8 @@ def expand_p_i(params, indices, axis=axis): + params.shape[axis + 1 :] ) result = result.to(dtype=params.dtype) + if ivy.exists(out): + return ivy.inplace_update(out, result) return result From da03caecdb434aebf750d8645fcdacabee4b4dcc Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Wed, 27 Dec 2023 00:05:48 +0100 Subject: [PATCH 84/95] torch_gather_fix for pytest --- ivy/functional/backends/torch/general.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 01bcec6bfaef9..34f6a5bed0956 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -223,8 +223,6 @@ def expand_p_i(params, indices, axis=axis, batch_dims=batch_dims): + params.shape[axis + 1 :] ) result = result.to(dtype=params.dtype) - if ivy.exists(out): - return ivy.inplace_update(out, result) return result else: indices_expanded, new_axis = expand_p_i(params, indices) @@ -249,11 +247,8 @@ def expand_p_i(params, indices, axis=axis, batch_dims=batch_dims): ] for z in zip_list: p, i = z - i_ex, new_axis = expand_p_i(p, i) - r = torch.gather( - p, (new_axis - batch_dims), i_ex.long(), sparse_grad=False, out=None - ) - + i_ex, new_axis = expand_p_i(p, i, axis=axis - batch_dims) + r = torch.gather(p, (new_axis), i_ex.long(), sparse_grad=False, out=None) result.append(r) result = torch.stack(result) result = result.reshape( From 5c604fc71a544ac1528f2be6c27ceef8f1d87e02 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Mon, 22 Jan 2024 18:13:37 +0100 Subject: [PATCH 85/95] torch gather fix- removed redundant var --- ivy/functional/backends/torch/general.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 50a4ede8e1e37..6b89a0fc1771e 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -191,7 +191,6 @@ def gather( def expand_p_i(params, indices, axis=axis, batch_dims=batch_dims): axis %= len(params.shape) - new_axis = axis indices_ex = torch.clone(indices).detach() abs(params.dim() - (indices_ex.dim() - batch_dims)) stack_dims1 = params.shape[:axis] @@ -204,7 +203,7 @@ def expand_p_i(params, indices, axis=axis, batch_dims=batch_dims): indices_ex = indices_ex.expand( (stack_dims1 + torch.Size([-1]) + stack_dims2) ).reshape((stack_dims1 + torch.Size([-1]) + stack_dims2)) - return indices_ex, new_axis + return indices_ex, axis final_shape = ( params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] From 33602e7d38b1dce8b28688c1685ee6bcdd1a8220 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Fri, 26 Jan 2024 18:38:05 +0100 Subject: [PATCH 86/95] added tolerance for test_gather --- ivy_tests/test_ivy/test_functional/test_core/test_general.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_general.py b/ivy_tests/test_ivy/test_functional/test_core/test_general.py index 951c1d05e804c..d770f2ee985dd 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_general.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_general.py @@ -1021,6 +1021,8 @@ def test_gather(params_indices_others, test_flags, backend_fw, fn_name, on_devic params=params, indices=indices, axis=axis, + atol_=1e-3, + rtol_=1e-3, batch_dims=batch_dims, ) From f1c24601f964eaad6f6352837552968c9507ba20 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Tue, 6 Feb 2024 02:04:33 +0100 Subject: [PATCH 87/95] added safety factors in test_gather- needs associated PR update to a helper func --- ivy_tests/test_ivy/test_functional/test_core/test_general.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_general.py b/ivy_tests/test_ivy/test_functional/test_core/test_general.py index d770f2ee985dd..359fff4d141b6 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_general.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_general.py @@ -1007,6 +1007,9 @@ def test_function_unsupported_devices(func, backend_fw): max_num_dims=5, min_dim_size=1, max_dim_size=10, + large_abs_safety_factor=1.5, + small_abs_safety_factor=1.5, + safety_factor_scale="log", ), ) def test_gather(params_indices_others, test_flags, backend_fw, fn_name, on_device): From 0f7047be687bbbd1f3b5e23fd631d3cd7963ef81 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Thu, 29 Feb 2024 10:24:33 +0100 Subject: [PATCH 88/95] torch.gather modify axis --- ivy/functional/backends/torch/general.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 6b89a0fc1771e..7905cb35d8b89 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -183,7 +183,6 @@ def gather( batch_dims: int = 0, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: - axis %= len(params.shape) axis = abs(len(params.shape) + axis) if axis < 0 else axis batch_dims %= len(params.shape) ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) From 53c652f9c313f6107a7f6dfafb9222190b03a6c0 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Thu, 29 Feb 2024 10:35:53 +0100 Subject: [PATCH 89/95] torch.gather modify axis --- ivy/functional/backends/torch/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 7905cb35d8b89..05f4192f97013 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -183,7 +183,7 @@ def gather( batch_dims: int = 0, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: - axis = abs(len(params.shape) + axis) if axis < 0 else axis + axis %= len(params.shape) batch_dims %= len(params.shape) ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) result = [] From c5d6f6cc4b70a52b6800233f01b136cfabfd152b Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Thu, 29 Feb 2024 09:38:40 +0000 Subject: [PATCH 90/95] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ivy/functional/backends/torch/general.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index f693387be067f..38d119be76fa0 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -197,11 +197,13 @@ def expand_p_i(params, indices, axis=axis, batch_dims=batch_dims): abs(params.dim() - (indices_ex.dim() - batch_dims)) stack_dims1 = params.shape[:axis] stack_dims2 = params.shape[axis + 1 :] - indices_ex = indices_ex.reshape(( - torch.Size([1 for dim in stack_dims1]) - + torch.Size([-1]) - + torch.Size([1 for dim in stack_dims2]) - )) + indices_ex = indices_ex.reshape( + ( + torch.Size([1 for dim in stack_dims1]) + + torch.Size([-1]) + + torch.Size([1 for dim in stack_dims2]) + ) + ) indices_ex = indices_ex.expand( (stack_dims1 + torch.Size([-1]) + stack_dims2) ).reshape((stack_dims1 + torch.Size([-1]) + stack_dims2)) From f008bf5545724394e825dd07f2b031191b08cc26 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Thu, 29 Feb 2024 10:41:13 +0100 Subject: [PATCH 91/95] pull --- docs/demos | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/demos b/docs/demos index d3fa2b9c2a7ff..104f88a7be642 160000 --- a/docs/demos +++ b/docs/demos @@ -1 +1 @@ -Subproject commit d3fa2b9c2a7ffa93573bb63d2e66abbe3dd198fc +Subproject commit 104f88a7be64234ec58950deed8142bc7748d9da From f05d18ffd53d9bfbd092491fe14c3b060ebd4c30 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Fri, 1 Mar 2024 17:40:58 +0100 Subject: [PATCH 92/95] gather- remove redundant --- ivy/functional/backends/torch/general.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 38d119be76fa0..497719007359a 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -193,21 +193,20 @@ def gather( def expand_p_i(params, indices, axis=axis, batch_dims=batch_dims): axis %= len(params.shape) - indices_ex = torch.clone(indices).detach() - abs(params.dim() - (indices_ex.dim() - batch_dims)) + abs(params.dim() - (indices.dim() - batch_dims)) stack_dims1 = params.shape[:axis] stack_dims2 = params.shape[axis + 1 :] - indices_ex = indices_ex.reshape( + indices = indices.reshape( ( torch.Size([1 for dim in stack_dims1]) + torch.Size([-1]) + torch.Size([1 for dim in stack_dims2]) ) ) - indices_ex = indices_ex.expand( + indices = indices.expand( (stack_dims1 + torch.Size([-1]) + stack_dims2) ).reshape((stack_dims1 + torch.Size([-1]) + stack_dims2)) - return indices_ex, axis + return indices, axis final_shape = ( params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1 :] From edbbf85c4c4448965a6a7cd318649d773b706ee1 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Mon, 4 Mar 2024 10:09:26 +0100 Subject: [PATCH 93/95] demos revert --- docs/demos | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/demos b/docs/demos index 104f88a7be642..d3fa2b9c2a7ff 160000 --- a/docs/demos +++ b/docs/demos @@ -1 +1 @@ -Subproject commit 104f88a7be64234ec58950deed8142bc7748d9da +Subproject commit d3fa2b9c2a7ffa93573bb63d2e66abbe3dd198fc From b9dd6d61a8d4c55fa3a40c0b982afa938a664e02 Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Wed, 13 Mar 2024 12:15:08 +0100 Subject: [PATCH 94/95] explicit bfloat16 tolerance in core test --- ivy_tests/test_ivy/test_functional/test_core/test_general.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_general.py b/ivy_tests/test_ivy/test_functional/test_core/test_general.py index 497772b416968..cc8851bfd5428 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_general.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_general.py @@ -1030,6 +1030,7 @@ def test_gather(params_indices_others, test_flags, backend_fw, fn_name, on_devic axis=axis, atol_=1e-3, rtol_=1e-3, + tolerance_dict={"bfloat16": 1e-1}, batch_dims=batch_dims, ) From 7a88a1f7ed6a8655f0e5eb4b107e88ee3571087b Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Wed, 13 Mar 2024 12:20:03 +0100 Subject: [PATCH 95/95] explicit bfloat16 tolerance in core test --- docs/demos | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/demos b/docs/demos index 15c235ff3aaff..ab2034037be50 160000 --- a/docs/demos +++ b/docs/demos @@ -1 +1 @@ -Subproject commit 15c235ff3aaff4903b80b1c5d574bcc116c24ac1 +Subproject commit ab2034037be50a796f4ef0efcbbf786b02e32cfe