Skip to content

Commit

Permalink
fix: Issue with unsupported dtypes (#27934)
Browse files Browse the repository at this point in the history
1. fixed an issue with _nested_get that was failing the unsupported dtypes function when testing with a frontend tensor method that uses the .to call as we can't infer the type of the instance where .to is being called in arbitrary code.
2. also updated the log1p frontend method and the relu frontend method implementations for the same.
  • Loading branch information
vedpatwardhan authored Jan 17, 2024
1 parent 036dfcb commit 288ea0d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
9 changes: 5 additions & 4 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# local
import ivy
import ivy.functional.frontends.torch as torch_frontend
import ivy.functional.frontends.torch.nn.functional as torch_frontend_nn
from ivy.functional.frontends.numpy.creation_routines.from_existing_data import (
array as np_frontend_array,
)
Expand Down Expand Up @@ -390,7 +389,7 @@ def log2(self):

@with_unsupported_dtypes({"2.1.2 and below": ("float16", "bfloat16")}, "torch")
def relu(self):
return torch_frontend_nn.relu(self)
return torch_frontend.nn.functional.relu(self)

@numpy_to_torch_style_args
@with_unsupported_dtypes({"2.1.2 and below": ("complex",)}, "torch")
Expand Down Expand Up @@ -1714,12 +1713,14 @@ def stride(self, dim=None):
)
def log1p(self):
promoted_type = ivy.promote_types(self.dtype, "float32")
return torch_frontend.log1p(self).to(promoted_type)
res = torch_frontend.log1p(self)
return res.to(promoted_type)

@with_supported_dtypes({"2.1.2 and below": ("float32", "float64")}, "torch")
def log1p_(self):
promoted_type = ivy.promote_types(self.dtype, "float32")
self.ivy_array = torch_frontend.log1p(self).to(promoted_type).ivy_array
res = torch_frontend.log1p(self)
self.ivy_array = res.to(promoted_type).ivy_array
return self

def baddbmm(self, batch1, batch2, *, beta=1, alpha=1):
Expand Down
15 changes: 7 additions & 8 deletions ivy/functional/ivy/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,15 @@ def _nested_get(f, base_set, merge_fn, get_fn, wrapper=set):
}
for key in fl:
if "frontend" in key:
fn_tree = key.split(".")
root, tree, fn_name = fn_tree[0], fn_tree[1:-1], fn_tree[-1]
root_module = importlib.import_module(
".".join([frontends[root], *tree])
)
frontend_fn = getattr(root_module, fn_name)
frontend_fl = _get_function_list(frontend_fn)
frontend_fn = fl[key]
for frontend in frontends:
if frontend in key:
key = key.replace(frontend, frontends[frontend])
frontend_module = ".".join(key.split(".")[:-1])
frontend_fl = {key: frontend_fn}
res += list(
_get_functions_from_string(
frontend_fl, __import__(frontend_fn.__module__)
frontend_fl, importlib.import_module(frontend_module)
)
)
to_visit.extend(set(res))
Expand Down

0 comments on commit 288ea0d

Please sign in to comment.