Skip to content

Commit

Permalink
added lu_factor to torch frontend. Tests failing (#10365)
Browse files Browse the repository at this point in the history
Co-authored-by: WilliamHirst bod_holthe@outlook.com
  • Loading branch information
TriniKiskadee authored Feb 26, 2023
1 parent 48df381 commit e5498b7
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 0 deletions.
4 changes: 4 additions & 0 deletions ivy/functional/frontends/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,5 +185,9 @@ def tensorsolve(A, B, dims=None, *, out=None):

@to_ivy_arrays_and_back
@with_unsupported_dtypes({"1.11.0 and below": ("bfloat16", "float16")}, "torch")
def lu_factor(A, *, pivot=True, out=None):
return ivy.lu_factor(A, pivot=pivot, out=out)


def matmul(input, other, *, out=None):
return ivy.matmul(input, other, out=out)
33 changes: 33 additions & 0 deletions ivy/functional/ivy/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2692,3 +2692,36 @@ def vector_to_skew_symmetric_matrix(
"""
return current_backend(vector).vector_to_skew_symmetric_matrix(vector, out=out)


@to_native_arrays_and_back
@handle_out_argument
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
def lu_factor(
A: Union[ivy.Array, ivy.NativeArray],
/,
*,
pivot: Optional[bool] = True,
out: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
) -> Tuple[Union[ivy.Array, ivy.NativeArray], Union[ivy.Array, ivy.NativeArray]]:
"""
Parameters
----------
A
tensor of shape (*, m, n) where * is zero or more batch dimensions.
pivot
Whether to compute the LU decomposition with partial pivoting, or the regular LU
decomposition. pivot = False not supported on CPU. Default: True.
out
tuple of two tensors to write the output to. Ignored if None. Default: None.
Returns
-------
ret
A named tuple (LU, pivots).
"""
return current_backend(A).lu_factor(A, pivot=pivot, out=out)
72 changes: 72 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,78 @@ def test_torch_tensorsolve(
)


# lu_factor
@st.composite
def _lu_factor_helper(draw):
# generate input matrix of shape (*, m, n) and where '*' is one or more
# batch dimensions
input_dtype = draw(
helpers.get_dtypes("float")
)

dim1 = draw(helpers.ints(min_value=2, max_value=3))
dim2 = draw(helpers.ints(min_value=2, max_value=3))
# batch_dim = draw(helpers.ints(min_value=0, max_value=2))
batch_dim = 0

if batch_dim == 0:
input_matrix = draw(
helpers.array_values(
dtype=input_dtype[0],
shape=(dim1, dim2),
min_value=-1,
max_value=1,
)
)
else:
input_matrix = draw(
helpers.array_values(
dtype=input_dtype[0],
shape=(batch_dim, dim1, dim2),
min_value=-1,
max_value=1,
)
)

return input_dtype, input_matrix


@handle_frontend_test(
fn_tree="torch.linalg.lu_factor",
input_dtype_and_input=_lu_factor_helper(),
)
def test_torch_lu_factor(
*,
input_dtype_and_input,
on_device,
fn_tree,
frontend,
test_flags,
):
dtype, input = input_dtype_and_input
ret, frontend_ret = helpers.test_frontend_function(
input_dtypes=dtype,
test_flags=test_flags,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
test_values=False,
# rtol=1e-03,
# atol=1e-02,
A=input,
)
ret = [ivy.to_numpy(x) for x in ret]
frontend_ret = [np.asarray(x) for x in frontend_ret]

LU, pivot = ret
frontend_LU, frontend_pivot = frontend_ret

assert_all_close(
ret_np=[LU, pivot],
ret_from_gt_np=[frontend_LU, frontend_pivot],
ground_truth_backend=frontend


@handle_frontend_test(
fn_tree="torch.linalg.matmul",
dtype_x=helpers.dtype_and_values(
Expand Down

0 comments on commit e5498b7

Please sign in to comment.