From 802c8e25ec29ccefd12f5f574e6aab9098f7c62b Mon Sep 17 00:00:00 2001 From: AnnaTz <111577222+AnnaTz@users.noreply.github.com> Date: Tue, 17 Jan 2023 13:02:48 +0000 Subject: [PATCH] updated `test_frontend_function` to allow testing for functions of the "tensorflow.compat.v1" module --- .../test_ivy/helpers/function_testing.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/ivy_tests/test_ivy/helpers/function_testing.py b/ivy_tests/test_ivy/helpers/function_testing.py index 222eda9057e20..de43f85d313c5 100644 --- a/ivy_tests/test_ivy/helpers/function_testing.py +++ b/ivy_tests/test_ivy/helpers/function_testing.py @@ -676,13 +676,24 @@ def _test_frontend_function(args, kwargs, args_ivy, kwargs_ivy): *args_frontend, **kwargs_frontend ) except KeyError: - # catch cases where the alias belongs to a higher-level module - # e.g. torch.inverse tested as an alias to torch.linalg.inv - module_name = module_name[: module_name.rfind(".")] - frontend_fw = importlib.import_module(module_name) - frontend_ret = frontend_fw.__dict__[fn_name]( - *args_frontend, **kwargs_frontend - ) + try: + # catch cases where the alias belongs to a higher-level module + # e.g. torch.inverse tested as an alias to torch.linalg.inv + higher_module_name = module_name[: module_name.rfind(".")] + frontend_fw = importlib.import_module(higher_module_name) + frontend_ret = frontend_fw.__dict__[fn_name]( + *args_frontend, **kwargs_frontend + ) + except KeyError: + if frontend == 'tensorflow': + compat_module_name = 'tensorflow.compat.v1' + module_name[ + module_name.rfind("."):] + frontend_fw = importlib.import_module(compat_module_name) + frontend_ret = frontend_fw.__dict__[fn_name]( + *args_frontend, **kwargs_frontend + ) + else: + raise if ivy.isscalar(frontend_ret): frontend_ret_np_flat = [np.asarray(frontend_ret)]