diff --git a/ivy/functional/frontends/torch/miscellaneous_ops.py b/ivy/functional/frontends/torch/miscellaneous_ops.py index 98074a690de62..61624d2398762 100644 --- a/ivy/functional/frontends/torch/miscellaneous_ops.py +++ b/ivy/functional/frontends/torch/miscellaneous_ops.py @@ -211,6 +211,55 @@ def diag(input, diagonal=0, *, out=None): return ivy.diag(input, k=diagonal) +@to_ivy_arrays_and_back +def diag_embed( + input, + offset=0, + dim1=-2, + dim2=-1, +): + def _handle_dim(rank, idx): + if idx >= 0 and idx < rank: + return idx + if idx < 0: + idx = idx + rank + if idx < 0 or idx >= rank: + raise IndexError + return idx + + input_type = ivy.dtype(input) + rank = input.ndim + 1 + dim1 = _handle_dim(rank, dim1) + dim2 = _handle_dim(rank, dim2) + if dim1 > dim2: + dim1, dim2 = dim2, dim1 + offset = -offset + last_dim = list(input.shape)[-1] + if offset != 0: + # add padding to match the new size + t_shape = list(input.shape) + t_shape[-1] = abs(offset) + z = ivy.zeros(t_shape, dtype=input.dtype, device=input.device) + pair = (z, input) if offset > 0 else (input, z) + input = ivy.concat(pair, axis=-1) + last_dim += abs(offset) + input = input.expand_dims(axis=dim1).moveaxis(-1, dim2) + # generate ranges shifting indices based on offset + a_range = ivy.arange(last_dim, device=input.device, dtype=ivy.int64) + b_range = ivy.arange( + offset, last_dim + offset, device=input.device, dtype=ivy.int64 + ) + # broadcast + cond = a_range == b_range.expand_dims(axis=-1) + cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(input.shape))] + cond = cond.reshape(cond_shape) + if input.dtype == ivy.bool: + ret = cond.logical_and(input) + else: + ret = ivy.where(cond, input, 0) + return ret.astype(input_type) + + @with_supported_dtypes( {"2.2 and below": ("float32", "float64", "int32", "int64")}, "torch" ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index cb1f13a438072..5c91c765d6a35 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -424,6 +424,34 @@ def test_torch_det( ) +@handle_frontend_test( + fn_tree="torch.diag_embed", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(min_num_dims=1, max_num_dims=2), key="shape"), + ), +) +def test_torch_diag_embed( + *, + dtype_and_values, + test_flags, + on_device, + fn_tree, + frontend, + backend_fw, +): + input_dtype, value = dtype_and_values + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + ) + + # eig # TODO: Test for all valid dtypes once ivy.eig supports complex data types @handle_frontend_test(