From 67c9de137887d20e8ec7fa08da52b5273e57a083 Mon Sep 17 00:00:00 2001 From: Krishna Gopal Date: Tue, 24 Oct 2023 11:45:40 -0700 Subject: [PATCH] Refactor validation decorators (#1354) * Refactor validation decorators --- .../server/lib/elasticsearch/search_table.py | 11 +++- .../base_sqlglot_validation_decorator.py} | 34 ++++------- .../decorators/base_validation_decorator.py | 35 +++++++++++ .../metadata_decorators.py} | 60 +++++++------------ .../validators/presto_optimizing_validator.py | 48 +++++---------- querybook/server/logic/admin.py | 6 ++ .../test_presto_optimizing_validator.py | 12 +++- 7 files changed, 104 insertions(+), 102 deletions(-) rename querybook/server/lib/query_analysis/validation/{validators/base_sqlglot_validator.py => decorators/base_sqlglot_validation_decorator.py} (71%) create mode 100644 querybook/server/lib/query_analysis/validation/decorators/base_validation_decorator.py rename querybook/server/lib/query_analysis/validation/{validators/metadata_suggesters.py => decorators/metadata_decorators.py} (73%) diff --git a/querybook/server/lib/elasticsearch/search_table.py b/querybook/server/lib/elasticsearch/search_table.py index 55ac8ee66..d080eb497 100644 --- a/querybook/server/lib/elasticsearch/search_table.py +++ b/querybook/server/lib/elasticsearch/search_table.py @@ -217,7 +217,9 @@ def get_column_name_suggestion( return get_matching_objects(search_query, ES_CONFIG["tables"]["index_name"], True) -def get_table_name_suggestion(fuzzy_table_name: str) -> Tuple[Dict, int]: +def get_table_name_suggestion( + fuzzy_table_name: str, metastore_id: int +) -> Tuple[Dict, int]: """Given an invalid table name use fuzzy search to search the correctly-spelled table name""" schema_name, fuzzy_name = None, fuzzy_table_name @@ -229,7 +231,12 @@ def get_table_name_suggestion(fuzzy_table_name: str) -> Tuple[Dict, int]: { "match": { "name": {"query": fuzzy_name, "fuzziness": "AUTO"}, - } + }, + }, + { + "match": { + "metastore_id": metastore_id, + }, }, ] if schema_name: diff --git a/querybook/server/lib/query_analysis/validation/validators/base_sqlglot_validator.py b/querybook/server/lib/query_analysis/validation/decorators/base_sqlglot_validation_decorator.py similarity index 71% rename from querybook/server/lib/query_analysis/validation/validators/base_sqlglot_validator.py rename to querybook/server/lib/query_analysis/validation/decorators/base_sqlglot_validation_decorator.py index 16dab444d..f4faadcac 100644 --- a/querybook/server/lib/query_analysis/validation/validators/base_sqlglot_validator.py +++ b/querybook/server/lib/query_analysis/validation/decorators/base_sqlglot_validation_decorator.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Any, Dict, List, Tuple +from typing import List, Tuple from sqlglot import Tokenizer from sqlglot.tokens import Token @@ -8,13 +8,12 @@ QueryValidationResultObjectType, QueryValidationSeverity, ) -from lib.query_analysis.validation.base_query_validator import BaseQueryValidator - +from lib.query_analysis.validation.decorators.base_validation_decorator import ( + BaseValidationDecorator, +) -class BaseSQLGlotValidator(BaseQueryValidator): - def __init__(self, name: str = "", config: Dict[str, Any] = {}): - super(BaseSQLGlotValidator, self).__init__(name, config) +class BaseSQLGlotValidationDecorator(BaseValidationDecorator): @property @abstractmethod def message(self) -> str: @@ -65,7 +64,6 @@ def _get_query_validation_result( suggestion=suggestion, ) - @abstractmethod def validate( self, query: str, @@ -74,20 +72,8 @@ def validate( raw_tokens: List[Token] = None, **kwargs, ) -> List[QueryValidationResult]: - raise NotImplementedError() - - -class BaseSQLGlotDecorator(BaseSQLGlotValidator): - def __init__(self, validator: BaseQueryValidator): - self._validator = validator - - def validate( - self, - query: str, - uid: int, - engine_id: int, - raw_tokens: List[Token] = None, - **kwargs, - ): - """Override this method to add suggestions to validation results""" - return self._validator.validate(query, uid, engine_id, **kwargs) + if raw_tokens is None: + raw_tokens = self._tokenize_query(query) + return super(BaseSQLGlotValidationDecorator, self).validate( + query, uid, engine_id, raw_tokens=raw_tokens, **kwargs + ) diff --git a/querybook/server/lib/query_analysis/validation/decorators/base_validation_decorator.py b/querybook/server/lib/query_analysis/validation/decorators/base_validation_decorator.py new file mode 100644 index 000000000..f0b4b1e02 --- /dev/null +++ b/querybook/server/lib/query_analysis/validation/decorators/base_validation_decorator.py @@ -0,0 +1,35 @@ +from abc import ABCMeta, abstractmethod +from typing import List + +from lib.query_analysis.validation.base_query_validator import ( + QueryValidationResult, +) +from lib.query_analysis.validation.base_query_validator import BaseQueryValidator + + +class BaseValidationDecorator(metaclass=ABCMeta): + def __init__(self, validator: BaseQueryValidator): + self._validator = validator + + @abstractmethod + def decorate_validation_results( + self, + validation_results: List[QueryValidationResult], + query: str, + uid: int, + engine_id: int, + **kwargs, + ) -> List[QueryValidationResult]: + raise NotImplementedError() + + def validate( + self, + query: str, + uid: int, + engine_id: int, + **kwargs, + ) -> List[QueryValidationResult]: + validation_results = self._validator.validate(query, uid, engine_id, **kwargs) + return self.decorate_validation_results( + validation_results, query, uid, engine_id, **kwargs + ) diff --git a/querybook/server/lib/query_analysis/validation/validators/metadata_suggesters.py b/querybook/server/lib/query_analysis/validation/decorators/metadata_decorators.py similarity index 73% rename from querybook/server/lib/query_analysis/validation/validators/metadata_suggesters.py rename to querybook/server/lib/query_analysis/validation/decorators/metadata_decorators.py index 726e2ce9d..e786d1092 100644 --- a/querybook/server/lib/query_analysis/validation/validators/metadata_suggesters.py +++ b/querybook/server/lib/query_analysis/validation/decorators/metadata_decorators.py @@ -6,23 +6,14 @@ from lib.query_analysis.lineage import process_query from lib.query_analysis.validation.base_query_validator import ( QueryValidationResult, - QueryValidationSeverity, ) -from lib.query_analysis.validation.validators.base_sqlglot_validator import ( - BaseSQLGlotDecorator, +from lib.query_analysis.validation.decorators.base_sqlglot_validation_decorator import ( + BaseValidationDecorator, ) -from logic.admin import get_query_engine_by_id +from logic import admin as admin_logic -class BaseColumnNameSuggester(BaseSQLGlotDecorator): - @property - def severity(self): - return QueryValidationSeverity.WARNING # Unused, severity is not changed - - @property - def message(self): - return "" # Unused, message is not changed - +class BaseColumnNameSuggester(BaseValidationDecorator): @abstractmethod def get_column_name_from_error( self, validation_result: QueryValidationResult @@ -32,7 +23,7 @@ def get_column_name_from_error( raise NotImplementedError() def _get_tables_in_query(self, query: str, engine_id: int) -> List[str]: - engine = get_query_engine_by_id(engine_id) + engine = admin_logic.get_query_engine_by_id(engine_id) tables_per_statement, _ = process_query(query, language=engine.language) return list(chain.from_iterable(tables_per_statement)) @@ -69,34 +60,21 @@ def _suggest_column_name_if_needed( validation_result.start_ch + len(fuzzy_column_name) - 1 ) - def validate( + def decorate_validation_results( self, + validation_results: List[QueryValidationResult], query: str, uid: int, engine_id: int, - raw_tokens: List[QueryValidationResult] = None, **kwargs, ) -> List[QueryValidationResult]: - if raw_tokens is None: - raw_tokens = self._tokenize_query(query) - validation_results = self._validator.validate( - query, uid, engine_id, raw_tokens=raw_tokens - ) tables_in_query = self._get_tables_in_query(query, engine_id) for result in validation_results: self._suggest_column_name_if_needed(result, tables_in_query) return validation_results -class BaseTableNameSuggester(BaseSQLGlotDecorator): - @property - def severity(self): - return QueryValidationSeverity.WARNING # Unused, severity is not changed - - @property - def message(self): - return "" # Unused, message is not changed - +class BaseTableNameSuggester(BaseValidationDecorator): @abstractmethod def get_full_table_name_from_error(self, validation_result: QueryValidationResult): """Returns invalid table name if the validation result is a table name error, otherwise @@ -104,14 +82,21 @@ def get_full_table_name_from_error(self, validation_result: QueryValidationResul raise NotImplementedError() def _suggest_table_name_if_needed( - self, validation_result: QueryValidationResult + self, + validation_result: QueryValidationResult, + engine_id: int, ) -> Optional[str]: """Takes validation result and tables in query to update validation result to provide table name suggestion""" fuzzy_table_name = self.get_full_table_name_from_error(validation_result) if not fuzzy_table_name: return - results, count = search_table.get_table_name_suggestion(fuzzy_table_name) + metastore_id = admin_logic.get_query_metastore_id_by_engine_id(engine_id) + if metastore_id is None: + return + results, count = search_table.get_table_name_suggestion( + fuzzy_table_name, metastore_id + ) if count > 0: table_result = results[0] # Get top match table_suggestion = f"{table_result['schema']}.{table_result['name']}" @@ -121,19 +106,14 @@ def _suggest_table_name_if_needed( validation_result.start_ch + len(fuzzy_table_name) - 1 ) - def validate( + def decorate_validation_results( self, + validation_results: List[QueryValidationResult], query: str, uid: int, engine_id: int, - raw_tokens: List[QueryValidationResult] = None, **kwargs, ) -> List[QueryValidationResult]: - if raw_tokens is None: - raw_tokens = self._tokenize_query(query) - validation_results = self._validator.validate( - query, uid, engine_id, raw_tokens=raw_tokens - ) for result in validation_results: - self._suggest_table_name_if_needed(result) + self._suggest_table_name_if_needed(result, engine_id) return validation_results diff --git a/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py b/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py index 6cf87de45..70bf71e0c 100644 --- a/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py +++ b/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py @@ -12,19 +12,16 @@ from lib.query_analysis.validation.validators.presto_explain_validator import ( PrestoExplainValidator, ) -from lib.query_analysis.validation.validators.base_sqlglot_validator import ( - BaseSQLGlotDecorator, +from lib.query_analysis.validation.decorators.base_sqlglot_validation_decorator import ( + BaseSQLGlotValidationDecorator, ) -from lib.query_analysis.validation.validators.metadata_suggesters import ( +from lib.query_analysis.validation.decorators.metadata_decorators import ( BaseColumnNameSuggester, BaseTableNameSuggester, ) -class BasePrestoSQLGlotDecorator(BaseSQLGlotDecorator): - def languages(self): - return ["presto", "trino"] - +class BasePrestoSQLGlotDecorator(BaseSQLGlotValidationDecorator): @property def tokenizer(self) -> Tokenizer: return Trino.Tokenizer() @@ -39,19 +36,15 @@ def message(self): def severity(self) -> str: return QueryValidationSeverity.WARNING - def validate( + def decorate_validation_results( self, + validation_results: List[QueryValidationResult], query: str, uid: int, engine_id: int, - raw_tokens: List[Token] = None, + raw_tokens: List[Token] = [], **kwargs, ) -> List[QueryValidationResult]: - if raw_tokens is None: - raw_tokens = self._tokenize_query(query) - validation_results = self._validator.validate( - query, uid, engine_id, raw_tokens=raw_tokens - ) for i, token in enumerate(raw_tokens): if token.token_type == TokenType.UNION: if ( @@ -77,20 +70,15 @@ def message(self): def severity(self) -> str: return QueryValidationSeverity.WARNING - def validate( + def decorate_validation_results( self, + validation_results: List[QueryValidationResult], query: str, uid: int, engine_id: int, - raw_tokens: List[Token] = None, + raw_tokens: List[Token] = [], **kwargs, ) -> List[QueryValidationResult]: - if raw_tokens is None: - raw_tokens = self._tokenize_query(query) - - validation_results = self._validator.validate( - query, uid, engine_id, raw_tokens=raw_tokens - ) for i, token in enumerate(raw_tokens): if ( i < len(raw_tokens) - 2 @@ -125,21 +113,15 @@ def _get_regexp_like_suggestion(self, column_name: str, like_strings: List[str]) ] return f"REGEXP_LIKE({column_name}, '{'|'.join(sanitized_like_strings)}')" - def validate( + def decorate_validation_results( self, + validation_results: List[QueryValidationResult], query: str, uid: int, engine_id: int, - raw_tokens: List[Token] = None, + raw_tokens: List[Token] = [], **kwargs, ) -> List[QueryValidationResult]: - if raw_tokens is None: - raw_tokens = self._tokenize_query(query) - - validation_results = self._validator.validate( - query, uid, engine_id, raw_tokens=raw_tokens - ) - start_column_token = None like_strings = [] token_idx = 0 @@ -203,7 +185,7 @@ def validate( return validation_results -class PrestoColumnNameSuggester(BasePrestoSQLGlotDecorator, BaseColumnNameSuggester): +class PrestoColumnNameSuggester(BaseColumnNameSuggester): def get_column_name_from_error(self, validation_result: QueryValidationResult): regex_result = re.match( r"line \d+:\d+: Column '(.*)' cannot be resolved", validation_result.message @@ -211,7 +193,7 @@ def get_column_name_from_error(self, validation_result: QueryValidationResult): return regex_result.groups()[0] if regex_result else None -class PrestoTableNameSuggester(BasePrestoSQLGlotDecorator, BaseTableNameSuggester): +class PrestoTableNameSuggester(BaseTableNameSuggester): def get_full_table_name_from_error(self, validation_result: QueryValidationResult): regex_result = re.match( r"line \d+:\d+: Table '(.*)' does not exist", validation_result.message diff --git a/querybook/server/logic/admin.py b/querybook/server/logic/admin.py index 7c7bd80ba..c570041db 100644 --- a/querybook/server/logic/admin.py +++ b/querybook/server/logic/admin.py @@ -224,6 +224,12 @@ def get_query_metastore_by_name(name, session=None): return session.query(QueryMetastore).filter(QueryMetastore.name == name).first() +@with_session +def get_query_metastore_id_by_engine_id(engine_id: int, session=None): + query_engine = get_query_engine_by_id(engine_id, session=session) + return query_engine.metastore_id if query_engine else None + + @with_session def get_all_query_metastore(session=None): return session.query(QueryMetastore).all() diff --git a/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py b/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py index 93a98bbb6..9e055af12 100644 --- a/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py +++ b/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py @@ -373,6 +373,12 @@ def test__get_column_name_suggestion(self, mock_get_column_name_suggestion): class PrestoTableNameSuggesterTestCase(BaseValidatorTestCase): def setUp(self): self._validator = PrestoTableNameSuggester(MagicMock()) + patch_get_metastore_id = patch( + "logic.admin.get_query_metastore_id_by_engine_id" + ) + mock_get_metastore_id = patch_get_metastore_id.start() + mock_get_metastore_id.return_value = 1 + self.addCleanup(patch_get_metastore_id.stop) def test_get_full_table_name_from_error(self): self.assertEquals( @@ -411,7 +417,7 @@ def test__suggest_table_name_if_needed_single_hit(self, mock_table_suggestion): mock_table_suggestion.return_value = [ {"schema": "main", "name": "world_happiness_rank_2015"} ], 1 - self._validator._suggest_table_name_if_needed(validation_result) + self._validator._suggest_table_name_if_needed(validation_result, 0) self.assertEquals( validation_result.suggestion, "main.world_happiness_rank_2015" ) @@ -430,7 +436,7 @@ def test__suggest_table_name_if_needed_multiple_hits(self, mock_table_suggestion {"schema": "main", "name": "world_happiness_rank_2015"}, {"schema": "main", "name": "world_happiness_rank_2016"}, ], 2 - self._validator._suggest_table_name_if_needed(validation_result) + self._validator._suggest_table_name_if_needed(validation_result, 0) self.assertEquals( validation_result.suggestion, "main.world_happiness_rank_2015" ) @@ -446,7 +452,7 @@ def test__suggest_table_name_if_needed_no_hits(self, mock_table_suggestion): "line 0:1: Table 'world_happiness_15' does not exist", ) mock_table_suggestion.return_value = [], 0 - self._validator._suggest_table_name_if_needed(validation_result) + self._validator._suggest_table_name_if_needed(validation_result, 0) self.assertEquals(validation_result.suggestion, None)