Skip to content

Commit

Permalink
feat: implement diag_embed for torch frontend and the test
Browse files Browse the repository at this point in the history
the implementation might be pushed to backends to simplify frontned.
the test is minimal and needs to test for dims and offsets which is doable locally but can cause random healthcheck failure. looking to fix and push.
  • Loading branch information
Ishticode committed Feb 17, 2024
1 parent 8c8fd6d commit 5f50540
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
49 changes: 49 additions & 0 deletions ivy/functional/frontends/torch/miscellaneous_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,55 @@ def diag(input, diagonal=0, *, out=None):
return ivy.diag(input, k=diagonal)


@to_ivy_arrays_and_back
def diag_embed(
input,
offset=0,
dim1=-2,
dim2=-1,
):
def _handle_dim(rank, idx):
if idx >= 0 and idx < rank:
return idx
if idx < 0:
idx = idx + rank
if idx < 0 or idx >= rank:
raise IndexError
return idx

input_type = ivy.dtype(input)
rank = input.ndim + 1
dim1 = _handle_dim(rank, dim1)
dim2 = _handle_dim(rank, dim2)
if dim1 > dim2:
dim1, dim2 = dim2, dim1
offset = -offset
last_dim = list(input.shape)[-1]
if offset != 0:
# add padding to match the new size
t_shape = list(input.shape)
t_shape[-1] = abs(offset)
z = ivy.zeros(t_shape, dtype=input.dtype, device=input.device)
pair = (z, input) if offset > 0 else (input, z)
input = ivy.concat(pair, axis=-1)
last_dim += abs(offset)
input = input.expand_dims(axis=dim1).moveaxis(-1, dim2)
# generate ranges shifting indices based on offset
a_range = ivy.arange(last_dim, device=input.device, dtype=ivy.int64)
b_range = ivy.arange(
offset, last_dim + offset, device=input.device, dtype=ivy.int64
)
# broadcast
cond = a_range == b_range.expand_dims(axis=-1)
cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(input.shape))]
cond = cond.reshape(cond_shape)
if input.dtype == ivy.bool:
ret = cond.logical_and(input)
else:
ret = ivy.where(cond, input, 0)
return ret.astype(input_type)


@with_supported_dtypes(
{"2.2 and below": ("float32", "float64", "int32", "int64")}, "torch"
)
Expand Down
28 changes: 28 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,34 @@ def test_torch_det(
)


@handle_frontend_test(
fn_tree="torch.diag_embed",
dtype_and_values=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
shape=st.shared(helpers.get_shape(min_num_dims=1, max_num_dims=2), key="shape"),
),
)
def test_torch_diag_embed(
*,
dtype_and_values,
test_flags,
on_device,
fn_tree,
frontend,
backend_fw,
):
input_dtype, value = dtype_and_values
helpers.test_frontend_function(
input_dtypes=input_dtype,
backend_to_test=backend_fw,
test_flags=test_flags,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
input=value[0],
)


# eig
# TODO: Test for all valid dtypes once ivy.eig supports complex data types
@handle_frontend_test(
Expand Down

0 comments on commit 5f50540

Please sign in to comment.