Skip to content

Commit

Permalink
updated test_frontend_function to allow testing for functions of th…
Browse files Browse the repository at this point in the history
…e "tensorflow.compat.v1" module
  • Loading branch information
AnnaTz authored and Tristan1649 committed Feb 22, 2023
1 parent 1ceed78 commit 802c8e2
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions ivy_tests/test_ivy/helpers/function_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,13 +676,24 @@ def _test_frontend_function(args, kwargs, args_ivy, kwargs_ivy):
*args_frontend, **kwargs_frontend
)
except KeyError:
# catch cases where the alias belongs to a higher-level module
# e.g. torch.inverse tested as an alias to torch.linalg.inv
module_name = module_name[: module_name.rfind(".")]
frontend_fw = importlib.import_module(module_name)
frontend_ret = frontend_fw.__dict__[fn_name](
*args_frontend, **kwargs_frontend
)
try:
# catch cases where the alias belongs to a higher-level module
# e.g. torch.inverse tested as an alias to torch.linalg.inv
higher_module_name = module_name[: module_name.rfind(".")]
frontend_fw = importlib.import_module(higher_module_name)
frontend_ret = frontend_fw.__dict__[fn_name](
*args_frontend, **kwargs_frontend
)
except KeyError:
if frontend == 'tensorflow':
compat_module_name = 'tensorflow.compat.v1' + module_name[
module_name.rfind("."):]
frontend_fw = importlib.import_module(compat_module_name)
frontend_ret = frontend_fw.__dict__[fn_name](
*args_frontend, **kwargs_frontend
)
else:
raise

if ivy.isscalar(frontend_ret):
frontend_ret_np_flat = [np.asarray(frontend_ret)]
Expand Down

0 comments on commit 802c8e2

Please sign in to comment.