diff --git a/eva/binder/statement_binder.py b/eva/binder/statement_binder.py index f7be1f1fe..13d9eb04e 100644 --- a/eva/binder/statement_binder.py +++ b/eva/binder/statement_binder.py @@ -36,7 +36,7 @@ from eva.parser.select_statement import SelectStatement from eva.parser.statement import AbstractStatement from eva.parser.table_ref import TableRef -from eva.utils.generic_utils import path_to_class +from eva.utils.generic_utils import load_udf_class_from_file from eva.utils.logging_manager import logger if sys.version_info >= (3, 8): @@ -228,7 +228,9 @@ def _bind_func_expr(self, node: FunctionExpression): raise BinderError(err_msg) try: - node.function = path_to_class(udf_obj.impl_file_path, udf_obj.name) + node.function = load_udf_class_from_file( + udf_obj.impl_file_path, udf_obj.name + ) except Exception as e: err_msg = ( f"{str(e)}. Please verify that the UDF class name in the" diff --git a/eva/executor/create_udf_executor.py b/eva/executor/create_udf_executor.py index d07ec6bb0..a2f6118d3 100644 --- a/eva/executor/create_udf_executor.py +++ b/eva/executor/create_udf_executor.py @@ -18,7 +18,7 @@ from eva.executor.abstract_executor import AbstractExecutor from eva.models.storage.batch import Batch from eva.plan_nodes.create_udf_plan import CreateUDFPlan -from eva.utils.generic_utils import path_to_class +from eva.utils.generic_utils import load_udf_class_from_file from eva.utils.logging_manager import logger @@ -52,12 +52,9 @@ def exec(self): impl_path = self.node.impl_path.absolute().as_posix() # check if we can create the udf object try: - path_to_class(impl_path, self.node.name)() + load_udf_class_from_file(impl_path, self.node.name)() except Exception as e: - err_msg = ( - f"{str(e)}. Please verify that the UDF class name in the " - f"implementation file matches the provided UDF name {self.node.name}." - ) + err_msg = f"Error creating UDF: {str(e)}" logger.error(err_msg) raise RuntimeError(err_msg) catalog_manager.insert_udf_catalog_entry( diff --git a/eva/utils/generic_utils.py b/eva/utils/generic_utils.py index 1bc17ae00..1c9386c19 100644 --- a/eva/utils/generic_utils.py +++ b/eva/utils/generic_utils.py @@ -14,6 +14,7 @@ # limitations under the License. import hashlib import importlib +import inspect import pickle import sys import uuid @@ -49,28 +50,44 @@ def str_to_class(class_path: str): return getattr(module, class_name) -def path_to_class(filepath: str, classname: str): +def load_udf_class_from_file(filepath, classname=None): """ - Convert the class in the path file into an object + Load a class from a Python file. If the classname is not specified, the function will check if there is only one class in the file and load that. If there are multiple classes, it will raise an error. - Arguments: - filepath: absolute path of file - classname: the name of the imported class + Args: + filepath (str): The path to the Python file. + classname (str, optional): The name of the class to load. If not specified, the function will try to load a class with the same name as the file. Defaults to None. Returns: - type: A class for given path + The class instance. + + Raises: + RuntimeError: If the class name is not found or there is more than one class in the file. """ try: abs_path = Path(filepath).resolve() spec = importlib.util.spec_from_file_location(abs_path.stem, abs_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - classobj = getattr(module, classname) except Exception as e: - err_msg = f"Failed to import {classname} from {filepath}\nException: {str(e)}" - logger.error(err_msg) + err_msg = f"Couldn't load UDF from {filepath} : {str(e)}. Ensure that the file exists and that it is a valid Python file." raise RuntimeError(err_msg) - return classobj + + # Try to load the specified class by name + if classname and hasattr(module, classname): + return getattr(module, classname) + + # If class name not specified, check if there is only one class in the file + classes = [ + obj + for _, obj in inspect.getmembers(module, inspect.isclass) + if obj.__module__ == module.__name__ + ] + if len(classes) != 1: + raise RuntimeError( + f"{filepath} contains {len(classes)} classes, please specify the correct class to load by naming the UDF with the same name in the CREATE query." + ) + return classes[0] def is_gpu_available() -> bool: diff --git a/test/binder/test_statement_binder.py b/test/binder/test_statement_binder.py index 56e792295..3bee343d5 100644 --- a/test/binder/test_statement_binder.py +++ b/test/binder/test_statement_binder.py @@ -122,8 +122,8 @@ def test_bind_explain_statement(self): mock_binder.assert_called_with(stmt.explainable_stmt) @patch("eva.binder.statement_binder.CatalogManager") - @patch("eva.binder.statement_binder.path_to_class") - def test_bind_func_expr(self, mock_path_to_class, mock_catalog): + @patch("eva.binder.statement_binder.load_udf_class_from_file") + def test_bind_func_expr(self, mock_load_udf_class_from_file, mock_catalog): # setup func_expr = MagicMock( name="func_expr", alias=Alias("func_expr"), output_col_aliases=[] @@ -142,7 +142,9 @@ def test_bind_func_expr(self, mock_path_to_class, mock_catalog): mock_catalog().get_udf_io_catalog_output_entries ) = MagicMock() mock_get_udf_outputs.return_value = func_ouput_objs - mock_path_to_class.return_value.return_value = "path_to_class" + mock_load_udf_class_from_file.return_value.return_value = ( + "load_udf_class_from_file" + ) # Case 1 set output func_expr.output = "out1" @@ -151,14 +153,16 @@ def test_bind_func_expr(self, mock_path_to_class, mock_catalog): mock_get_name.assert_called_with(func_expr.name) mock_get_udf_outputs.assert_called_with(udf_obj) - mock_path_to_class.assert_called_with(udf_obj.impl_file_path, udf_obj.name) + mock_load_udf_class_from_file.assert_called_with( + udf_obj.impl_file_path, udf_obj.name + ) self.assertEqual(func_expr.output_objs, [obj1]) print(str(func_expr.alias)) self.assertEqual( func_expr.alias, Alias("func_expr", ["out1"]), ) - self.assertEqual(func_expr.function(), "path_to_class") + self.assertEqual(func_expr.function(), "load_udf_class_from_file") # Case 2 output not set func_expr.output = None @@ -168,7 +172,9 @@ def test_bind_func_expr(self, mock_path_to_class, mock_catalog): mock_get_name.assert_called_with(func_expr.name) mock_get_udf_outputs.assert_called_with(udf_obj) - mock_path_to_class.assert_called_with(udf_obj.impl_file_path, udf_obj.name) + mock_load_udf_class_from_file.assert_called_with( + udf_obj.impl_file_path, udf_obj.name + ) self.assertEqual(func_expr.output_objs, func_ouput_objs) self.assertEqual( func_expr.alias, @@ -177,12 +183,12 @@ def test_bind_func_expr(self, mock_path_to_class, mock_catalog): ["out1", "out2"], ), ) - self.assertEqual(func_expr.function(), "path_to_class") + self.assertEqual(func_expr.function(), "load_udf_class_from_file") # Raise error if the class object cannot be created - mock_path_to_class.reset_mock() - mock_error_msg = "mock_path_to_class_error" - mock_path_to_class.side_effect = MagicMock( + mock_load_udf_class_from_file.reset_mock() + mock_error_msg = "mock_load_udf_class_from_file_error" + mock_load_udf_class_from_file.side_effect = MagicMock( side_effect=RuntimeError(mock_error_msg) ) binder = StatementBinder(StatementBinderContext()) diff --git a/test/executor/test_create_udf_executor.py b/test/executor/test_create_udf_executor.py index 0a9544863..d1abfb95a 100644 --- a/test/executor/test_create_udf_executor.py +++ b/test/executor/test_create_udf_executor.py @@ -21,15 +21,15 @@ class CreateUdfExecutorTest(unittest.TestCase): @patch("eva.executor.create_udf_executor.CatalogManager") - @patch("eva.executor.create_udf_executor.path_to_class") - def test_should_create_udf(self, path_to_class_mock, mock): + @patch("eva.executor.create_udf_executor.load_udf_class_from_file") + def test_should_create_udf(self, load_udf_class_from_file_mock, mock): catalog_instance = mock.return_value catalog_instance.get_udf_catalog_entry_by_name.return_value = None catalog_instance.insert_udf_catalog_entry.return_value = "udf" impl_path = MagicMock() abs_path = impl_path.absolute.return_value = MagicMock() abs_path.as_posix.return_value = "test.py" - path_to_class_mock.return_value.return_value = "mock_class" + load_udf_class_from_file_mock.return_value.return_value = "mock_class" plan = type( "CreateUDFPlan", (), diff --git a/test/utils/test_generic_utils.py b/test/utils/test_generic_utils.py index d8e93f616..422360b91 100644 --- a/test/utils/test_generic_utils.py +++ b/test/utils/test_generic_utils.py @@ -23,27 +23,47 @@ from eva.utils.generic_utils import ( generate_file_path, is_gpu_available, - path_to_class, + load_udf_class_from_file, str_to_class, + validate_kwargs, ) class ModulePathTest(unittest.TestCase): + def test_helper_validates_kwargs(self): + with self.assertRaises(TypeError): + validate_kwargs({"a": 1, "b": 2}, ["a"], "Invalid keyword argument:") + def test_should_return_correct_class_for_string(self): vl = str_to_class("eva.readers.opencv_reader.OpenCVReader") self.assertEqual(vl, OpenCVReader) - @unittest.skip( - "This returns opecv_reader.OpenCVReader \ - instead of eva.readers.opencv_reader.OpenCVReader" - ) def test_should_return_correct_class_for_path(self): - vl = path_to_class("eva/readers/opencv_reader.py", "OpenCVReader") - self.assertEqual(vl, OpenCVReader) + vl = load_udf_class_from_file("eva/readers/opencv_reader.py", "OpenCVReader") + # Can't check that v1 = OpenCVReader because the above function returns opencv_reader.OpenCVReader instead of eva.readers.opencv_reader.OpenCVReader + # So we check the qualname instead, qualname is the path to the class including the module name + # Ref: https://peps.python.org/pep-3155/#rationale + assert vl.__qualname__ == OpenCVReader.__qualname__ + + def test_should_return_correct_class_for_path_without_classname(self): + vl = load_udf_class_from_file("eva/readers/opencv_reader.py") + assert vl.__qualname__ == OpenCVReader.__qualname__ + + def test_should_raise_on_missing_file(self): + with self.assertRaises(RuntimeError): + load_udf_class_from_file("eva/readers/opencv_reader_abdfdsfds.py") def test_should_raise_if_class_does_not_exists(self): with self.assertRaises(RuntimeError): - path_to_class("eva/readers/opencv_reader.py", "OpenCV") + # eva/utils/s3_utils.py has no class in it + # if this test fails due to change in s3_utils.py, change the file to something else + load_udf_class_from_file("eva/utils/s3_utils.py") + + def test_should_raise_if_multiple_classes_exist_and_no_class_mentioned(self): + with self.assertRaises(RuntimeError): + # eva/utils/generic_utils.py has multiple classes in it + # if this test fails due to change in generic_utils.py, change the file to something else + load_udf_class_from_file("eva/utils/generic_utils.py") def test_should_use_torch_to_check_if_gpu_is_available(self): # Emulate a missing import