From 55271e9b1ded9282247f3bb838615d72b2696785 Mon Sep 17 00:00:00 2001 From: ivy-dev-bot Date: Sat, 13 Jul 2024 02:27:16 +0000 Subject: [PATCH] auto-lint code --- ivy/functional/backends/jax/module.py | 11 +-- ivy/functional/frontends/jax/numpy/fft.py | 64 +++++++-------- .../test_jax/test_numpy/test_fft.py | 80 +++++++++---------- 3 files changed, 75 insertions(+), 80 deletions(-) diff --git a/ivy/functional/backends/jax/module.py b/ivy/functional/backends/jax/module.py index 01ff617139b6..44163895c853 100644 --- a/ivy/functional/backends/jax/module.py +++ b/ivy/functional/backends/jax/module.py @@ -1,13 +1,12 @@ # global from __future__ import annotations import re -import os import jax from flax import linen as nn import jax.tree_util as tree import jax.numpy as jnp import functools -from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, Union +from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type import inspect from collections import OrderedDict @@ -451,9 +450,7 @@ def register_parameter(self, name: str, value: jax.Array): def train(self, mode: bool = True): self._training = mode for module in self.children(): - if isinstance(module, nn.Module) and not hasattr( - module, "train" - ): + if isinstance(module, nn.Module) and not hasattr(module, "train"): module.trainable = mode continue module.train(mode) @@ -845,9 +842,7 @@ def register_parameter(self, name: str, value: jax.Array): def train(self, mode: bool = True): self._training = mode for module in self.children(): - if isinstance(module, nn.Module) and not hasattr( - module, "train" - ): + if isinstance(module, nn.Module) and not hasattr(module, "train"): module.trainable = mode continue module.train(mode) diff --git a/ivy/functional/frontends/jax/numpy/fft.py b/ivy/functional/frontends/jax/numpy/fft.py index 77ded083de99..dadd61b25d07 100644 --- a/ivy/functional/frontends/jax/numpy/fft.py +++ b/ivy/functional/frontends/jax/numpy/fft.py @@ -70,6 +70,38 @@ def ifft2(a, s=None, axes=(-2, -1), norm=None): return ivy.array(ivy.ifft2(a, s=s, dim=axes, norm=norm), dtype=ivy.dtype(a)) +@with_unsupported_dtypes({"1.24.3 and below": ("complex64", "bfloat16")}, "numpy") +@to_ivy_arrays_and_back +def ifftn(a, s=None, axes=None, norm=None): + a = ivy.asarray(a, dtype=ivy.complex128) + a = ivy.ifftn(a, s=s, axes=axes, norm=norm) + return a + + +@to_ivy_arrays_and_back +def ifftshift(x, axes=None): + if not ivy.is_array(x): + raise ValueError("Input 'x' must be an array") + + # Get the shape of x + shape = ivy.shape(x) + + # If axes is None, shift all axes + if axes is None: + axes = tuple(range(x.ndim)) + + # Convert axes to a list if it's not already + axes = [axes] if isinstance(axes, int) else list(axes) + + # Perform the shift for each axis + for axis in axes: + axis_size = shape[axis] + shift = -ivy.floor(axis_size / 2).astype(ivy.int32) + result = ivy.roll(x, shift, axis=axis) + + return result + + @to_ivy_arrays_and_back def irfftn(a, s=None, axes=None, norm=None): x = ivy.asarray(a) @@ -121,38 +153,6 @@ def irfftn(a, s=None, axes=None, norm=None): return result_t -@with_unsupported_dtypes({"1.24.3 and below": ("complex64", "bfloat16")}, "numpy") -@to_ivy_arrays_and_back -def ifftn(a, s=None, axes=None, norm=None): - a = ivy.asarray(a, dtype=ivy.complex128) - a = ivy.ifftn(a, s=s, axes=axes, norm=norm) - return a - - -@to_ivy_arrays_and_back -def ifftshift(x, axes=None): - if not ivy.is_array(x): - raise ValueError("Input 'x' must be an array") - - # Get the shape of x - shape = ivy.shape(x) - - # If axes is None, shift all axes - if axes is None: - axes = tuple(range(x.ndim)) - - # Convert axes to a list if it's not already - axes = [axes] if isinstance(axes, int) else list(axes) - - # Perform the shift for each axis - for axis in axes: - axis_size = shape[axis] - shift = -ivy.floor(axis_size / 2).astype(ivy.int32) - result = ivy.roll(x, shift, axis=axis) - - return result - - @to_ivy_arrays_and_back @with_unsupported_dtypes({"1.25.2 and below": ("float16", "bfloat16")}, "numpy") def rfft(a, n=None, axis=-1, norm=None): diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index 1973c2da0d8c..8b25279a3aad 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -245,46 +245,6 @@ def test_jax_numpy_ifft2( ) -# irfftn -@handle_frontend_test( - fn_tree="jax.numpy.fft.irfftn", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_value=-10, - max_value=10, - min_num_dims=1, - max_num_dims=3, - min_dim_size=2, - max_dim_size=10, - valid_axis=True, - force_int_axis=True, - num_arrays=1, - ), - norm=st.sampled_from(["backward", "ortho", "forward"]), -) -def test_jax_numpy_irfftn( - dtype_x_axis, - norm, - frontend, - test_flags, - fn_tree, - backend_fw, -): - input_dtypes, x, _ = dtype_x_axis - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - atol=1e-3, - a=x[0], - s=None, # TODO: also test cases where `s` and `axes` are not None - axes=None, - norm=norm, - ) - - @handle_frontend_test( fn_tree="jax.numpy.fft.ifftn", dtype_and_x=_x_and_ifftn_jax(), @@ -348,6 +308,46 @@ def test_jax_numpy_ifftshift( ) +# irfftn +@handle_frontend_test( + fn_tree="jax.numpy.fft.irfftn", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_value=-10, + max_value=10, + min_num_dims=1, + max_num_dims=3, + min_dim_size=2, + max_dim_size=10, + valid_axis=True, + force_int_axis=True, + num_arrays=1, + ), + norm=st.sampled_from(["backward", "ortho", "forward"]), +) +def test_jax_numpy_irfftn( + dtype_x_axis, + norm, + frontend, + test_flags, + fn_tree, + backend_fw, +): + input_dtypes, x, _ = dtype_x_axis + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + atol=1e-3, + a=x[0], + s=None, # TODO: also test cases where `s` and `axes` are not None + axes=None, + norm=norm, + ) + + # rfft @handle_frontend_test( fn_tree="jax.numpy.fft.rfft",