Skip to content

Commit

Permalink
auto-lint code
Browse files Browse the repository at this point in the history
  • Loading branch information
ivy-dev-bot committed Jul 13, 2024
1 parent 084047b commit 55271e9
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 80 deletions.
11 changes: 3 additions & 8 deletions ivy/functional/backends/jax/module.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# global
from __future__ import annotations
import re
import os
import jax
from flax import linen as nn
import jax.tree_util as tree
import jax.numpy as jnp
import functools
from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, Union
from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type
import inspect
from collections import OrderedDict

Expand Down Expand Up @@ -451,9 +450,7 @@ def register_parameter(self, name: str, value: jax.Array):
def train(self, mode: bool = True):
self._training = mode
for module in self.children():
if isinstance(module, nn.Module) and not hasattr(
module, "train"
):
if isinstance(module, nn.Module) and not hasattr(module, "train"):
module.trainable = mode
continue
module.train(mode)
Expand Down Expand Up @@ -845,9 +842,7 @@ def register_parameter(self, name: str, value: jax.Array):
def train(self, mode: bool = True):
self._training = mode
for module in self.children():
if isinstance(module, nn.Module) and not hasattr(
module, "train"
):
if isinstance(module, nn.Module) and not hasattr(module, "train"):
module.trainable = mode
continue
module.train(mode)
Expand Down
64 changes: 32 additions & 32 deletions ivy/functional/frontends/jax/numpy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,38 @@ def ifft2(a, s=None, axes=(-2, -1), norm=None):
return ivy.array(ivy.ifft2(a, s=s, dim=axes, norm=norm), dtype=ivy.dtype(a))


@with_unsupported_dtypes({"1.24.3 and below": ("complex64", "bfloat16")}, "numpy")
@to_ivy_arrays_and_back
def ifftn(a, s=None, axes=None, norm=None):
a = ivy.asarray(a, dtype=ivy.complex128)
a = ivy.ifftn(a, s=s, axes=axes, norm=norm)
return a


@to_ivy_arrays_and_back
def ifftshift(x, axes=None):
if not ivy.is_array(x):
raise ValueError("Input 'x' must be an array")

# Get the shape of x
shape = ivy.shape(x)

# If axes is None, shift all axes
if axes is None:
axes = tuple(range(x.ndim))

# Convert axes to a list if it's not already
axes = [axes] if isinstance(axes, int) else list(axes)

# Perform the shift for each axis
for axis in axes:
axis_size = shape[axis]
shift = -ivy.floor(axis_size / 2).astype(ivy.int32)
result = ivy.roll(x, shift, axis=axis)

return result


@to_ivy_arrays_and_back
def irfftn(a, s=None, axes=None, norm=None):
x = ivy.asarray(a)
Expand Down Expand Up @@ -121,38 +153,6 @@ def irfftn(a, s=None, axes=None, norm=None):
return result_t


@with_unsupported_dtypes({"1.24.3 and below": ("complex64", "bfloat16")}, "numpy")
@to_ivy_arrays_and_back
def ifftn(a, s=None, axes=None, norm=None):
a = ivy.asarray(a, dtype=ivy.complex128)
a = ivy.ifftn(a, s=s, axes=axes, norm=norm)
return a


@to_ivy_arrays_and_back
def ifftshift(x, axes=None):
if not ivy.is_array(x):
raise ValueError("Input 'x' must be an array")

# Get the shape of x
shape = ivy.shape(x)

# If axes is None, shift all axes
if axes is None:
axes = tuple(range(x.ndim))

# Convert axes to a list if it's not already
axes = [axes] if isinstance(axes, int) else list(axes)

# Perform the shift for each axis
for axis in axes:
axis_size = shape[axis]
shift = -ivy.floor(axis_size / 2).astype(ivy.int32)
result = ivy.roll(x, shift, axis=axis)

return result


@to_ivy_arrays_and_back
@with_unsupported_dtypes({"1.25.2 and below": ("float16", "bfloat16")}, "numpy")
def rfft(a, n=None, axis=-1, norm=None):
Expand Down
80 changes: 40 additions & 40 deletions ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,46 +245,6 @@ def test_jax_numpy_ifft2(
)


# irfftn
@handle_frontend_test(
fn_tree="jax.numpy.fft.irfftn",
dtype_x_axis=helpers.dtype_values_axis(
available_dtypes=helpers.get_dtypes("numeric"),
min_value=-10,
max_value=10,
min_num_dims=1,
max_num_dims=3,
min_dim_size=2,
max_dim_size=10,
valid_axis=True,
force_int_axis=True,
num_arrays=1,
),
norm=st.sampled_from(["backward", "ortho", "forward"]),
)
def test_jax_numpy_irfftn(
dtype_x_axis,
norm,
frontend,
test_flags,
fn_tree,
backend_fw,
):
input_dtypes, x, _ = dtype_x_axis
helpers.test_frontend_function(
input_dtypes=input_dtypes,
backend_to_test=backend_fw,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
atol=1e-3,
a=x[0],
s=None, # TODO: also test cases where `s` and `axes` are not None
axes=None,
norm=norm,
)


@handle_frontend_test(
fn_tree="jax.numpy.fft.ifftn",
dtype_and_x=_x_and_ifftn_jax(),
Expand Down Expand Up @@ -348,6 +308,46 @@ def test_jax_numpy_ifftshift(
)


# irfftn
@handle_frontend_test(
fn_tree="jax.numpy.fft.irfftn",
dtype_x_axis=helpers.dtype_values_axis(
available_dtypes=helpers.get_dtypes("numeric"),
min_value=-10,
max_value=10,
min_num_dims=1,
max_num_dims=3,
min_dim_size=2,
max_dim_size=10,
valid_axis=True,
force_int_axis=True,
num_arrays=1,
),
norm=st.sampled_from(["backward", "ortho", "forward"]),
)
def test_jax_numpy_irfftn(
dtype_x_axis,
norm,
frontend,
test_flags,
fn_tree,
backend_fw,
):
input_dtypes, x, _ = dtype_x_axis
helpers.test_frontend_function(
input_dtypes=input_dtypes,
backend_to_test=backend_fw,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
atol=1e-3,
a=x[0],
s=None, # TODO: also test cases where `s` and `axes` are not None
axes=None,
norm=norm,
)


# rfft
@handle_frontend_test(
fn_tree="jax.numpy.fft.rfft",
Expand Down

0 comments on commit 55271e9

Please sign in to comment.