diff --git a/ivy/functional/backends/jax/linear_algebra.py b/ivy/functional/backends/jax/linear_algebra.py index e223326c8729b..660f000464d94 100644 --- a/ivy/functional/backends/jax/linear_algebra.py +++ b/ivy/functional/backends/jax/linear_algebra.py @@ -422,21 +422,17 @@ def vector_norm( ) -> JaxArray: if dtype and x.dtype != dtype: x = x.astype(dtype) - - ret_scalar = False - if x.ndim == 0: - x = jnp.expand_dims(x, 0) - ret_scalar = True - - if axis is None: - x = x.reshape([-1]) - elif isinstance(axis, list): - axis = tuple(axis) - - jnp_normalized_vector = jnp.linalg.norm(x, ord, axis, keepdims) - if ret_scalar: - jnp_normalized_vector = jnp.squeeze(jnp_normalized_vector) - return jnp_normalized_vector + abs_x = jnp.abs(x) + if ord == 0: + return jnp.sum( + (abs_x != 0).astype(abs_x.dtype), axis=axis, keepdims=keepdims, out=out + ) + elif ord == inf: + return jnp.max(abs_x, axis=axis, keepdims=keepdims, out=out) + elif ord == -inf: + return jnp.min(abs_x, axis=axis, keepdims=keepdims, out=out) + else: + return jnp.sum(abs_x**ord, axis=axis, keepdims=keepdims) ** (1.0 / ord) # Extra # diff --git a/ivy/functional/backends/numpy/linear_algebra.py b/ivy/functional/backends/numpy/linear_algebra.py index bfe0c65ac47e1..e1334a89012cc 100644 --- a/ivy/functional/backends/numpy/linear_algebra.py +++ b/ivy/functional/backends/numpy/linear_algebra.py @@ -378,21 +378,21 @@ def vector_norm( ) -> np.ndarray: if dtype and x.dtype != dtype: x = x.astype(dtype) - - ret_scalar = False - if x.ndim == 0: - x = np.expand_dims(x, 0) - ret_scalar = True - - if axis is None: - x = x.reshape([-1]) - elif isinstance(axis, list): + abs_x = np.abs(x) + if isinstance(axis, list): axis = tuple(axis) - - np_normalized_vector = np.linalg.norm(x, ord, axis, keepdims) - if ret_scalar: - np_normalized_vector = np.squeeze(np_normalized_vector) - return np_normalized_vector + if ord == 0: + return np.sum( + (abs_x != 0).astype(abs_x.dtype), axis=axis, keepdims=keepdims, out=out + ) + elif ord == inf: + return np.max(abs_x, axis=axis, keepdims=keepdims, out=out) + elif ord == -inf: + return np.min(abs_x, axis=axis, keepdims=keepdims, out=out) + else: + return ( + np.sum(abs_x**ord, axis=axis, keepdims=keepdims) ** (1.0 / ord) + ).astype(abs_x.dtype) # Extra # diff --git a/ivy/functional/backends/paddle/linear_algebra.py b/ivy/functional/backends/paddle/linear_algebra.py index 15f49c63c5522..9ebfc8f00acbd 100644 --- a/ivy/functional/backends/paddle/linear_algebra.py +++ b/ivy/functional/backends/paddle/linear_algebra.py @@ -591,37 +591,26 @@ def vector_norm( dtype: Optional[paddle.dtype] = None, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: - ret_scalar = False - dtype = dtype if dtype is not None else x.dtype - if dtype in ["complex64", "complex128"]: - dtype = "float" + str(ivy.dtype_bits(dtype) // 2) - if x.ndim == 0: - x = paddle_backend.expand_dims(x, axis=0) - ret_scalar = True - - if x.dtype in [ - paddle.int8, - paddle.int16, - paddle.int32, - paddle.int64, - paddle.uint8, - paddle.float16, - paddle.complex64, - paddle.complex128, - paddle.bool, - ]: - if paddle.is_complex(x): - x = paddle.abs(x) - ret = paddle.norm(x, p=ord, axis=axis, keepdim=keepdims).astype(dtype) - else: - ret = paddle.norm( - x.cast("float32"), p=ord, axis=axis, keepdim=keepdims - ).astype(dtype) + if dtype and x.dtype != dtype: + x = x.astype(dtype) + abs_x = paddle_backend.abs(x) + if ord == 0: + return paddle_backend.sum( + (abs_x != 0).astype(abs_x.dtype), axis=axis, keepdims=keepdims + ) + elif ord == inf: + return paddle_backend.max(abs_x, axis=axis, keepdims=keepdims) + elif ord == -inf: + return paddle_backend.min(abs_x, axis=axis, keepdims=keepdims) else: - ret = paddle.norm(x, p=ord, axis=axis, keepdim=keepdims).astype(dtype) - if ret_scalar or (x.ndim == 1 and not keepdims): - ret = paddle_backend.squeeze(ret, axis=axis) - return ret + return paddle_backend.pow( + paddle_backend.sum( + paddle_backend.pow(abs_x, ord), + axis=axis, + keepdims=keepdims, + ), + (1.0 / ord), + ) # Extra # diff --git a/ivy/functional/backends/tensorflow/linear_algebra.py b/ivy/functional/backends/tensorflow/linear_algebra.py index 9d87364f692b4..edf7dd1bd8d7c 100644 --- a/ivy/functional/backends/tensorflow/linear_algebra.py +++ b/ivy/functional/backends/tensorflow/linear_algebra.py @@ -655,25 +655,15 @@ def vector_norm( ) -> Union[tf.Tensor, tf.Variable]: if dtype and x.dtype != dtype: x = tf.cast(x, dtype) - # Mathematical Norms - if ord > 0: - tn_normalized_vector = tf.linalg.norm(x, ord, axis, keepdims) + abs_x = tf.abs(x) + if ord == 0: + return tf.reduce_sum(tf.cast(x != 0, abs_x.dtype), axis=axis, keepdims=keepdims) + elif ord == inf: + return tf.reduce_max(abs_x, axis=axis, keepdims=keepdims) + elif ord == -inf: + return tf.reduce_min(abs_x, axis=axis, keepdims=keepdims) else: - if ord == -float("inf"): - tn_normalized_vector = tf.reduce_min(tf.abs(x), axis, keepdims) - elif ord == 0: - tn_normalized_vector = tf.reduce_sum( - tf.cast(x != 0, x.dtype), axis, keepdims - ) - else: - tn_normalized_vector = tf.reduce_sum(tf.abs(x) ** ord, axis, keepdims) ** ( - 1.0 / ord - ) - tn_normalized_vector = tf.cast( - tn_normalized_vector, tn_normalized_vector.dtype.real_dtype - ) - - return tn_normalized_vector + return tf.reduce_sum(abs_x**ord, axis=axis, keepdims=keepdims) ** (1.0 / ord) # Extra # diff --git a/ivy/functional/frontends/paddle/tensor/tensor.py b/ivy/functional/frontends/paddle/tensor/tensor.py index 0d43f4ebdd46d..ba363afaf743c 100644 --- a/ivy/functional/frontends/paddle/tensor/tensor.py +++ b/ivy/functional/frontends/paddle/tensor/tensor.py @@ -332,3 +332,473 @@ def fmin(self, y, name=None): def minimum(self, y, name=None): y_ivy = _to_ivy_array(y) return ivy.minimum(self._ivy_array, y_ivy) +======= +# local +import ivy +import ivy.functional.frontends.paddle as paddle_frontend +from ivy.func_wrapper import with_supported_dtypes, with_unsupported_dtypes +from ivy.functional.frontends.paddle.func_wrapper import _to_ivy_array + + +class Tensor: + def __init__(self, array, dtype=None, place="cpu", stop_gradient=True): + self._ivy_array = ( + ivy.array(array, dtype=dtype, device=place) + if not isinstance(array, ivy.Array) + else array + ) + self._dtype = dtype + self._place = place + self._stop_gradient = stop_gradient + + def __repr__(self): + return ( + str(self._ivy_array.__repr__()) + .replace("ivy.array", "ivy.frontends.paddle.Tensor") + .replace("dev", "place") + ) + + # Properties # + # ---------- # + + @property + def ivy_array(self): + return self._ivy_array + + @property + def place(self): + return self.ivy_array.device + + @property + def dtype(self): + return self._ivy_array.dtype + + @property + def shape(self): + return self._ivy_array.shape + + @property + def ndim(self): + return self.dim() + + # Setters # + # --------# + + @ivy_array.setter + def ivy_array(self, array): + self._ivy_array = ( + ivy.array(array) if not isinstance(array, ivy.Array) else array + ) + + # Special Methods # + # -------------------# + + def __getitem__(self, item): + ivy_args = ivy.nested_map([self, item], _to_ivy_array) + ret = ivy.get_item(*ivy_args) + return paddle_frontend.Tensor(ret) + + def __setitem__(self, item, value): + item, value = ivy.nested_map([item, value], _to_ivy_array) + self.ivy_array[item] = value + + def __iter__(self): + if self.ndim == 0: + raise TypeError("iteration over a 0-d tensor not supported") + for i in range(self.shape[0]): + yield self[i] + + # Instance Methods # + # ---------------- # + + def reshape(self, *args, shape=None): + if args and shape: + raise TypeError("reshape() got multiple values for argument 'shape'") + if shape is not None: + return paddle_frontend.reshape(self._ivy_array, shape) + if args: + if isinstance(args[0], (tuple, list)): + shape = args[0] + return paddle_frontend.reshape(self._ivy_array, shape) + else: + return paddle_frontend.reshape(self._ivy_array, args) + return paddle_frontend.reshape(self._ivy_array) + + def dim(self): + return self.ivy_array.ndim + + @with_unsupported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def abs(self): + return paddle_frontend.abs(self) + + @with_unsupported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def ceil(self): + return paddle_frontend.ceil(self) + + @with_unsupported_dtypes({"2.5.0 and below": ("float16",)}, "paddle") + def asinh(self, name=None): + return ivy.asinh(self._ivy_array) + + @with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle") + def asin(self, name=None): + return ivy.asin(self._ivy_array) + + @with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle") + def cosh(self, name=None): + return ivy.cosh(self._ivy_array) + + @with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle") + def log(self, name=None): + return ivy.log(self._ivy_array) + + @with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle") + def sin(self, name=None): + return ivy.sin(self._ivy_array) + + @with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle") + def sinh(self, name=None): + return ivy.sinh(self._ivy_array) + + @with_unsupported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def argmax(self, axis=None, keepdim=False, dtype=None, name=None): + return ivy.argmax(self._ivy_array, axis=axis, keepdims=keepdim, dtype=dtype) + + @with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle") + def sqrt(self, name=None): + return ivy.sqrt(self._ivy_array) + + @with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle") + def cos(self, name=None): + return ivy.cos(self._ivy_array) + + @with_unsupported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def exp(self, name=None): + return ivy.exp(self._ivy_array) + + @with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle") + def erf(self, name=None): + return ivy.erf(self._ivy_array) + + @with_unsupported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def subtract(self, y, name=None): + y_ivy = _to_ivy_array(y) + return ivy.subtract(self._ivy_array, y_ivy) + + @with_unsupported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def log10(self, name=None): + return ivy.log10(self._ivy_array) + + @with_unsupported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def argsort(self, axis=-1, descending=False, name=None): + return ivy.argsort(self._ivy_array, axis=axis, descending=descending) + + @with_unsupported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def floor(self, name=None): + return ivy.floor(self._ivy_array) + + @with_unsupported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def floor_(self): + return ivy.floor(self._ivy_array) + + @with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle") + def tanh(self, name=None): + return ivy.tanh(self._ivy_array) + + @with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle") + def add_(self, name=None): + return ivy.add(self._ivy_array) + + @with_supported_dtypes( + {"2.5.0 and below": ("float16", "float32", "float64", "int32", "int64")}, + "paddle", + ) + def isinf(self, name=None): + return ivy.isinf(self._ivy_array) + + @with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle") + def square(self, name=None): + return ivy.square(self._ivy_array) + + @with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle") + def cholesky(self, upper=False, name=None): + return ivy.cholesky(self._ivy_array, upper=upper) + + @with_unsupported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def multiply(self, y, name=None): + return paddle_frontend.multiply(self, y) + + @with_supported_dtypes( + {"2.5.0 and below": ("float16", "float32", "float64", "int32", "int64")}, + "paddle", + ) + def isfinite(self, name=None): + return ivy.isfinite(self._ivy_array) + + @with_supported_dtypes({"2.4.2 and below": ("float16", "bfloat16")}, "paddle") + def all(self, axis=None, keepdim=False, dtype=None, name=None): + return ivy.all(self.ivy_array, axis=axis, keepdims=keepdim, dtype=dtype) + + @with_supported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def allclose(self, other, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): + return ivy.allclose( + self._ivy_array, other, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + + @with_unsupported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def sort(self, axis=-1, descending=False, name=None): + return ivy.sort(self._ivy_array, axis=axis, descending=descending) + + @with_supported_dtypes( + { + "2.4.2 and below": ( + "bool", + "uint8", + "int8", + "int16", + "int32", + "int64", + ) + }, + "paddle", + ) + def bitwise_and(self, y, out=None, name=None): + return paddle_frontend.bitwise_and(self, y) + + @with_supported_dtypes( + { + "2.5.0 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", + ) + def logical_or(self, y, out=None, name=None): + return paddle_frontend.logical_or(self, y, out=out) + + @with_supported_dtypes( + {"2.5.0 and below": ("bool", "uint8", "int8", "int16", "int32", "int64")}, + "paddle", + ) + def bitwise_xor(self, y, out=None, name=None): + return paddle_frontend.bitwise_xor(self, y) + + @with_supported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def any(self, axis=None, keepdim=False, name=None): + return ivy.any(self._ivy_array, axis=axis, keepdims=keepdim) + + @with_supported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def astype(self, dtype): + return ivy.astype(self._ivy_array, dtype=dtype) + + @with_supported_dtypes( + { + "2.5.0 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + ) + }, + "paddle", + ) + def bitwise_or(self, y, out=None, name=None): + return paddle_frontend.bitwise_or(self, y, out=out) + + @with_supported_dtypes( + { + "2.5.0 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", + ) + def logical_xor(self, y, out=None, name=None): + return paddle_frontend.logical_xor(self, y, out=out) + + @with_unsupported_dtypes( + { + "2.5.0 and below": ( + "bool", + "uint8", + "int8", + "int16", + "complex64", + "complex128", + ) + }, + "paddle", + ) + def greater_than(self, y, name=None): + return paddle_frontend.greater_than(self, y) + + @with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle") + def rsqrt(self, name=None): + return ivy.reciprocal(ivy.sqrt(self._ivy_array)) + + @with_supported_dtypes( + { + "2.5.0 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", + ) + def logical_and(self, y, out=None, name=None): + return paddle_frontend.logical_and(self, y, out=out) + + @with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle") + def divide(self, y, name=None): + return paddle_frontend.divide(self, y) + + @with_unsupported_dtypes( + { + "2.5.0 and below": ( + "bool", + "uint8", + "int8", + "int16", + "complex64", + "complex128", + ) + }, + "paddle", + ) + def less_than(self, y, name=None): + return paddle_frontend.less_than(self, y) + + @with_unsupported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def cumprod(self, dim=None, dtype=None, name=None): + return ivy.cumprod(self._ivy_array, axis=dim, dtype=dtype) + + @with_unsupported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def cumsum(self, axis=None, dtype=None, name=None): + return ivy.cumsum(self._ivy_array, axis=axis, dtype=dtype) + + @with_supported_dtypes( + {"2.5.0 and below": ("complex64", "complex128", "float32", "float64")}, + "paddle", + ) + def angle(self, name=None): + return ivy.angle(self._ivy_array) + + @with_unsupported_dtypes( + { + "2.5.0 and below": ( + "uint8", + "int8", + "int16", + "complex64", + "complex128", + ) + }, + "paddle", + ) + def equal(self, y, name=None): + return paddle_frontend.equal(self, y) + + @with_unsupported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def rad2deg(self, name=None): + return ivy.rad2deg(self._ivy_array) + + @with_unsupported_dtypes( + { + "2.5.0 and below": ( + "uint8", + "int8", + "int16", + "float16", + "complex64", + "complex128", + ) + }, + "paddle", + ) + def equal_all(self, y, name=None): + y_ivy = _to_ivy_array(y) + return ivy.array_equal(self._ivy_array, y_ivy) + + @with_unsupported_dtypes({"2.5.0 and below": "bfloat16"}, "paddle") + def fmax(self, y, name=None): + y_ivy = _to_ivy_array(y) + return ivy.fmax(self._ivy_array, y_ivy) + + @with_unsupported_dtypes({"2.5.0 and below": "bfloat16"}, "paddle") + def fmin(self, y, name=None): + y_ivy = _to_ivy_array(y) + return ivy.fmin(self._ivy_array, y_ivy) + + @with_supported_dtypes( + {"2.5.0 and below": ("float32", "float64", "int32", "int64")}, "paddle" + ) + def minimum(self, y, name=None): + y_ivy = _to_ivy_array(y) + return ivy.minimum(self._ivy_array, y_ivy) + + @with_supported_dtypes( + {"2.5.0 and below": ("float32", "float64", "int32", "int64")}, "paddle" + ) + def max(self, axis=None, keepdim=False, name=None): + return ivy.max(self._ivy_array, axis=axis, keepdims=keepdim) + + @with_unsupported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle") + def deg2rad(self, name=None): + return ivy.deg2rad(self._ivy_array) + + + @with_supported_dtypes( + {"2.5.0 and below": ("float32", "float64", "int32", "int64", "bool")}, "paddle" + ) + def rot90(self, k=1, axes=(0, 1), name=None): + return ivy.rot90(self._ivy_array, k=k, axes=axes) + + @with_supported_dtypes( + {"2.5.0 and below": ("complex64", "complex128")}, + "paddle", + ) + def imag(self, name=None): + return paddle_frontend.imag(self) + + def is_tensor(self): + return paddle_frontend.is_tensor(self._ivy_array) + + @with_supported_dtypes( + { + "2.5.0 and below": ( + "float32", + "float64", + ) + }, + "paddle", + ) + def isclose(self, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): + return paddle_frontend.isclose( + self, y, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + + @with_supported_dtypes({"2.5.0 and below": ("int32", "int64")}, "paddle") + def floor_divide(self, y, name=None): + y_ivy = y._ivy_array if isinstance(y, Tensor) else _to_ivy_array(y) + return ivy.floor_divide(self._ivy_array, y_ivy) + + @with_unsupported_dtypes({"2.4.2 and below": ("int16", "float16")}, "paddle") + def conj(self, name=None): + return ivy.conj(self._ivy_array) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_paddle_tensor.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_paddle_tensor.py index fec3a9d69c078..9defb9e67f8f0 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_paddle_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_paddle_tensor.py @@ -9,6 +9,9 @@ import ivy_tests.test_ivy.helpers as helpers from ivy.functional.frontends.paddle import Tensor from ivy_tests.test_ivy.helpers import handle_frontend_method +from ivy_tests.test_ivy.test_functional.test_experimental.test_core.test_manipulation import ( + _get_dtype_values_k_axes_for_rot90, +) CLASS_TREE = "ivy.functional.frontends.paddle.Tensor" @@ -1683,6 +1686,39 @@ def test_paddle_angle( ) +# equal +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="paddle.to_tensor", + method_name="equal", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, + ), +) +def test_paddle_equal( + dtypes_and_x, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, +): + input_dtype, x = dtypes_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"y": x[1]}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # rad2deg @handle_frontend_method( class_tree=CLASS_TREE, @@ -1947,6 +1983,47 @@ def test_paddle_deg2rad( frontend=frontend, on_device=on_device, ) + + +# rot90 +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="paddle.to_tensor", + method_name="rot90", + dtype_m_k_axes=_get_dtype_values_k_axes_for_rot90( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=3, + max_num_dims=6, + min_dim_size=1, + max_dim_size=10, + ), +) +def test_paddle_rot90( + dtype_m_k_axes, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, +): + input_dtype, values, k, axes = dtype_m_k_axes + + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + init_all_as_kwargs_np={ + "data": values, + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "k": k, + "axes": axes, + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) # imag diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_linalg.py b/ivy_tests/test_ivy/test_functional/test_core/test_linalg.py index ab598e3c7eb7d..cbfe79b811043 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_linalg.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_linalg.py @@ -874,12 +874,12 @@ def test_vecdot( max_value=1e04, abs_smallest_val=1e-04, max_axes_size=2, - force_int_axis=True, + allow_neg_axes=True, ), kd=st.booleans(), ord=st.one_of( - helpers.ints(min_value=0, max_value=5), - helpers.floats(min_value=1.0, max_value=5.0), + helpers.ints(min_value=-5, max_value=5), + helpers.floats(min_value=-5, max_value=5.0), st.sampled_from((float("inf"), -float("inf"))), ), dtype=helpers.get_dtypes("numeric", full=False, none=True), @@ -897,6 +897,10 @@ def test_vector_norm( ground_truth_backend, ): x_dtype, x, axis = dtype_values_axis + # to avoid tuple axis with only one axis as force_int_axis can't generate + # axis with two axes + if isinstance(axis, tuple) and len(axis) == 1: + axis = axis[0] helpers.test_function( ground_truth_backend=ground_truth_backend, input_dtypes=x_dtype,