Skip to content

Commit

Permalink
feat: Updated jax version mapping from 0.4.23 to 0.4.24 (#28237)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sai-Suraj-27 authored Feb 10, 2024
1 parent 24bd1ce commit 1603886
Show file tree
Hide file tree
Showing 24 changed files with 119 additions and 119 deletions.
24 changes: 12 additions & 12 deletions ivy/functional/backends/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _array_unflatten(aux_data, children):

# update these to add new dtypes
valid_dtypes = {
"0.4.23 and below": (
"0.4.24 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
Expand All @@ -121,7 +121,7 @@ def _array_unflatten(aux_data, children):
)
}
valid_numeric_dtypes = {
"0.4.23 and below": (
"0.4.24 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
Expand All @@ -140,7 +140,7 @@ def _array_unflatten(aux_data, children):
}

valid_int_dtypes = {
"0.4.23 and below": (
"0.4.24 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
Expand All @@ -153,12 +153,12 @@ def _array_unflatten(aux_data, children):
}

valid_uint_dtypes = {
"0.4.23 and below": (ivy.uint8, ivy.uint16, ivy.uint32, ivy.uint64)
"0.4.24 and below": (ivy.uint8, ivy.uint16, ivy.uint32, ivy.uint64)
}
valid_float_dtypes = {
"0.4.23 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64)
"0.4.24 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64)
}
valid_complex_dtypes = {"0.4.23 and below": (ivy.complex64, ivy.complex128)}
valid_complex_dtypes = {"0.4.24 and below": (ivy.complex64, ivy.complex128)}


# leave these untouched
Expand All @@ -173,12 +173,12 @@ def _array_unflatten(aux_data, children):
# invalid data types

# update these to add new dtypes
invalid_dtypes = {"0.4.23 and below": ()}
invalid_numeric_dtypes = {"0.4.23 and below": ()}
invalid_int_dtypes = {"0.4.23 and below": ()}
invalid_float_dtypes = {"0.4.23 and below": ()}
invalid_uint_dtypes = {"0.4.23 and below": ()}
invalid_complex_dtypes = {"0.4.23 and below": ()}
invalid_dtypes = {"0.4.24 and below": ()}
invalid_numeric_dtypes = {"0.4.24 and below": ()}
invalid_int_dtypes = {"0.4.24 and below": ()}
invalid_float_dtypes = {"0.4.24 and below": ()}
invalid_uint_dtypes = {"0.4.24 and below": ()}
invalid_complex_dtypes = {"0.4.24 and below": ()}

# leave these untouched
invalid_dtypes = _dtype_from_version(invalid_dtypes, backend_version)
Expand Down
40 changes: 20 additions & 20 deletions ivy/functional/backends/jax/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def abs(
return jnp.where(x != 0, jnp.absolute(x), 0)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def acos(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.arccos(x)

Expand All @@ -52,12 +52,12 @@ def add(
return jnp.add(x1, x2)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def asin(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.arcsin(x)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def asinh(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.arcsinh(x)

Expand All @@ -71,12 +71,12 @@ def atan2(x1: JaxArray, x2: JaxArray, /, *, out: Optional[JaxArray] = None) -> J
return jnp.arctan2(x1, x2)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def atanh(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.arctanh(x)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def bitwise_and(
x1: Union[int, JaxArray],
x2: Union[int, JaxArray],
Expand All @@ -88,14 +88,14 @@ def bitwise_and(
return jnp.bitwise_and(x1, x2)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def bitwise_invert(
x: Union[int, JaxArray], /, *, out: Optional[JaxArray] = None
) -> JaxArray:
return jnp.bitwise_not(x)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def bitwise_left_shift(
x1: Union[int, JaxArray],
x2: Union[int, JaxArray],
Expand All @@ -107,7 +107,7 @@ def bitwise_left_shift(
return jnp.left_shift(x1, x2)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def bitwise_or(
x1: Union[int, JaxArray],
x2: Union[int, JaxArray],
Expand All @@ -119,7 +119,7 @@ def bitwise_or(
return jnp.bitwise_or(x1, x2)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def bitwise_right_shift(
x1: Union[int, JaxArray],
x2: Union[int, JaxArray],
Expand All @@ -131,7 +131,7 @@ def bitwise_right_shift(
return jnp.right_shift(x1, x2)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def bitwise_xor(
x1: Union[int, JaxArray],
x2: Union[int, JaxArray],
Expand All @@ -143,7 +143,7 @@ def bitwise_xor(
return jnp.bitwise_xor(x1, x2)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def ceil(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
if "int" in str(x.dtype):
return x
Expand All @@ -155,7 +155,7 @@ def cos(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.cos(x)


@with_unsupported_dtypes({"0.4.23 and below": ("float16",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("float16",)}, backend_version)
def cosh(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.cosh(x)

Expand Down Expand Up @@ -195,15 +195,15 @@ def expm1(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.expm1(x)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def floor(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
if "int" in str(x.dtype):
return x
else:
return jnp.floor(x)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def floor_divide(
x1: Union[float, JaxArray],
x2: Union[float, JaxArray],
Expand Down Expand Up @@ -251,7 +251,7 @@ def isfinite(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.isfinite(x)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def isinf(
x: JaxArray,
/,
Expand Down Expand Up @@ -432,7 +432,7 @@ def pow(
return jnp.power(x1, x2)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def remainder(
x1: Union[float, JaxArray],
x2: Union[float, JaxArray],
Expand Down Expand Up @@ -520,7 +520,7 @@ def trapz(


@with_unsupported_dtypes(
{"0.4.23 and below": ("complex", "float16", "bfloat16")}, backend_version
{"0.4.24 and below": ("complex", "float16", "bfloat16")}, backend_version
)
def tan(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.tan(x)
Expand All @@ -532,7 +532,7 @@ def tanh(
return jnp.tanh(x)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def trunc(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
if "int" in str(x.dtype):
return x
Expand Down Expand Up @@ -572,7 +572,7 @@ def angle(
# ------#


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def erf(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jax.scipy.special.erf(x)

Expand Down Expand Up @@ -623,7 +623,7 @@ def isreal(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.isreal(x)


@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("complex",)}, backend_version)
def fmod(
x1: JaxArray,
x2: JaxArray,
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/jax/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def sinc(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:


@with_supported_dtypes(
{"0.4.23 and below": ("float16", "float32", "float64")}, backend_version
{"0.4.24 and below": ("float16", "float32", "float64")}, backend_version
)
def lgamma(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jlax.lgamma(x)
Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/backends/jax/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def avg_pool3d(
return res


@with_supported_dtypes({"0.4.23 and below": ("float32", "float64")}, backend_version)
@with_supported_dtypes({"0.4.24 and below": ("float32", "float64")}, backend_version)
def dct(
x: JaxArray,
/,
Expand Down Expand Up @@ -822,7 +822,7 @@ def ifftn(


@with_unsupported_dtypes(
{"0.4.23 and below": ("bfloat16", "float16", "complex")}, backend_version
{"0.4.24 and below": ("bfloat16", "float16", "complex")}, backend_version
)
def embedding(
weights: JaxArray,
Expand Down Expand Up @@ -870,7 +870,7 @@ def rfft(
return ret


@with_unsupported_dtypes({"0.4.23 and below": ("float16", "complex")}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("float16", "complex")}, backend_version)
def rfftn(
x: JaxArray,
s: Optional[Sequence[int]] = None,
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/jax/experimental/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def beta(
return jax.random.beta(rng_input, a, b, shape, dtype)


@with_unsupported_dtypes({"0.4.23 and below": ("bfloat16",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("bfloat16",)}, backend_version)
def gamma(
alpha: Union[float, JaxArray],
beta: Union[float, JaxArray],
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/jax/experimental/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def invert_permutation(


# lexsort
@with_unsupported_dtypes({"0.4.23 and below": ("bfloat16",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("bfloat16",)}, backend_version)
def lexsort(
keys: JaxArray,
/,
Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/backends/jax/experimental/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@with_unsupported_dtypes(
{"0.4.23 and below": ("bfloat16",)},
{"0.4.24 and below": ("bfloat16",)},
backend_version,
)
def histogram(
Expand Down Expand Up @@ -121,7 +121,7 @@ def histogram(


@with_unsupported_dtypes(
{"0.4.23 and below": ("complex64", "complex128")}, backend_version
{"0.4.24 and below": ("complex64", "complex128")}, backend_version
)
def median(
input: JaxArray,
Expand Down Expand Up @@ -406,7 +406,7 @@ def __get_index(lst, indices=None, prefix=None):

@with_unsupported_dtypes(
{
"0.4.23 and below": (
"0.4.24 and below": (
"bfloat16",
"bool",
)
Expand Down
4 changes: 2 additions & 2 deletions ivy/functional/backends/jax/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def array_equal(x0: JaxArray, x1: JaxArray, /) -> bool:
return bool(jnp.array_equal(x0, x1))


@with_unsupported_dtypes({"0.4.23 and below": ("bfloat16",)}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("bfloat16",)}, backend_version)
def to_numpy(x: JaxArray, /, *, copy: bool = True) -> np.ndarray:
if copy:
return np.array(_to_array(x))
Expand Down Expand Up @@ -422,7 +422,7 @@ def vmap(
)


@with_unsupported_dtypes({"0.4.23 and below": ("float16", "bfloat16")}, backend_version)
@with_unsupported_dtypes({"0.4.24 and below": ("float16", "bfloat16")}, backend_version)
def isin(
elements: JaxArray,
test_elements: JaxArray,
Expand Down
Loading

0 comments on commit 1603886

Please sign in to comment.