Skip to content

Commit

Permalink
fix: Fixing pytest discovery issues (#28158)
Browse files Browse the repository at this point in the history
Co-authored-by: Ved Patwardhan <54766411+vedpatwardhan@users.noreply.github.com>
Co-authored-by: NripeshN <nripesh14@gmail.com>
  • Loading branch information
3 people authored Feb 7, 2024
1 parent 935cde8 commit 0c3a4a9
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 5 deletions.
1 change: 1 addition & 0 deletions ivy/functional/frontends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"sklearn": "1.3.0",
"xgboost": "1.7.6",
"torchvision": "0.15.2.",
"mindspore": "2.0.0",
}


Expand Down
21 changes: 19 additions & 2 deletions ivy/functional/frontends/tensorflow/keras/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@
]


# --- Helpers --- #
# --------------- #


# note: defined to avoid AST call extraction of
# 'tf_frontend.keras.activations.__dict__.items()
# or 'tf_frontend.keras.activations.__dict__.values()'
def _get_tf_keras_activations():
return tf_frontend.keras.activations.__dict__.items()


# --- Main --- #
# ------------ #


@with_supported_dtypes(
{"2.15.0 and below": ("float16", "float32", "float64")},
"tensorflow",
Expand Down Expand Up @@ -132,13 +147,15 @@ def serialize(activation, use_legacy_format=False, custom_objects=None):
if custom_func == activation:
return name

tf_keras_frontend_activations = _get_tf_keras_activations()

# Check if the function is in the ACTIVATION_FUNCTIONS list
if activation.__name__ in ACTIVATION_FUNCTIONS:
return activation.__name__

# Check if the function is in the TensorFlow frontend activations
elif activation in tf_frontend.keras.activations.__dict__.values():
for name, tf_func in tf_frontend.keras.activations.__dict__.items():
elif activation in [fn for name, fn in tf_keras_frontend_activations]:
for name, tf_func in tf_keras_frontend_activations:
if tf_func == activation:
return name

Expand Down
7 changes: 6 additions & 1 deletion ivy/functional/ivy/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _nested_get(f, base_set, merge_fn, get_fn, wrapper=set):
continue
is_frontend_fn = "frontend" in fn.__module__
is_backend_fn = "backend" in fn.__module__ and not is_frontend_fn
is_einops_fn = "einops" in fn.__name__
is_einops_fn = hasattr(fn, "__name__") and "einops" in fn.__name__
if is_backend_fn:
f_supported = get_fn(fn, False)
if hasattr(fn, "partial_mixed_handler"):
Expand Down Expand Up @@ -183,6 +183,7 @@ def _nested_get(f, base_set, merge_fn, get_fn, wrapper=set):
if is_frontend_fn:
frontends = {
"jax_frontend": "ivy.functional.frontends.jax",
"jnp_frontend": "ivy.functional.frontends.jax.numpy",
"np_frontend": "ivy.functional.frontends.numpy",
"tf_frontend": "ivy.functional.frontends.tensorflow",
"torch_frontend": "ivy.functional.frontends.torch",
Expand All @@ -197,6 +198,10 @@ def _nested_get(f, base_set, merge_fn, get_fn, wrapper=set):
if "(" in key:
key = key.split("(")[0]
frontend_module = ".".join(key.split(".")[:-1])
if (
frontend_module == ""
): # single edge case: fn='frontend_outputs_to_ivy_arrays'
continue
frontend_fl = {key: frontend_fn}
res += list(
_get_functions_from_string(
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/ivy/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,7 @@ def _get_devices(fn: Callable, complement: bool = True) -> Tuple:

is_backend_fn = "backend" in fn.__module__
is_frontend_fn = "frontend" in fn.__module__
is_einops_fn = "einops" in fn.__name__
is_einops_fn = hasattr(fn, "__name__") and "einops" in fn.__name__
if not is_backend_fn and not is_frontend_fn and not is_einops_fn:
if complement:
supported = set(all_devices).difference(supported)
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/ivy/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -3988,7 +3988,7 @@ def _get_devices_and_dtypes(fn, recurse=True, complement=True):

is_frontend_fn = "frontend" in fn.__module__
is_backend_fn = "backend" in fn.__module__ and not is_frontend_fn
is_einops_fn = "einops" in fn.__name__
is_einops_fn = hasattr(fn, "__name__") and "einops" in fn.__name__
if not is_backend_fn and not is_frontend_fn and not is_einops_fn:
if complement:
all_comb = _all_dnd_combinations()
Expand Down

0 comments on commit 0c3a4a9

Please sign in to comment.