From 3a2cf9eed83ee97a14518640605da062d89a8369 Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Tue, 3 May 2022 15:18:57 +0100 Subject: [PATCH 01/12] feat: deprecate /superset/validate_sql_json migrate to api v1 --- superset/constants.py | 1 + superset/databases/api.py | 68 +++++++++++ superset/databases/commands/exceptions.py | 25 ++++ superset/databases/commands/validate_sql.py | 122 ++++++++++++++++++++ superset/databases/schemas.py | 15 +++ 5 files changed, 231 insertions(+) create mode 100644 superset/databases/commands/validate_sql.py diff --git a/superset/constants.py b/superset/constants.py index 8399aa457a882..8108e96c515ac 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -126,6 +126,7 @@ class RouteMethod: # pylint: disable=too-few-public-methods "get_datasets": "read", "function_names": "read", "available": "read", + "validate_sql": "read", "get_data": "read", } diff --git a/superset/databases/api.py b/superset/databases/api.py index ac497bf67dbde..e2552e2a62786 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -49,6 +49,7 @@ from superset.databases.commands.test_connection import TestConnectionDatabaseCommand from superset.databases.commands.update import UpdateDatabaseCommand from superset.databases.commands.validate import ValidateDatabaseParametersCommand +from superset.databases.commands.validate_sql import ValidateSQLCommand from superset.databases.dao import DatabaseDAO from superset.databases.decorators import check_datasource_access from superset.databases.filters import DatabaseFilter, DatabaseUploadEnabledFilter @@ -64,6 +65,8 @@ SchemasResponseSchema, SelectStarResponseSchema, TableMetadataResponseSchema, + ValidateSQLRequest, + ValidateSQLResponse, ) from superset.databases.utils import get_table_metadata from superset.db_engine_specs import get_available_engine_specs @@ -96,6 +99,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "function_names", "available", "validate_parameters", + "validate_sql", } resource_name = "database" class_permission_name = "Database" @@ -191,6 +195,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "database_schemas_query_schema": database_schemas_query_schema, "get_export_ids_schema": get_export_ids_schema, } + openapi_spec_tag = "Database" openapi_spec_component_schemas = ( DatabaseFunctionNamesResponse, @@ -200,6 +205,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi): TableMetadataResponseSchema, SelectStarResponseSchema, SchemasResponseSchema, + ValidateSQLRequest, + ValidateSQLResponse, ) @expose("/", methods=["POST"]) @@ -708,6 +715,67 @@ def related_objects(self, pk: int) -> Response: }, ) + @expose("//validate_sql", methods=["POST"]) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.validate_sql", + log_to_statsd=False, + ) + def validate_sql(self, pk: int, **kwargs: Any) -> FlaskResponse: + """ + --- + post: + summary: >- + Validates that arbitrary sql is acceptable for the given database + description: >- + Validates arbitrary SQL. + parameters: + - in: path + schema: + type: integer + name: pk + requestBody: + description: Validate SQL request + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/ValidateSQLRequest' + responses: + 200: + description: Validation result + content: + application/json: + schema: + type: object + properties: + result: + description: >- + A List of SQL errors found on the statement + type: array + items: + $ref: '#/components/schemas/ValidateSQLResponse' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + try: + sql_request = ValidateSQLRequest().load(request.json) + except ValidationError as error: + return self.response_400(message=error.messages) + try: + validator_errors = ValidateSQLCommand(pk, sql_request).run() + return self.response(200, result=validator_errors) + except DatabaseNotFoundError: + return self.response_404() + @expose("/export/", methods=["GET"]) @protect() @safe diff --git a/superset/databases/commands/exceptions.py b/superset/databases/commands/exceptions.py index bde76c021c88a..a49abd3449d03 100644 --- a/superset/databases/commands/exceptions.py +++ b/superset/databases/commands/exceptions.py @@ -137,6 +137,31 @@ class DatabaseTestConnectionUnexpectedError(SupersetErrorsException): message = _("Unexpected error occurred, please check your logs for details") +class NoValidatorConfigFoundError(SupersetErrorException): + status = 422 + message = _("no SQL validator is configured") + + +class NoValidatorFoundError(SupersetErrorException): + status = 422 + message = _("No validator found (configured for the engine)") + + +class ValidatorSQLError(SupersetErrorException): + status = 422 + message = _("Was unable to check your query") + + +class ValidatorSQLUnexpectedError(CommandException): + status = 422 + message = _("An unexpected error occurred") + + +class ValidatorSQL400Error(SupersetErrorException): + status = 400 + message = _("Was unable to check your query") + + class DatabaseImportError(ImportFailedError): message = _("Import database failed for an unknown reason") diff --git a/superset/databases/commands/validate_sql.py b/superset/databases/commands/validate_sql.py new file mode 100644 index 0000000000000..bcd87b9db6e38 --- /dev/null +++ b/superset/databases/commands/validate_sql.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +import re +from typing import Any, Dict, List, Optional, Type, TypedDict + +from flask import current_app +from flask_babel import gettext as __ +from marshmallow import ValidationError + +from superset.commands.base import BaseCommand +from superset.databases.commands.exceptions import ( + DatabaseNotFoundError, + NoValidatorConfigFoundError, + NoValidatorFoundError, + ValidatorSQL400Error, + ValidatorSQLError, + ValidatorSQLUnexpectedError, +) +from superset.databases.dao import DatabaseDAO +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.models.core import Database +from superset.sql_validators import get_validator_by_name +from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation +from superset.utils import core as utils + +logger = logging.getLogger(__name__) + + +class ValidateSQLCommand(BaseCommand): + def __init__(self, model_id: int, data: Dict[str, Any]): + self._properties = data.copy() + self._model_id = model_id + self._model: Optional[Database] = None + self._validator: Optional[Type[BaseSQLValidator]] = None + + def run(self) -> List[Dict[str, Any]]: + """ + Validates a SQL statement + + :return: A List of SQLValidationAnnotation + :raises: DatabaseNotFoundError, NoValidatorConfigFoundError + NoValidatorFoundError, ValidatorSQLUnexpectedError, ValidatorSQLError + ValidatorSQL400Error + """ + self.validate() + if not self._validator or not self._model: + raise ValidatorSQLUnexpectedError() + sql = self._properties["sql"] + schema = self._properties.get("schema") + try: + timeout = current_app.config["SQLLAB_VALIDATION_TIMEOUT"] + timeout_msg = f"The query exceeded the {timeout} seconds timeout." + with utils.timeout(seconds=timeout, error_message=timeout_msg): + errors = self._validator.validate(sql, schema, self._model) + return [err.to_dict() for err in errors] + except Exception as ex: # pylint: disable=broad-except + logger.exception(ex) + superset_error = SupersetError( + message=__( + "%(validator)s was unable to check your query.\n" + "Please recheck your query.\n" + "Exception: %(ex)s", + validator=self._validator.name, + ex=ex, + ), + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + + # Return as a 400 if the database error message says we got a 4xx error + if re.search(r"([\W]|^)4\d{2}([\W]|$)", str(ex)): + raise ValidatorSQL400Error(superset_error) + raise ValidatorSQLError(superset_error) + + def validate(self) -> None: + exceptions: List[ValidationError] = [] + # Validate/populate model exists + self._model = DatabaseDAO.find_by_id(self._model_id) + if not self._model: + raise DatabaseNotFoundError() + + spec = self._model.db_engine_spec + validators_by_engine = current_app.config["SQL_VALIDATORS_BY_ENGINE"] + if not validators_by_engine or spec.engine not in validators_by_engine: + raise NoValidatorConfigFoundError( + SupersetError( + message=__( + "no SQL validator is configured for {}".format(spec.engine) + ), + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ), + ) + validator_name = validators_by_engine[spec.engine] + self._validator = get_validator_by_name(validator_name) + if not self._validator: + raise NoValidatorFoundError( + SupersetError( + message=__( + "No validator named {} found (configured for the {} engine)".format( + validator_name, spec.engine + ) + ), + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ), + ) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index dd60e87d167f0..50ab84997a5b7 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -535,6 +535,21 @@ class SchemasResponseSchema(Schema): result = fields.List(fields.String(description="A database schema name")) +class ValidateSQLRequest(Schema): + sql = fields.String(required=True, description="SQL statement to validate") + schema = fields.String(required=False, allow_none=True) + template_params = fields.Dict( + required=False, + ) + + +class ValidateSQLResponse(Schema): + line_number = fields.Integer() + start_column = fields.Integer() + end_column = fields.Integer() + message = fields.String() + + class DatabaseRelatedChart(Schema): id = fields.Integer() slice_name = fields.String() From 202240e8a985b8cfb5b8e53cb7927ac133defe7b Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Tue, 3 May 2022 17:43:02 +0100 Subject: [PATCH 02/12] use new error handling --- superset/databases/api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/superset/databases/api.py b/superset/databases/api.py index e2552e2a62786..d3e896adfd6f0 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -717,7 +717,6 @@ def related_objects(self, pk: int) -> Response: @expose("//validate_sql", methods=["POST"]) @protect() - @safe @statsd_metrics @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.validate_sql", From b636399911f3862f49161150ea38d6078e893706 Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Wed, 4 May 2022 12:28:22 +0100 Subject: [PATCH 03/12] migrate SQLLAb frontend and add tests --- .../src/SqlLab/actions/sqlLab.js | 17 +- superset/databases/api.py | 2 +- superset/databases/commands/validate_sql.py | 7 +- superset/databases/schemas.py | 4 +- superset/views/core.py | 6 + .../integration_tests/databases/api_tests.py | 155 ++++++++++++++++++ 6 files changed, 171 insertions(+), 20 deletions(-) diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js index 41717dd17488b..a8036998b8ca4 100644 --- a/superset-frontend/src/SqlLab/actions/sqlLab.js +++ b/superset-frontend/src/SqlLab/actions/sqlLab.js @@ -369,24 +369,19 @@ export function validateQuery(query) { dispatch(startQueryValidation(query)); const postPayload = { - client_id: query.id, - database_id: query.dbId, - json: true, schema: query.schema, sql: query.sql, - sql_editor_id: query.sqlEditorId, - templateParams: query.templateParams, - validate_only: true, + template_params: query.templateParams, }; return SupersetClient.post({ - endpoint: `/superset/validate_sql_json/${window.location.search}`, - postPayload, - stringify: false, + endpoint: `/api/v1/database/${query.dbId}/validate_sql`, + body: JSON.stringify(postPayload), + headers: { 'Content-Type': 'application/json' }, }) - .then(({ json }) => dispatch(queryValidationReturned(query, json))) + .then(({ json }) => dispatch(queryValidationReturned(query, json.result))) .catch(response => - getClientErrorObject(response).then(error => { + getClientErrorObject(response.result).then(error => { let message = error.error || error.statusText || t('Unknown error'); if (message.includes('CSRF token')) { message = t(COMMON_ERR_MESSAGES.SESSION_TIMED_OUT); diff --git a/superset/databases/api.py b/superset/databases/api.py index d3e896adfd6f0..795b4535164e7 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -722,7 +722,7 @@ def related_objects(self, pk: int) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.validate_sql", log_to_statsd=False, ) - def validate_sql(self, pk: int, **kwargs: Any) -> FlaskResponse: + def validate_sql(self, pk: int) -> FlaskResponse: """ --- post: diff --git a/superset/databases/commands/validate_sql.py b/superset/databases/commands/validate_sql.py index bcd87b9db6e38..c3c6374145972 100644 --- a/superset/databases/commands/validate_sql.py +++ b/superset/databases/commands/validate_sql.py @@ -16,12 +16,10 @@ # under the License. import logging import re -from typing import Any, Dict, List, Optional, Type, TypedDict +from typing import Any, Dict, List, Optional, Type from flask import current_app from flask_babel import gettext as __ -from marshmallow import ValidationError - from superset.commands.base import BaseCommand from superset.databases.commands.exceptions import ( DatabaseNotFoundError, @@ -35,7 +33,7 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.core import Database from superset.sql_validators import get_validator_by_name -from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation +from superset.sql_validators.base import BaseSQLValidator from superset.utils import core as utils logger = logging.getLogger(__name__) @@ -88,7 +86,6 @@ def run(self) -> List[Dict[str, Any]]: raise ValidatorSQLError(superset_error) def validate(self) -> None: - exceptions: List[ValidationError] = [] # Validate/populate model exists self._model = DatabaseDAO.find_by_id(self._model_id) if not self._model: diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 50ab84997a5b7..db38ab2783350 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -538,9 +538,7 @@ class SchemasResponseSchema(Schema): class ValidateSQLRequest(Schema): sql = fields.String(required=True, description="SQL statement to validate") schema = fields.String(required=False, allow_none=True) - template_params = fields.Dict( - required=False, - ) + template_params = fields.Dict(required=False, allow_none=True) class ValidateSQLResponse(Schema): diff --git a/superset/views/core.py b/superset/views/core.py index 478cd93a739d9..b2420fe474371 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2360,6 +2360,12 @@ def validate_sql_json( """Validates that arbitrary sql is acceptable for the given database. Returns a list of error/warning annotations as json. """ + logger.warning( + "%s.validate_sql_json " + "This API endpoint is deprecated and will be removed in version 3.0.0", + self.__class__.__name__, + ) + sql = request.form["sql"] database_id = request.form["database_id"] schema = request.form.get("schema") or None diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 70640728ac352..c8df4e9c379a0 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -21,6 +21,7 @@ from collections import defaultdict from io import BytesIO from unittest import mock +from unittest.mock import patch, MagicMock from zipfile import is_zipfile, ZipFile from operator import itemgetter @@ -71,6 +72,21 @@ from tests.integration_tests.test_app import app +SQL_VALIDATORS_BY_ENGINE = { + "presto": "PrestoDBSQLValidator", + "sqlite": "PrestoDBSQLValidator", + "postgresql": "PostgreSQLValidator", + "mysql": "PrestoDBSQLValidator", +} + +PRESTO_SQL_VALIDATORS_BY_ENGINE = { + "presto": "PrestoDBSQLValidator", + "sqlite": "PrestoDBSQLValidator", + "postgresql": "PrestoDBSQLValidator", + "mysql": "PrestoDBSQLValidator", +} + + class TestDatabaseApi(SupersetTestCase): def insert_database( self, @@ -2365,3 +2381,142 @@ def test_get_related_objects(self): assert "charts" in rv.json assert "dashboards" in rv.json assert "sqllab_tab_states" in rv.json + + @patch.dict( + "superset.config.SQL_VALIDATORS_BY_ENGINE", + SQL_VALIDATORS_BY_ENGINE, + clear=True, + ) + def test_validate_sql(self): + """ + Database API: validate SQL success + """ + request_payload = { + "sql": "SELECT * from birth_names", + "schema": None, + "template_params": None, + } + + self.login(username="admin") + example_db = get_example_database() + + uri = f"api/v1/database/{example_db.id}/validate_sql" + rv = self.client.post(uri, json=request_payload) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + self.assertEqual(response["result"], []) + + @patch.dict( + "superset.config.SQL_VALIDATORS_BY_ENGINE", + SQL_VALIDATORS_BY_ENGINE, + clear=True, + ) + def test_validate_sql_errors(self): + """ + Database API: validate SQL with errors + """ + request_payload = { + "sql": "SELECT col1 froma table1", + "schema": None, + "template_params": None, + } + + self.login(username="admin") + example_db = get_example_database() + + uri = f"api/v1/database/{example_db.id}/validate_sql" + rv = self.client.post(uri, json=request_payload) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + self.assertEqual( + response["result"], + [ + { + "end_column": None, + "line_number": 1, + "message": 'ERROR: syntax error at or near "table1"', + "start_column": None, + } + ], + ) + + @patch.dict( + "superset.config.SQL_VALIDATORS_BY_ENGINE", + {}, + clear=True, + ) + def test_validate_sql_endpoint_noconfig(self): + """Assert that validate_sql_json errors out when no validators are + configured for any db""" + request_payload = { + "sql": "SELECT col1 from table1", + "schema": None, + "template_params": None, + } + + self.login("admin") + + example_db = get_example_database() + + uri = f"api/v1/database/{example_db.id}/validate_sql" + rv = self.client.post(uri, json=request_payload) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 422) + self.assertEqual( + response, + { + "errors": [ + { + "message": f"no SQL validator is configured for " + f"{example_db.backend}", + "error_type": "GENERIC_DB_ENGINE_ERROR", + "level": "error", + "extra": { + "issue_codes": [ + { + "code": 1002, + "message": "Issue 1002 - The database returned an " + "unexpected error.", + } + ] + }, + } + ] + }, + ) + + @patch("superset.databases.commands.validate_sql.get_validator_by_name") + @patch.dict( + "superset.config.SQL_VALIDATORS_BY_ENGINE", + PRESTO_SQL_VALIDATORS_BY_ENGINE, + clear=True, + ) + def test_validate_sql_endpoint_failure(self, get_validator_by_name): + """Assert that validate_sql_json errors out when the selected validator + raises an unexpected exception""" + + request_payload = { + "sql": "SELECT * FROM birth_names", + "schema": None, + "template_params": None, + } + + self.login("admin") + + validator = MagicMock() + get_validator_by_name.return_value = validator + validator.validate.side_effect = Exception("Kaboom!") + + self.login("admin") + + example_db = get_example_database() + + uri = f"api/v1/database/{example_db.id}/validate_sql" + rv = self.client.post(uri, json=request_payload) + response = json.loads(rv.data.decode("utf-8")) + + # TODO(bkyryliuk): properly handle hive error + if get_example_database().backend == "hive": + return + self.assertEqual(rv.status_code, 422) + self.assertIn("Kaboom!", response["errors"][0]["message"]) From 8a0870a37666cbe670bafb780279d1f2edd0ffc4 Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Wed, 4 May 2022 15:03:41 +0100 Subject: [PATCH 04/12] debug test --- .../integration_tests/databases/api_tests.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index c8df4e9c379a0..49b886b8f2096 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -2106,7 +2106,8 @@ def test_validate_parameters_invalid_payload_schema(self): "issue_codes": [ { "code": 1020, - "message": "Issue 1020 - The submitted payload has the incorrect schema.", + "message": "Issue 1020 - The submitted payload" + " has the incorrect schema.", } ], }, @@ -2120,7 +2121,8 @@ def test_validate_parameters_invalid_payload_schema(self): "issue_codes": [ { "code": 1020, - "message": "Issue 1020 - The submitted payload has the incorrect schema.", + "message": "Issue 1020 - The submitted payload " + "has the incorrect schema.", } ], }, @@ -2153,7 +2155,8 @@ def test_validate_parameters_missing_fields(self): assert response == { "errors": [ { - "message": "One or more parameters are missing: database, host, username", + "message": "One or more parameters are missing: database, host," + " username", "error_type": "CONNECTION_MISSING_PARAMETERS_ERROR", "level": "warning", "extra": { @@ -2161,7 +2164,8 @@ def test_validate_parameters_missing_fields(self): "issue_codes": [ { "code": 1018, - "message": "Issue 1018 - One or more parameters needed to configure a database are missing.", + "message": "Issue 1018 - One or more parameters " + "needed to configure a database are missing.", } ], }, @@ -2240,7 +2244,8 @@ def test_validate_parameters_invalid_port(self): }, }, { - "message": "The port must be an integer between 0 and 65535 (inclusive).", + "message": "The port must be an integer between " + "0 and 65535 (inclusive).", "error_type": "CONNECTION_INVALID_PORT_ERROR", "level": "error", "extra": { @@ -2292,7 +2297,8 @@ def test_validate_parameters_invalid_host(self, is_hostname_valid): "issue_codes": [ { "code": 1018, - "message": "Issue 1018 - One or more parameters needed to configure a database are missing.", + "message": "Issue 1018 - One or more parameters" + " needed to configure a database are missing.", } ], }, @@ -2306,7 +2312,8 @@ def test_validate_parameters_invalid_host(self, is_hostname_valid): "issue_codes": [ { "code": 1007, - "message": "Issue 1007 - The hostname provided can't be resolved.", + "message": "Issue 1007 - The hostname " + "provided can't be resolved.", } ], }, @@ -2403,6 +2410,7 @@ def test_validate_sql(self): uri = f"api/v1/database/{example_db.id}/validate_sql" rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8")) + raise Exception(response + self.app.config["SQL_VALIDATORS_BY_ENGINE"]) self.assertEqual(rv.status_code, 200) self.assertEqual(response["result"], []) From b730cda310a948298417b440a2ea23c86040eebe Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Wed, 4 May 2022 15:48:57 +0100 Subject: [PATCH 05/12] debug test --- tests/integration_tests/databases/api_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 49b886b8f2096..ceb20aebfab1d 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -2410,7 +2410,7 @@ def test_validate_sql(self): uri = f"api/v1/database/{example_db.id}/validate_sql" rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8")) - raise Exception(response + self.app.config["SQL_VALIDATORS_BY_ENGINE"]) + raise Exception(response + str(self.app.config["SQL_VALIDATORS_BY_ENGINE"])) self.assertEqual(rv.status_code, 200) self.assertEqual(response["result"], []) From 7e12da134ba7dc2ccecad95d7cfbaec4e0dfde11 Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Wed, 4 May 2022 16:29:26 +0100 Subject: [PATCH 06/12] fix frontend test on sqllab --- tests/integration_tests/databases/api_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index ceb20aebfab1d..9f2771b1a8fc4 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -2410,7 +2410,7 @@ def test_validate_sql(self): uri = f"api/v1/database/{example_db.id}/validate_sql" rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8")) - raise Exception(response + str(self.app.config["SQL_VALIDATORS_BY_ENGINE"])) + raise Exception(str(response) + str(self.app.config["SQL_VALIDATORS_BY_ENGINE"])) self.assertEqual(rv.status_code, 200) self.assertEqual(response["result"], []) From b776ad79fc7e983ebb1971573a56e3f288086571 Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Wed, 4 May 2022 17:21:54 +0100 Subject: [PATCH 07/12] fix tests --- superset/databases/commands/validate_sql.py | 10 +++-- .../integration_tests/databases/api_tests.py | 45 ++++++++++++++++++- .../integration_tests/sql_validator_tests.py | 8 +++- 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/superset/databases/commands/validate_sql.py b/superset/databases/commands/validate_sql.py index c3c6374145972..346d684a0d2ca 100644 --- a/superset/databases/commands/validate_sql.py +++ b/superset/databases/commands/validate_sql.py @@ -20,6 +20,7 @@ from flask import current_app from flask_babel import gettext as __ + from superset.commands.base import BaseCommand from superset.databases.commands.exceptions import ( DatabaseNotFoundError, @@ -66,7 +67,7 @@ def run(self) -> List[Dict[str, Any]]: with utils.timeout(seconds=timeout, error_message=timeout_msg): errors = self._validator.validate(sql, schema, self._model) return [err.to_dict() for err in errors] - except Exception as ex: # pylint: disable=broad-except + except Exception as ex: logger.exception(ex) superset_error = SupersetError( message=__( @@ -82,8 +83,8 @@ def run(self) -> List[Dict[str, Any]]: # Return as a 400 if the database error message says we got a 4xx error if re.search(r"([\W]|^)4\d{2}([\W]|$)", str(ex)): - raise ValidatorSQL400Error(superset_error) - raise ValidatorSQLError(superset_error) + raise ValidatorSQL400Error(superset_error) from ex + raise ValidatorSQLError(superset_error) from ex def validate(self) -> None: # Validate/populate model exists @@ -109,7 +110,8 @@ def validate(self) -> None: raise NoValidatorFoundError( SupersetError( message=__( - "No validator named {} found (configured for the {} engine)".format( + "No validator named {} found " + "(configured for the {} engine)".format( validator_name, spec.engine ) ), diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 9f2771b1a8fc4..26eab1f4f475d 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -2410,7 +2410,6 @@ def test_validate_sql(self): uri = f"api/v1/database/{example_db.id}/validate_sql" rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8")) - raise Exception(str(response) + str(self.app.config["SQL_VALIDATORS_BY_ENGINE"])) self.assertEqual(rv.status_code, 200) self.assertEqual(response["result"], []) @@ -2448,6 +2447,50 @@ def test_validate_sql_errors(self): ], ) + @patch.dict( + "superset.config.SQL_VALIDATORS_BY_ENGINE", + SQL_VALIDATORS_BY_ENGINE, + clear=True, + ) + def test_validate_sql_not_found(self): + """ + Database API: validate SQL database not found + """ + request_payload = { + "sql": "SELECT * from birth_names", + "schema": None, + "template_params": None, + } + self.login(username="admin") + uri = ( + f"api/v1/database/{self.get_nonexistent_numeric_id(Database)}/validate_sql" + ) + rv = self.client.post(uri, json=request_payload) + self.assertEqual(rv.status_code, 404) + + @patch.dict( + "superset.config.SQL_VALIDATORS_BY_ENGINE", + SQL_VALIDATORS_BY_ENGINE, + clear=True, + ) + def test_validate_sql_validation_fails(self): + """ + Database API: validate SQL database payload validation fails + """ + request_payload = { + "sql": None, + "schema": None, + "template_params": None, + } + self.login(username="admin") + uri = ( + f"api/v1/database/{self.get_nonexistent_numeric_id(Database)}/validate_sql" + ) + rv = self.client.post(uri, json=request_payload) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 400) + self.assertEqual(response, {"message": {"sql": ["Field may not be null."]}}) + @patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", {}, diff --git a/tests/integration_tests/sql_validator_tests.py b/tests/integration_tests/sql_validator_tests.py index b1e661cc2c5bb..7fa9933ff4c86 100644 --- a/tests/integration_tests/sql_validator_tests.py +++ b/tests/integration_tests/sql_validator_tests.py @@ -48,13 +48,17 @@ class TestSqlValidatorEndpoint(SupersetTestCase): def tearDown(self): self.logout() + @patch("superset.views.core.get_validator_by_name") + @patch.dict( + "superset.config.SQL_VALIDATORS_BY_ENGINE", + {}, + clear=True, + ) def test_validate_sql_endpoint_noconfig(self): """Assert that validate_sql_json errors out when no validators are configured for any db""" self.login("admin") - app.config["SQL_VALIDATORS_BY_ENGINE"] = {} - resp = self.validate_sql( "SELECT * FROM birth_names", client_id="1", raise_on_error=False ) From aa7d56e9abaf0f8d64ed47a4130d07f5e47d09fe Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Wed, 4 May 2022 17:36:31 +0100 Subject: [PATCH 08/12] fix frontend test on sqllab --- tests/integration_tests/sql_validator_tests.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration_tests/sql_validator_tests.py b/tests/integration_tests/sql_validator_tests.py index 7fa9933ff4c86..faf22f4b07e53 100644 --- a/tests/integration_tests/sql_validator_tests.py +++ b/tests/integration_tests/sql_validator_tests.py @@ -48,7 +48,6 @@ class TestSqlValidatorEndpoint(SupersetTestCase): def tearDown(self): self.logout() - @patch("superset.views.core.get_validator_by_name") @patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", {}, From 4d863bd8231829cee71487a563892b591f98f0d9 Mon Sep 17 00:00:00 2001 From: dpgaspar Date: Mon, 9 May 2022 13:18:48 +0100 Subject: [PATCH 09/12] fix tests --- tests/integration_tests/sql_validator_tests.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/integration_tests/sql_validator_tests.py b/tests/integration_tests/sql_validator_tests.py index faf22f4b07e53..57f31ba4b750d 100644 --- a/tests/integration_tests/sql_validator_tests.py +++ b/tests/integration_tests/sql_validator_tests.py @@ -234,6 +234,11 @@ def test_validator_query_error(self, flask_g): self.assertEqual(1, len(errors)) + @patch.dict( + "superset.config.SQL_VALIDATORS_BY_ENGINE", + {}, + clear=True, + ) def test_validate_sql_endpoint(self): self.login("admin") # NB this is effectively an integration test -- when there's a default From 20dfed628fe606dcf792f4f26aa05e514ac993ef Mon Sep 17 00:00:00 2001 From: dpgaspar Date: Mon, 9 May 2022 14:06:09 +0100 Subject: [PATCH 10/12] fix tests --- tests/integration_tests/databases/api_tests.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 26eab1f4f475d..30cd25d46d953 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -2410,6 +2410,8 @@ def test_validate_sql(self): uri = f"api/v1/database/{example_db.id}/validate_sql" rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8")) + if rv.status_code == 422: + raise Exception(response) self.assertEqual(rv.status_code, 200) self.assertEqual(response["result"], []) From b5246e1de2e8e9a63220fe887d6d4b913645a349 Mon Sep 17 00:00:00 2001 From: dpgaspar Date: Mon, 9 May 2022 14:58:36 +0100 Subject: [PATCH 11/12] fix tests --- tests/integration_tests/databases/api_tests.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 30cd25d46d953..29abcd49772e4 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -74,9 +74,7 @@ SQL_VALIDATORS_BY_ENGINE = { "presto": "PrestoDBSQLValidator", - "sqlite": "PrestoDBSQLValidator", "postgresql": "PostgreSQLValidator", - "mysql": "PrestoDBSQLValidator", } PRESTO_SQL_VALIDATORS_BY_ENGINE = { @@ -2406,12 +2404,11 @@ def test_validate_sql(self): self.login(username="admin") example_db = get_example_database() - + if example_db.backend not in ("presto", "postgresql"): + pytest.skip("Only presto and PG are implemented") uri = f"api/v1/database/{example_db.id}/validate_sql" rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8")) - if rv.status_code == 422: - raise Exception(response) self.assertEqual(rv.status_code, 200) self.assertEqual(response["result"], []) From 9225bd4f0780dad67a234b20cd49d10a6f04f982 Mon Sep 17 00:00:00 2001 From: dpgaspar Date: Mon, 9 May 2022 15:18:11 +0100 Subject: [PATCH 12/12] fix tests --- tests/integration_tests/databases/api_tests.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 29abcd49772e4..9a80438a54028 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -2402,10 +2402,11 @@ def test_validate_sql(self): "template_params": None, } - self.login(username="admin") example_db = get_example_database() if example_db.backend not in ("presto", "postgresql"): pytest.skip("Only presto and PG are implemented") + + self.login(username="admin") uri = f"api/v1/database/{example_db.id}/validate_sql" rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8")) @@ -2427,9 +2428,11 @@ def test_validate_sql_errors(self): "template_params": None, } - self.login(username="admin") example_db = get_example_database() + if example_db.backend not in ("presto", "postgresql"): + pytest.skip("Only presto and PG are implemented") + self.login(username="admin") uri = f"api/v1/database/{example_db.id}/validate_sql" rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8"))