Skip to content

Commit

Permalink
Driver handling in svdvals function in torch_frontend (#23718)
Browse files Browse the repository at this point in the history
Co-authored-by: ivy-branch <ivy.branch@lets-unify.ai>
Co-authored-by: juliagsy <67888047+juliagsy@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 24, 2023
1 parent d8feef7 commit bb0b201
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 10 deletions.
5 changes: 4 additions & 1 deletion ivy/functional/backends/jax/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,10 @@ def svd(
{"0.4.19 and below": ("bfloat16", "float16", "complex")},
backend_version,
)
def svdvals(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
def svdvals(
x: JaxArray, /, *, driver: Optional[str] = None, out: Optional[JaxArray] = None
) -> JaxArray:
# TODO: handling the driver argument
return jnp.linalg.svd(x, compute_uv=False)


Expand Down
2 changes: 2 additions & 0 deletions ivy/functional/backends/mxnet/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,10 @@ def svdvals(
x: Union[(None, mx.ndarray.NDArray)],
/,
*,
driver: Optional[str] = None,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
) -> Union[(None, mx.ndarray.NDArray)]:
# TODO: handling the driver argument
raise IvyNotImplementedException()


Expand Down
5 changes: 4 additions & 1 deletion ivy/functional/backends/numpy/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,10 @@ def svd(


@with_unsupported_dtypes({"1.26.1 and below": ("float16",)}, backend_version)
def svdvals(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray:
def svdvals(
x: np.ndarray, /, *, driver: Optional[str] = None, out: Optional[np.ndarray] = None
) -> np.ndarray:
# TODO: handling the driver argument
return np.linalg.svd(x, compute_uv=False)


Expand Down
7 changes: 6 additions & 1 deletion ivy/functional/backends/paddle/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,8 +521,13 @@ def svd(
backend_version,
)
def svdvals(
x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None
x: paddle.Tensor,
/,
*,
driver: Optional[str] = None,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
# TODO:handling the driver argument
return paddle_backend.svd(x)[1]


Expand Down
2 changes: 2 additions & 0 deletions ivy/functional/backends/tensorflow/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,8 +564,10 @@ def svdvals(
x: Union[tf.Tensor, tf.Variable],
/,
*,
driver: Optional[str] = None,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
# TODO: handling the driver argument
ret = tf.linalg.svd(x, compute_uv=False)
return ret

Expand Down
12 changes: 9 additions & 3 deletions ivy/functional/backends/torch/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,9 +414,15 @@ def svd(
return results(D)


@with_unsupported_dtypes({"2.1.0 and below": ("float16", "bfloat16")}, backend_version)
def svdvals(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.linalg.svdvals(x, out=out)
@with_unsupported_dtypes({"2.0.1 and below": ("float16", "bfloat16")}, backend_version)
def svdvals(
x: torch.Tensor,
/,
*,
driver: Optional[str] = None,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.linalg.svdvals(x, driver=driver, out=out)


svdvals.support_native_out = True
Expand Down
6 changes: 4 additions & 2 deletions ivy/functional/frontends/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,10 @@ def svd(A, /, *, full_matrices=True, driver=None, out=None):
{"2.1.0 and below": ("float32", "float64", "complex32", "complex64")}, "torch"
)
def svdvals(A, *, driver=None, out=None):
# TODO: add handling for driver
return ivy.svdvals(A, out=out)
if driver in ["gesvd", "gesvdj", "gesvda", None]:
return ivy.svdvals(A, driver=driver, out=out)
else:
raise ValueError("Unsupported SVD driver")


@to_ivy_arrays_and_back
Expand Down
11 changes: 9 additions & 2 deletions ivy/functional/ivy/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2264,7 +2264,11 @@ def svd(
@handle_array_function
@handle_device
def svdvals(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
driver: Optional[str] = None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Return the singular values of a matrix (or a stack of matrices) ``x``.
Expand All @@ -2274,6 +2278,9 @@ def svdvals(
x
input array having shape ``(..., M, N)`` and whose innermost two dimensions form
``MxN`` matrices.
driver
optional output array,name of the cuSOLVER method to be used. This keyword argument only works on CUDA inputs.
Available options are: None, gesvd, gesvdj, and gesvda.Default: None.
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.
Expand Down Expand Up @@ -2387,7 +2394,7 @@ def svdvals(
b: ivy.array([23.16134834, 10.35037804, 4.31025076, 1.35769391])
}
"""
return current_backend(x).svdvals(x, out=out)
return current_backend(x).svdvals(x, driver=driver, out=out)


@handle_exceptions
Expand Down
3 changes: 3 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,10 +1156,12 @@ def test_torch_svd(
@handle_frontend_test(
fn_tree="torch.linalg.svdvals",
dtype_and_x=_get_dtype_and_matrix(batch=True),
driver=st.sampled_from([None, "gesvd", "gesvdj", "gesvda"]),
)
def test_torch_svdvals(
*,
dtype_and_x,
driver,
on_device,
fn_tree,
frontend,
Expand All @@ -1174,6 +1176,7 @@ def test_torch_svdvals(
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
driver=driver,
A=x[0],
)

Expand Down

0 comments on commit bb0b201

Please sign in to comment.