From 23c8524c1de33ce7cea264769edba5c710b6039e Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Tue, 8 Feb 2022 12:14:36 +0200 Subject: [PATCH 1/3] feat(chart-data-api): download multiple csvs as zip --- superset/charts/data/api.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index 6983152e248f6..b53f0e2116cb1 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -18,10 +18,12 @@ import json import logging +from io import BytesIO from typing import Any, Dict, Optional, TYPE_CHECKING +from zipfile import ZipFile import simplejson -from flask import g, make_response, request +from flask import current_app, g, make_response, request, Response from flask_appbuilder.api import expose, protect from flask_babel import gettext as _ from marshmallow import ValidationError @@ -49,8 +51,6 @@ from superset.views.base_api import statsd_metrics if TYPE_CHECKING: - from flask import Response - from superset.common.query_context import QueryContext logger = logging.getLogger(__name__) @@ -350,9 +350,24 @@ def _send_chart_response( if not security_manager.can_access("can_csv", "Superset"): return self.response_403() - # return the first result - data = result["queries"][0]["data"] - return CsvResponse(data, headers=generate_download_headers("csv")) + if len(result["queries"]) == 1: + # return single query results csv format + data = result["queries"][0]["data"] + return CsvResponse(data, headers=generate_download_headers("csv")) + else: + # return multi-query csv results bundled as a zip file + encoding = current_app.config["CSV_EXPORT"].get("encoding", "utf-8") + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + for idx, result in enumerate(result["queries"]): + with bundle.open(f"query_{idx + 1}.csv", "w") as fp: + fp.write(result["data"].encode(encoding)) + buf.seek(0) + return Response( + buf, + headers=generate_download_headers("zip"), + mimetype="application/zip", + ) if result_format == ChartDataResultFormat.JSON: response_data = simplejson.dumps( From ee93f12cca15629a782ed3604ebd5e22f81352b9 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Tue, 8 Feb 2022 12:38:26 +0200 Subject: [PATCH 2/3] break out util --- superset/charts/data/api.py | 16 ++++++--------- superset/utils/core.py | 12 +++++++++++ .../charts/data/api_tests.py | 20 +++++++++++++++++++ 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index b53f0e2116cb1..19748813d0905 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -18,9 +18,7 @@ import json import logging -from io import BytesIO from typing import Any, Dict, Optional, TYPE_CHECKING -from zipfile import ZipFile import simplejson from flask import current_app, g, make_response, request, Response @@ -46,7 +44,7 @@ from superset.exceptions import QueryObjectValidationError from superset.extensions import event_logger from superset.utils.async_query_manager import AsyncQueryTokenException -from superset.utils.core import json_int_dttm_ser +from superset.utils.core import create_zip, json_int_dttm_ser from superset.views.base import CsvResponse, generate_download_headers from superset.views.base_api import statsd_metrics @@ -357,14 +355,12 @@ def _send_chart_response( else: # return multi-query csv results bundled as a zip file encoding = current_app.config["CSV_EXPORT"].get("encoding", "utf-8") - buf = BytesIO() - with ZipFile(buf, "w") as bundle: - for idx, result in enumerate(result["queries"]): - with bundle.open(f"query_{idx + 1}.csv", "w") as fp: - fp.write(result["data"].encode(encoding)) - buf.seek(0) + files = { + f"query_{idx + 1}.csv": result["data"].encode(encoding) + for idx, result in enumerate(result["queries"]) + } return Response( - buf, + create_zip(files), headers=generate_download_headers("zip"), mimetype="application/zip", ) diff --git a/superset/utils/core.py b/superset/utils/core.py index 4908fd98dc7ad..da69a89a80fd5 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -40,6 +40,7 @@ from email.mime.text import MIMEText from email.utils import formatdate from enum import Enum, IntEnum +from io import BytesIO from timeit import default_timer from types import TracebackType from typing import ( @@ -61,6 +62,7 @@ Union, ) from urllib.parse import unquote_plus +from zipfile import ZipFile import bleach import markdown as md @@ -1788,3 +1790,13 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int: if limit != 0: return min(max_limit, limit) return max_limit + + +def create_zip(files: Dict[str, Any]) -> BytesIO: + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + for filename, contents in files.items(): + with bundle.open(filename, "w") as fp: + fp.write(contents) + buf.seek(0) + return buf diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index d6ccd6aadfe54..6912f4697f611 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -20,8 +20,11 @@ import unittest import copy from datetime import datetime +from io import BytesIO from typing import Optional from unittest import mock +from zipfile import ZipFile + from flask import Response from tests.integration_tests.conftest import with_feature_flags from superset.models.sql_lab import Query @@ -243,6 +246,22 @@ def test_with_csv_result_format(self): self.query_context_payload["result_format"] = "csv" rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 200 + assert rv.mimetype == "text/csv" + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_multi_query_csv_result_format(self): + """ + Chart data API: Test chart data with multi-query CSV result format + """ + self.query_context_payload["result_format"] = "csv" + self.query_context_payload["queries"].append( + self.query_context_payload["queries"][0] + ) + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + assert rv.status_code == 200 + assert rv.mimetype == "application/zip" + zipfile = ZipFile(BytesIO(rv.data), "r") + assert zipfile.namelist() == ["query_1.csv", "query_2.csv"] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_csv_result_format_when_actor_not_permitted_for_csv__403(self): @@ -766,6 +785,7 @@ def test_chart_data_get(self): } ) rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data") + assert rv.mimetype == "application/json" data = json.loads(rv.data.decode("utf-8")) assert data["result"][0]["status"] == "success" assert data["result"][0]["rowcount"] == 2 From 9246e813dd454edff6f364f6bbb5b595f42f707e Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Tue, 8 Feb 2022 12:58:07 +0200 Subject: [PATCH 3/3] check for empty request --- superset/charts/data/api.py | 27 ++++++++++--------- .../charts/data/api_tests.py | 10 +++++++ 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index 19748813d0905..d6490421c273b 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -348,22 +348,25 @@ def _send_chart_response( if not security_manager.can_access("can_csv", "Superset"): return self.response_403() + if not result["queries"]: + return self.response_400(_("Empty query result")) + if len(result["queries"]) == 1: # return single query results csv format data = result["queries"][0]["data"] return CsvResponse(data, headers=generate_download_headers("csv")) - else: - # return multi-query csv results bundled as a zip file - encoding = current_app.config["CSV_EXPORT"].get("encoding", "utf-8") - files = { - f"query_{idx + 1}.csv": result["data"].encode(encoding) - for idx, result in enumerate(result["queries"]) - } - return Response( - create_zip(files), - headers=generate_download_headers("zip"), - mimetype="application/zip", - ) + + # return multi-query csv results bundled as a zip file + encoding = current_app.config["CSV_EXPORT"].get("encoding", "utf-8") + files = { + f"query_{idx + 1}.csv": result["data"].encode(encoding) + for idx, result in enumerate(result["queries"]) + } + return Response( + create_zip(files), + headers=generate_download_headers("zip"), + mimetype="application/zip", + ) if result_format == ChartDataResultFormat.JSON: response_data = simplejson.dumps( diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 6912f4697f611..6b047217bf1d2 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -238,6 +238,16 @@ def test_with_query_result_type__200(self): rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 200 + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_empty_request_with_csv_result_format(self): + """ + Chart data API: Test empty chart data with CSV result format + """ + self.query_context_payload["result_format"] = "csv" + self.query_context_payload["queries"] = [] + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + assert rv.status_code == 400 + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_csv_result_format(self): """