diff --git a/ivy/functional/backends/tensorflow/general.py b/ivy/functional/backends/tensorflow/general.py index 0346b00b9641e..275770f162009 100644 --- a/ivy/functional/backends/tensorflow/general.py +++ b/ivy/functional/backends/tensorflow/general.py @@ -316,6 +316,8 @@ def scatter_flat( if ivy.exists(size) and ivy.exists(target): ivy.utils.assertions.check_equal(len(target.shape), 1, as_array=False) ivy.utils.assertions.check_equal(target.shape[0], size, as_array=False) + if target_given: + updates = ivy.astype(updates, target.dtype) if not target_given: target = tf.zeros([size], dtype=updates.dtype) res = tf.tensor_scatter_nd_update(target, tf.expand_dims(indices, -1), updates) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index e1abd35a45c64..2a78520a9a9f4 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1,6 +1,7 @@ # global from typing import Iterable import math +from ivy.functional.frontends.torch.func_wrapper import to_ivy_arrays_and_back # local import ivy @@ -2336,6 +2337,10 @@ def minimum(self, other, *, out=None): def rad2deg(self, *, out=None): return torch_frontend.rad2deg(self, out=out) + def fill_diagonal_(self, fill_value, wrap=False): + self._ivy_array = ivy.fill_diagonal(self._ivy_array, fill_value, wrap=wrap) + return self + @with_supported_dtypes( {"2.2 and below": "valid"}, "torch", diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index a1fe83a3b5d22..960095e9552c8 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -6907,6 +6907,49 @@ def test_torch_fill_( ) +# fill_diagonal_ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="fill_diagonal_", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + min_dim_size=2, + max_num_dims=2, + ), + val=helpers.floats(min_value=-10, max_value=10), + wrap=st.booleans(), +) +def test_torch_fill_diagonal_( + dtype_x_axis, + wrap, + val, + frontend, + frontend_method_data, + init_flags, + method_flags, + on_device, + backend_fw, +): + input_dtype, x, axis = dtype_x_axis + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "fill_value": val, + "wrap": wrap, + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # fix @handle_frontend_method( class_tree=CLASS_TREE,