diff --git a/superset/dashboards/api.py b/superset/dashboards/api.py index 76787117fd61c..763438e047a54 100644 --- a/superset/dashboards/api.py +++ b/superset/dashboards/api.py @@ -19,7 +19,7 @@ from datetime import datetime from io import BytesIO from typing import Any, Dict -from zipfile import ZipFile +from zipfile import is_zipfile, ZipFile from flask import g, make_response, redirect, request, Response, send_file, url_for from flask_appbuilder.api import expose, protect, rison, safe @@ -787,8 +787,12 @@ def import_(self) -> Response: upload = request.files.get("formData") if not upload: return self.response_400() - with ZipFile(upload) as bundle: - contents = get_contents_from_bundle(bundle) + if is_zipfile(upload): + with ZipFile(upload) as bundle: + contents = get_contents_from_bundle(bundle) + else: + upload.seek(0) + contents = {upload.filename: upload.read()} passwords = ( json.loads(request.form["passwords"]) diff --git a/superset/dashboards/commands/importers/v0.py b/superset/dashboards/commands/importers/v0.py index 851ecab941319..5a24c9309c457 100644 --- a/superset/dashboards/commands/importers/v0.py +++ b/superset/dashboards/commands/importers/v0.py @@ -317,7 +317,9 @@ class ImportDashboardsCommand(BaseCommand): in Superset. """ - def __init__(self, contents: Dict[str, str], database_id: Optional[int] = None): + def __init__( + self, contents: Dict[str, str], database_id: Optional[int] = None, **kwargs: Any + ): self.contents = contents self.database_id = database_id diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 96f3532265da0..30ad5bdd677dd 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -20,7 +20,7 @@ from distutils.util import strtobool from io import BytesIO from typing import Any -from zipfile import ZipFile +from zipfile import is_zipfile, ZipFile import yaml from flask import g, request, Response, send_file @@ -687,8 +687,12 @@ def import_(self) -> Response: upload = request.files.get("formData") if not upload: return self.response_400() - with ZipFile(upload) as bundle: - contents = get_contents_from_bundle(bundle) + if is_zipfile(upload): + with ZipFile(upload) as bundle: + contents = get_contents_from_bundle(bundle) + else: + upload.seek(0) + contents = {upload.filename: upload.read()} passwords = ( json.loads(request.form["passwords"]) diff --git a/tests/dashboards/api_tests.py b/tests/dashboards/api_tests.py index a9b0ca4812f45..7bef2edeebda1 100644 --- a/tests/dashboards/api_tests.py +++ b/tests/dashboards/api_tests.py @@ -45,6 +45,7 @@ chart_config, database_config, dashboard_config, + dashboard_export, dashboard_metadata_config, dataset_config, dataset_metadata_config, @@ -1316,6 +1317,38 @@ def test_import_dashboard(self): db.session.delete(database) db.session.commit() + def test_import_dashboard_v0_export(self): + num_dashboards = db.session.query(Dashboard).count() + + self.login(username="admin") + uri = "api/v1/dashboard/import/" + + buf = BytesIO() + buf.write(json.dumps(dashboard_export).encode()) + buf.seek(0) + form_data = { + "formData": (buf, "20201119_181105.json"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + assert response == {"message": "OK"} + assert db.session.query(Dashboard).count() == num_dashboards + 1 + + dashboard = ( + db.session.query(Dashboard).filter_by(dashboard_title="Births 2").one() + ) + chart = dashboard.slices[0] + dataset = chart.table + database = dataset.database + + db.session.delete(dashboard) + db.session.delete(chart) + db.session.delete(dataset) + db.session.delete(database) + db.session.commit() + def test_import_dashboard_overwrite(self): """ Dashboard API: Test import existing dashboard diff --git a/tests/datasets/api_tests.py b/tests/datasets/api_tests.py index 6b628d7d2472f..00be830fcf8f0 100644 --- a/tests/datasets/api_tests.py +++ b/tests/datasets/api_tests.py @@ -47,6 +47,7 @@ database_metadata_config, dataset_config, dataset_metadata_config, + dataset_ui_export, ) @@ -1275,6 +1276,31 @@ def test_import_dataset(self): db.session.delete(database) db.session.commit() + def test_import_dataset_v0_export(self): + num_datasets = db.session.query(SqlaTable).count() + + self.login(username="admin") + uri = "api/v1/dataset/import/" + + buf = BytesIO() + buf.write(json.dumps(dataset_ui_export).encode()) + buf.seek(0) + form_data = { + "formData": (buf, "dataset_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + assert response == {"message": "OK"} + assert db.session.query(SqlaTable).count() == num_datasets + 1 + + dataset = ( + db.session.query(SqlaTable).filter_by(table_name="birth_names_2").one() + ) + db.session.delete(dataset) + db.session.commit() + def test_import_dataset_overwrite(self): """ Dataset API: Test import existing dataset