Skip to content

Commit

Permalink
refactor move ChartDataResult enums to common (#17399)
Browse files Browse the repository at this point in the history
  • Loading branch information
ofekisr authored Nov 11, 2021
1 parent 0257cf7 commit 45480f7
Show file tree
Hide file tree
Showing 12 changed files with 74 additions and 72 deletions.
7 changes: 2 additions & 5 deletions superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,14 @@
)
from superset.commands.importers.exceptions import NoValidFilesFoundError
from superset.commands.importers.v1.utils import get_contents_from_bundle
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.exceptions import QueryObjectValidationError
from superset.extensions import event_logger, security_manager
from superset.models.slice import Slice
from superset.tasks.thumbnails import cache_chart_thumbnail
from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.core import (
ChartDataResultFormat,
ChartDataResultType,
json_int_dttm_ser,
)
from superset.utils.core import json_int_dttm_ser
from superset.utils.screenshots import ChartScreenshot
from superset.utils.urls import get_url_path
from superset.views.base_api import (
Expand Down
8 changes: 2 additions & 6 deletions superset/charts/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,8 @@

import pandas as pd

from superset.utils.core import (
ChartDataResultFormat,
DTTM_ALIAS,
extract_dataframe_dtypes,
get_metric_name,
)
from superset.common.chart_data import ChartDataResultFormat
from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name


def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]:
Expand Down
3 changes: 1 addition & 2 deletions superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
from marshmallow_enum import EnumField

from superset import app
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.db_engine_specs.base import builtin_time_grains
from superset.utils import schema as utils
from superset.utils.core import (
AnnotationType,
ChartDataResultFormat,
ChartDataResultType,
FilterOperator,
PostProcessingBoxplotWhiskerType,
PostProcessingContributionOrientation,
Expand Down
40 changes: 40 additions & 0 deletions superset/common/chart_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from enum import Enum


class ChartDataResultFormat(str, Enum):
"""
Chart data response format
"""

CSV = "csv"
JSON = "json"


class ChartDataResultType(str, Enum):
"""
Chart data response type
"""

COLUMNS = "columns"
FULL = "full"
QUERY = "query"
RESULTS = "results"
SAMPLES = "samples"
TIMEGRAINS = "timegrains"
POST_PROCESSED = "post_processed"
2 changes: 1 addition & 1 deletion superset/common/query_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
from flask_babel import _

from superset import app
from superset.common.chart_data import ChartDataResultType
from superset.common.db_query_status import QueryStatus
from superset.connectors.base.models import BaseDatasource
from superset.exceptions import QueryObjectValidationError
from superset.utils.core import (
ChartDataResultType,
extract_column_dtype,
extract_dataframe_dtypes,
ExtraFiltersReasonType,
Expand Down
3 changes: 1 addition & 2 deletions superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from superset import app, db, is_feature_enabled
from superset.annotation_layers.dao import AnnotationLayerDAO
from superset.charts.dao import ChartDAO
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.db_query_status import QueryStatus
from superset.common.query_actions import get_query_results
from superset.common.query_object import QueryObject
Expand All @@ -42,8 +43,6 @@
from superset.utils import csv
from superset.utils.cache import generate_cache_key, set_and_log_cache
from superset.utils.core import (
ChartDataResultFormat,
ChartDataResultType,
DatasourceDict,
DTTM_ALIAS,
error_msg_from_exception,
Expand Down
2 changes: 1 addition & 1 deletion superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
from pandas import DataFrame

from superset import app, db
from superset.common.chart_data import ChartDataResultType
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
from superset.exceptions import QueryObjectValidationError
from superset.typing import Metric, OrderBy
from superset.utils import pandas_postprocessing
from superset.utils.core import (
apply_max_row_limit,
ChartDataResultType,
DatasourceDict,
DTTM_ALIAS,
find_duplicates,
Expand Down
2 changes: 1 addition & 1 deletion superset/reports/commands/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from superset import app
from superset.commands.base import BaseCommand
from superset.commands.exceptions import CommandException
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.extensions import feature_flag_manager, machine_auth_provider_factory
from superset.models.reports import (
ReportDataFormat,
Expand Down Expand Up @@ -64,7 +65,6 @@
from superset.reports.notifications.base import NotificationContent
from superset.reports.notifications.exceptions import NotificationError
from superset.utils.celery import session_scope
from superset.utils.core import ChartDataResultFormat, ChartDataResultType
from superset.utils.csv import get_chart_csv_data, get_chart_dataframe
from superset.utils.screenshots import (
BaseScreenshot,
Expand Down
23 changes: 0 additions & 23 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,29 +174,6 @@ class GenericDataType(IntEnum):
# ROW = 7


class ChartDataResultFormat(str, Enum):
"""
Chart data response format
"""

CSV = "csv"
JSON = "json"


class ChartDataResultType(str, Enum):
"""
Chart data response type
"""

COLUMNS = "columns"
FULL = "full"
QUERY = "query"
RESULTS = "results"
SAMPLES = "samples"
TIMEGRAINS = "timegrains"
POST_PROCESSED = "post_processed"


class DatasourceDict(TypedDict):
type: str
id: int
Expand Down
23 changes: 12 additions & 11 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
viz,
)
from superset.charts.dao import ChartDAO
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.db_query_status import QueryStatus
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
Expand Down Expand Up @@ -459,18 +460,18 @@ def send_data_payload_response(viz_obj: BaseViz, payload: Any) -> FlaskResponse:
def generate_json(
self, viz_obj: BaseViz, response_type: Optional[str] = None
) -> FlaskResponse:
if response_type == utils.ChartDataResultFormat.CSV:
if response_type == ChartDataResultFormat.CSV:
return CsvResponse(
viz_obj.get_csv(), headers=generate_download_headers("csv")
)

if response_type == utils.ChartDataResultType.QUERY:
if response_type == ChartDataResultType.QUERY:
return self.get_query_string_response(viz_obj)

if response_type == utils.ChartDataResultType.RESULTS:
if response_type == ChartDataResultType.RESULTS:
return self.get_raw_results(viz_obj)

if response_type == utils.ChartDataResultType.SAMPLES:
if response_type == ChartDataResultType.SAMPLES:
return self.get_samples(viz_obj)

payload = viz_obj.get_payload()
Expand Down Expand Up @@ -598,19 +599,19 @@ def explore_json(
TODO: break into one endpoint for each return shape"""

response_type = utils.ChartDataResultFormat.JSON.value
responses: List[
Union[utils.ChartDataResultFormat, utils.ChartDataResultType]
] = list(utils.ChartDataResultFormat)
responses.extend(list(utils.ChartDataResultType))
response_type = ChartDataResultFormat.JSON.value
responses: List[Union[ChartDataResultFormat, ChartDataResultType]] = list(
ChartDataResultFormat
)
responses.extend(list(ChartDataResultType))
for response_option in responses:
if request.args.get(response_option) == "true":
response_type = response_option
break

# Verify user has permission to export CSV file
if (
response_type == utils.ChartDataResultFormat.CSV
response_type == ChartDataResultFormat.CSV
and not security_manager.can_access("can_csv", "Superset")
):
return json_error_response(
Expand All @@ -628,7 +629,7 @@ def explore_json(
# TODO: support CSV, SQL query and other non-JSON types
if (
is_feature_enabled("GLOBAL_ASYNC_QUERIES")
and response_type == utils.ChartDataResultFormat.JSON
and response_type == ChartDataResultFormat.JSON
):
# First, look for the chart query results in the cache.
try:
Expand Down
24 changes: 11 additions & 13 deletions tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,13 @@
from superset.models.dashboard import Dashboard
from superset.models.reports import ReportSchedule, ReportScheduleType
from superset.models.slice import Slice
from superset.utils import core as utils
from superset.utils.core import (
AnnotationType,
ChartDataResultFormat,
get_example_database,
get_example_default_schema,
get_main_database,
)

from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType

from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin
from tests.integration_tests.base_tests import (
Expand Down Expand Up @@ -1239,7 +1237,7 @@ def test_chart_data_sample_default_limit(self):
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
request_payload["result_type"] = ChartDataResultType.SAMPLES
del request_payload["queries"][0]["row_limit"]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
Expand All @@ -1258,7 +1256,7 @@ def test_chart_data_sample_custom_limit(self):
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
request_payload["result_type"] = ChartDataResultType.SAMPLES
request_payload["queries"][0]["row_limit"] = 10
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
Expand All @@ -1276,7 +1274,7 @@ def test_chart_data_sql_max_row_sample_limit(self):
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
request_payload["result_type"] = ChartDataResultType.SAMPLES
request_payload["queries"][0]["row_limit"] = 10000000
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
Expand Down Expand Up @@ -1326,7 +1324,7 @@ def test_chart_data_query_result_type(self):
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.QUERY
request_payload["result_type"] = ChartDataResultType.QUERY
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)

Expand Down Expand Up @@ -1453,7 +1451,7 @@ def test_chart_data_query_missing_filter(self):
request_payload["queries"][0]["filters"] = [
{"col": "non_existent_filter", "op": "==", "val": "foo"},
]
request_payload["result_type"] = utils.ChartDataResultType.QUERY
request_payload["result_type"] = ChartDataResultType.QUERY
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
response_payload = json.loads(rv.data.decode("utf-8"))
Expand Down Expand Up @@ -1532,7 +1530,7 @@ def test_chart_data_jinja_filter_request(self):
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.QUERY
request_payload["result_type"] = ChartDataResultType.QUERY
request_payload["queries"][0]["filters"] = [
{"col": "gender", "op": "==", "val": "boy"}
]
Expand Down Expand Up @@ -1574,7 +1572,7 @@ def test_chart_data_async_cached_sync_response(self):

class QueryContext:
result_format = ChartDataResultFormat.JSON
result_type = utils.ChartDataResultType.FULL
result_type = ChartDataResultType.FULL

cmd_run_val = {
"query_context": QueryContext(),
Expand All @@ -1585,7 +1583,7 @@ class QueryContext:
ChartDataCommand, "run", return_value=cmd_run_val
) as patched_run:
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.FULL
request_payload["result_type"] = ChartDataResultType.FULL
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
Expand Down Expand Up @@ -1997,8 +1995,8 @@ def test_chart_data_timegrains(self):
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"] = [
{"result_type": utils.ChartDataResultType.TIMEGRAINS},
{"result_type": utils.ChartDataResultType.COLUMNS},
{"result_type": ChartDataResultType.TIMEGRAINS},
{"result_type": ChartDataResultType.COLUMNS},
]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
Expand Down
9 changes: 2 additions & 7 deletions tests/integration_tests/query_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,13 @@

from superset import db
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_object import QueryObject
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import SqlMetric
from superset.extensions import cache_manager
from superset.utils.core import (
AdhocMetricExpressionType,
backend,
ChartDataResultFormat,
ChartDataResultType,
TimeRangeEndpoint,
)
from superset.utils.core import AdhocMetricExpressionType, backend, TimeRangeEndpoint
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
Expand Down

0 comments on commit 45480f7

Please sign in to comment.