diff --git a/superset/queries/saved_queries/api.py b/superset/queries/saved_queries/api.py index 47ab052fb5a27..aef2288422cb5 100644 --- a/superset/queries/saved_queries/api.py +++ b/superset/queries/saved_queries/api.py @@ -14,19 +14,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import json import logging from datetime import datetime from io import BytesIO from typing import Any from zipfile import ZipFile -from flask import g, Response, send_file +from flask import g, request, Response, send_file from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import ngettext +from superset.commands.exceptions import CommandInvalidError +from superset.commands.importers.v1.utils import get_contents_from_bundle from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod from superset.databases.filters import DatabaseFilter +from superset.extensions import event_logger from superset.models.sql_lab import SavedQuery from superset.queries.saved_queries.commands.bulk_delete import ( BulkDeleteSavedQueryCommand, @@ -36,6 +40,9 @@ SavedQueryNotFoundError, ) from superset.queries.saved_queries.commands.export import ExportSavedQueriesCommand +from superset.queries.saved_queries.commands.importers.dispatcher import ( + ImportSavedQueriesCommand, +) from superset.queries.saved_queries.filters import ( SavedQueryAllTextFilter, SavedQueryFavoriteFilter, @@ -58,6 +65,7 @@ class SavedQueryRestApi(BaseSupersetModelRestApi): RouteMethod.EXPORT, RouteMethod.RELATED, RouteMethod.DISTINCT, + RouteMethod.IMPORT, "bulk_delete", # not using RouteMethod since locally defined } class_permission_name = "SavedQuery" @@ -252,3 +260,77 @@ def export(self, **kwargs: Any) -> Response: as_attachment=True, attachment_filename=filename, ) + + @expose("/import/", methods=["POST"]) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.import_", + log_to_statsd=False, + ) + def import_(self) -> Response: + """Import Saved Queries with associated databases + --- + post: + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + properties: + formData: + description: upload file (ZIP) + type: string + format: binary + passwords: + description: JSON map of passwords for each file + type: string + overwrite: + description: overwrite existing saved queries? + type: bool + responses: + 200: + description: Saved Query import result + content: + application/json: + schema: + type: object + properties: + message: + type: string + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + upload = request.files.get("formData") + if not upload: + return self.response_400() + with ZipFile(upload) as bundle: + contents = get_contents_from_bundle(bundle) + + passwords = ( + json.loads(request.form["passwords"]) + if "passwords" in request.form + else None + ) + overwrite = request.form.get("overwrite") == "true" + + command = ImportSavedQueriesCommand( + contents, passwords=passwords, overwrite=overwrite + ) + try: + command.run() + return self.response(200, message="OK") + except CommandInvalidError as exc: + logger.warning("Import Saved Query failed") + return self.response_422(message=exc.normalized_messages()) + except Exception as exc: # pylint: disable=broad-except + logger.exception("Import Saved Query failed") + return self.response_500(message=str(exc)) diff --git a/superset/queries/saved_queries/commands/exceptions.py b/superset/queries/saved_queries/commands/exceptions.py index 0e03dc7f4dced..731857352444d 100644 --- a/superset/queries/saved_queries/commands/exceptions.py +++ b/superset/queries/saved_queries/commands/exceptions.py @@ -16,7 +16,12 @@ # under the License. from flask_babel import lazy_gettext as _ -from superset.commands.exceptions import CommandException, DeleteFailedError +from superset.commands.exceptions import ( + CommandException, + CommandInvalidError, + DeleteFailedError, + ImportFailedError, +) class SavedQueryBulkDeleteFailedError(DeleteFailedError): @@ -25,3 +30,11 @@ class SavedQueryBulkDeleteFailedError(DeleteFailedError): class SavedQueryNotFoundError(CommandException): message = _("Saved query not found.") + + +class SavedQueryImportError(ImportFailedError): + message = _("Import saved query failed for an unknown reason.") + + +class SavedQueryInvalidError(CommandInvalidError): + message = _("Saved query parameters are invalid.") diff --git a/superset/queries/saved_queries/commands/importers/dispatcher.py b/superset/queries/saved_queries/commands/importers/dispatcher.py new file mode 100644 index 0000000000000..a53a765e790b6 --- /dev/null +++ b/superset/queries/saved_queries/commands/importers/dispatcher.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from typing import Any, Dict + +from marshmallow.exceptions import ValidationError + +from superset.commands.base import BaseCommand +from superset.commands.exceptions import CommandInvalidError +from superset.commands.importers.exceptions import IncorrectVersionError +from superset.queries.saved_queries.commands.importers import v1 + +logger = logging.getLogger(__name__) + +command_versions = [ + v1.ImportSavedQueriesCommand, +] + + +class ImportSavedQueriesCommand(BaseCommand): + """ + Import Saved Queries + + This command dispatches the import to different versions of the command + until it finds one that matches. + """ + + # pylint: disable=unused-argument + def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + self.contents = contents + self.args = args + self.kwargs = kwargs + + def run(self) -> None: + # iterate over all commands until we find a version that can + # handle the contents + for version in command_versions: + command = version(self.contents, *self.args, **self.kwargs) + try: + command.run() + return + except IncorrectVersionError: + logger.debug("File not handled by command, skipping") + except (CommandInvalidError, ValidationError) as exc: + # found right version, but file is invalid + logger.exception("Error running import command") + raise exc + + raise CommandInvalidError("Could not find a valid command to import file") + + def validate(self) -> None: + pass diff --git a/superset/queries/saved_queries/commands/importers/v1/__init__.py b/superset/queries/saved_queries/commands/importers/v1/__init__.py new file mode 100644 index 0000000000000..1412dbd356125 --- /dev/null +++ b/superset/queries/saved_queries/commands/importers/v1/__init__.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Set + +from marshmallow import Schema +from sqlalchemy.orm import Session + +from superset.commands.importers.v1 import ImportModelsCommand +from superset.connectors.sqla.models import SqlaTable +from superset.databases.commands.importers.v1.utils import import_database +from superset.databases.schemas import ImportV1DatabaseSchema +from superset.queries.saved_queries.commands.exceptions import SavedQueryImportError +from superset.queries.saved_queries.commands.importers.v1.utils import ( + import_saved_query, +) +from superset.queries.saved_queries.dao import SavedQueryDAO +from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema + + +class ImportSavedQueriesCommand(ImportModelsCommand): + """Import Saved Queries""" + + dao = SavedQueryDAO + model_name = "saved_queries" + prefix = "queries/" + schemas: Dict[str, Schema] = { + "databases/": ImportV1DatabaseSchema(), + "queries/": ImportV1SavedQuerySchema(), + } + import_error = SavedQueryImportError + + @staticmethod + def _import( + session: Session, configs: Dict[str, Any], overwrite: bool = False + ) -> None: + # discover databases associated with saved queries + database_uuids: Set[str] = set() + for file_name, config in configs.items(): + if file_name.startswith("queries/"): + database_uuids.add(config["database_uuid"]) + + # import related databases + database_ids: Dict[str, int] = {} + for file_name, config in configs.items(): + if file_name.startswith("databases/") and config["uuid"] in database_uuids: + database = import_database(session, config, overwrite=False) + database_ids[str(database.uuid)] = database.id + + # import saved queries with the correct parent ref + for file_name, config in configs.items(): + if ( + file_name.startswith("queries/") + and config["database_uuid"] in database_ids + ): + config["db_id"] = database_ids[config["database_uuid"]] + import_saved_query(session, config, overwrite=overwrite) diff --git a/superset/queries/saved_queries/commands/importers/v1/utils.py b/superset/queries/saved_queries/commands/importers/v1/utils.py new file mode 100644 index 0000000000000..f2d090bf11e5b --- /dev/null +++ b/superset/queries/saved_queries/commands/importers/v1/utils.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict + +from sqlalchemy.orm import Session + +from superset.models.sql_lab import SavedQuery + + +def import_saved_query( + session: Session, config: Dict[str, Any], overwrite: bool = False +) -> SavedQuery: + existing = session.query(SavedQuery).filter_by(uuid=config["uuid"]).first() + if existing: + if not overwrite: + return existing + config["id"] = existing.id + + saved_query = SavedQuery.import_from_dict(session, config, recursive=False) + if saved_query.id is None: + session.flush() + + return saved_query diff --git a/superset/queries/saved_queries/schemas.py b/superset/queries/saved_queries/schemas.py index afc90753e9e5f..ca2ef800a67e9 100644 --- a/superset/queries/saved_queries/schemas.py +++ b/superset/queries/saved_queries/schemas.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from marshmallow import fields, Schema +from marshmallow.validate import Length openapi_spec_methods_override = { "get": {"get": {"description": "Get a saved query",}}, @@ -32,3 +34,13 @@ get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}} get_export_ids_schema = {"type": "array", "items": {"type": "integer"}} + + +class ImportV1SavedQuerySchema(Schema): + schema = fields.String(allow_none=True, validate=Length(0, 128)) + label = fields.String(allow_none=True, validate=Length(0, 256)) + description = fields.String(allow_none=True) + sql = fields.String(required=True) + uuid = fields.UUID(required=True) + version = fields.String(required=True) + database_uuid = fields.UUID(required=True) diff --git a/tests/charts/commands_tests.py b/tests/charts/commands_tests.py index 73c01ea3d49b6..5ad53bfc51ea7 100644 --- a/tests/charts/commands_tests.py +++ b/tests/charts/commands_tests.py @@ -197,7 +197,7 @@ def test_import_v1_chart(self): db.session.commit() def test_import_v1_chart_multiple(self): - """Test that a dataset can be imported multiple times""" + """Test that a chart can be imported multiple times""" contents = { "metadata.yaml": yaml.safe_dump(chart_metadata_config), "databases/imported_database.yaml": yaml.safe_dump(database_config), diff --git a/tests/fixtures/importexport.py b/tests/fixtures/importexport.py index c815642f99de3..66eb3f6033017 100644 --- a/tests/fixtures/importexport.py +++ b/tests/fixtures/importexport.py @@ -343,7 +343,11 @@ "type": "Dashboard", "timestamp": "2020-11-04T21:27:44.423819+00:00", } - +saved_queries_metadata_config: Dict[str, Any] = { + "version": "1.0.0", + "type": "SavedQuery", + "timestamp": "2021-03-30T20:37:54.791187+00:00", +} database_config: Dict[str, Any] = { "allow_csv_upload": True, "allow_ctas": True, @@ -499,3 +503,12 @@ }, "version": "1.0.0", } +saved_queries_config = { + "schema": "public", + "label": "Test Saved Query", + "description": None, + "sql": "-- Note: Unless you save your query, these tabs will NOT persist if you clear\nyour cookies or change browsers.\n\n\nSELECT * from birth_names", + "uuid": "05b679b5-8eaf-452c-b874-a7a774cfa4e9", + "version": "1.0.0", + "database_uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89", +} diff --git a/tests/queries/saved_queries/api_tests.py b/tests/queries/saved_queries/api_tests.py index a50674569412d..bfffebf00cb28 100644 --- a/tests/queries/saved_queries/api_tests.py +++ b/tests/queries/saved_queries/api_tests.py @@ -19,8 +19,9 @@ import json from io import BytesIO from typing import Optional -from zipfile import is_zipfile +from zipfile import is_zipfile, ZipFile +import yaml import pytest import prison from sqlalchemy.sql import func, and_ @@ -33,6 +34,11 @@ from superset.utils.core import get_example_database from tests.base_tests import SupersetTestCase +from tests.fixtures.importexport import ( + database_config, + saved_queries_config, + saved_queries_metadata_config, +) SAVED_QUERIES_FIXTURE_COUNT = 10 @@ -745,3 +751,52 @@ def test_export_not_allowed(self): uri = f"api/v1/saved_query/export/?q={prison.dumps(argument)}" rv = self.client.get(uri) assert rv.status_code == 404 + + def create_saved_query_import(self): + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + with bundle.open("saved_query_export/metadata.yaml", "w") as fp: + fp.write(yaml.safe_dump(saved_queries_metadata_config).encode()) + with bundle.open( + "saved_query_export/databases/imported_database.yaml", "w" + ) as fp: + fp.write(yaml.safe_dump(database_config).encode()) + with bundle.open( + "saved_query_export/queries/imported_database/public/imported_saved_query.yaml", + "w", + ) as fp: + fp.write(yaml.safe_dump(saved_queries_config).encode()) + buf.seek(0) + return buf + + def test_import_saved_queries(self): + """ + Saved Query API: Test import + """ + self.login(username="admin") + uri = "api/v1/saved_query/import/" + + buf = self.create_saved_query_import() + form_data = { + "formData": (buf, "saved_query.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + assert response == {"message": "OK"} + database = ( + db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() + ) + assert database.database_name == "imported_database" + + saved_query = ( + db.session.query(SavedQuery) + .filter_by(uuid=saved_queries_config["uuid"]) + .one() + ) + assert saved_query.database == database + + db.session.delete(saved_query) + db.session.delete(database) + db.session.commit() diff --git a/tests/queries/saved_queries/commands_tests.py b/tests/queries/saved_queries/commands_tests.py index 1525576a27e53..aad393673e488 100644 --- a/tests/queries/saved_queries/commands_tests.py +++ b/tests/queries/saved_queries/commands_tests.py @@ -17,14 +17,27 @@ from unittest.mock import patch +import pytest import yaml from superset import db, security_manager +from superset.commands.exceptions import CommandInvalidError +from superset.commands.importers.exceptions import IncorrectVersionError +from superset.models.core import Database from superset.models.sql_lab import SavedQuery from superset.queries.saved_queries.commands.exceptions import SavedQueryNotFoundError from superset.queries.saved_queries.commands.export import ExportSavedQueriesCommand +from superset.queries.saved_queries.commands.importers.v1 import ( + ImportSavedQueriesCommand, +) from superset.utils.core import get_example_database from tests.base_tests import SupersetTestCase +from tests.fixtures.importexport import ( + database_config, + database_metadata_config, + saved_queries_config, + saved_queries_metadata_config, +) class TestExportSavedQueriesCommand(SupersetTestCase): @@ -108,3 +121,102 @@ def test_export_query_command_key_order(self, mock_g): "version", "database_uuid", ] + + +class TestImportSavedQueriesCommand(SupersetTestCase): + def test_import_v1_saved_queries(self): + """Test that we can import a saved query""" + contents = { + "metadata.yaml": yaml.safe_dump(saved_queries_metadata_config), + "databases/imported_database.yaml": yaml.safe_dump(database_config), + "queries/imported_query.yaml": yaml.safe_dump(saved_queries_config), + } + + command = ImportSavedQueriesCommand(contents) + command.run() + + saved_query = ( + db.session.query(SavedQuery) + .filter_by(uuid=saved_queries_config["uuid"]) + .one() + ) + + assert saved_query.schema == "public" + + database = ( + db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() + ) + + db.session.delete(saved_query) + db.session.delete(database) + db.session.commit() + + def test_import_v1_saved_queries_multiple(self): + """Test that a saved query can be imported multiple times""" + contents = { + "metadata.yaml": yaml.safe_dump(saved_queries_metadata_config), + "databases/imported_database.yaml": yaml.safe_dump(database_config), + "queries/imported_query.yaml": yaml.safe_dump(saved_queries_config), + } + command = ImportSavedQueriesCommand(contents, overwrite=True) + command.run() + command.run() + database = ( + db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() + ) + saved_query = db.session.query(SavedQuery).filter_by(db_id=database.id).all() + assert len(saved_query) == 1 + + db.session.delete(saved_query[0]) + db.session.delete(database) + db.session.commit() + + def test_import_v1_saved_queries_validation(self): + """Test different validations applied when importing a saved query""" + # metadata.yaml must be present + contents = { + "databases/imported_database.yaml": yaml.safe_dump(database_config), + "queries/imported_query.yaml": yaml.safe_dump(saved_queries_config), + } + command = ImportSavedQueriesCommand(contents) + with pytest.raises(IncorrectVersionError) as excinfo: + command.run() + assert str(excinfo.value) == "Missing metadata.yaml" + + # version should be 1.0.0 + contents["metadata.yaml"] = yaml.safe_dump( + { + "version": "2.0.0", + "type": "SavedQuery", + "timestamp": "2021-03-30T20:37:54.791187+00:00", + } + ) + command = ImportSavedQueriesCommand(contents) + with pytest.raises(IncorrectVersionError) as excinfo: + command.run() + assert str(excinfo.value) == "Must be equal to 1.0.0." + + # type should be a SavedQuery + contents["metadata.yaml"] = yaml.safe_dump(database_metadata_config) + command = ImportSavedQueriesCommand(contents) + with pytest.raises(CommandInvalidError) as excinfo: + command.run() + assert str(excinfo.value) == "Error importing saved_queries" + assert excinfo.value.normalized_messages() == { + "metadata.yaml": {"type": ["Must be equal to SavedQuery."]} + } + + # must also validate databases + broken_config = database_config.copy() + del broken_config["database_name"] + contents["metadata.yaml"] = yaml.safe_dump(saved_queries_metadata_config) + contents["databases/imported_database.yaml"] = yaml.safe_dump(broken_config) + command = ImportSavedQueriesCommand(contents) + with pytest.raises(CommandInvalidError) as excinfo: + command.run() + assert str(excinfo.value) == "Error importing saved_queries" + assert excinfo.value.normalized_messages() == { + "databases/imported_database.yaml": { + "database_name": ["Missing data for required field."], + } + }