From 5f50540f845aac0f9743964d875de660ad3ea57f Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Sat, 17 Feb 2024 16:56:14 +0000 Subject: [PATCH] feat: implement `diag_embed` for torch frontend and the test the implementation might be pushed to backends to simplify frontned. the test is minimal and needs to test for dims and offsets which is doable locally but can cause random healthcheck failure. looking to fix and push. --- .../frontends/torch/miscellaneous_ops.py | 49 +++++++++++++++++++ .../test_frontends/test_torch/test_linalg.py | 28 +++++++++++ 2 files changed, 77 insertions(+) diff --git a/ivy/functional/frontends/torch/miscellaneous_ops.py b/ivy/functional/frontends/torch/miscellaneous_ops.py index 98074a690de62..61624d2398762 100644 --- a/ivy/functional/frontends/torch/miscellaneous_ops.py +++ b/ivy/functional/frontends/torch/miscellaneous_ops.py @@ -211,6 +211,55 @@ def diag(input, diagonal=0, *, out=None): return ivy.diag(input, k=diagonal) +@to_ivy_arrays_and_back +def diag_embed( + input, + offset=0, + dim1=-2, + dim2=-1, +): + def _handle_dim(rank, idx): + if idx >= 0 and idx < rank: + return idx + if idx < 0: + idx = idx + rank + if idx < 0 or idx >= rank: + raise IndexError + return idx + + input_type = ivy.dtype(input) + rank = input.ndim + 1 + dim1 = _handle_dim(rank, dim1) + dim2 = _handle_dim(rank, dim2) + if dim1 > dim2: + dim1, dim2 = dim2, dim1 + offset = -offset + last_dim = list(input.shape)[-1] + if offset != 0: + # add padding to match the new size + t_shape = list(input.shape) + t_shape[-1] = abs(offset) + z = ivy.zeros(t_shape, dtype=input.dtype, device=input.device) + pair = (z, input) if offset > 0 else (input, z) + input = ivy.concat(pair, axis=-1) + last_dim += abs(offset) + input = input.expand_dims(axis=dim1).moveaxis(-1, dim2) + # generate ranges shifting indices based on offset + a_range = ivy.arange(last_dim, device=input.device, dtype=ivy.int64) + b_range = ivy.arange( + offset, last_dim + offset, device=input.device, dtype=ivy.int64 + ) + # broadcast + cond = a_range == b_range.expand_dims(axis=-1) + cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(input.shape))] + cond = cond.reshape(cond_shape) + if input.dtype == ivy.bool: + ret = cond.logical_and(input) + else: + ret = ivy.where(cond, input, 0) + return ret.astype(input_type) + + @with_supported_dtypes( {"2.2 and below": ("float32", "float64", "int32", "int64")}, "torch" ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index cb1f13a438072..5c91c765d6a35 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -424,6 +424,34 @@ def test_torch_det( ) +@handle_frontend_test( + fn_tree="torch.diag_embed", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(min_num_dims=1, max_num_dims=2), key="shape"), + ), +) +def test_torch_diag_embed( + *, + dtype_and_values, + test_flags, + on_device, + fn_tree, + frontend, + backend_fw, +): + input_dtype, value = dtype_and_values + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + ) + + # eig # TODO: Test for all valid dtypes once ivy.eig supports complex data types @handle_frontend_test(