diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index 0d0758819ed02..8e58440c76299 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -18,7 +18,7 @@ import json import logging -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, TYPE_CHECKING, Union import simplejson from flask import current_app, make_response, request, Response @@ -44,6 +44,7 @@ from superset.dao.exceptions import DatasourceNotFound from superset.exceptions import QueryObjectValidationError from superset.extensions import event_logger +from superset.models.sql_lab import Query from superset.utils.async_query_manager import AsyncQueryTokenException from superset.utils.core import create_zip, get_user_id, json_int_dttm_ser from superset.views.base import CsvResponse, generate_download_headers, XlsxResponse @@ -342,7 +343,7 @@ def _send_chart_response( self, result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None, - datasource: Optional[BaseDatasource] = None, + datasource: Optional[Union[BaseDatasource, Query]] = None, ) -> Response: result_type = result["query_context"].result_type result_format = result["query_context"].result_format @@ -405,7 +406,7 @@ def _get_data_response( command: ChartDataCommand, force_cached: bool = False, form_data: Optional[Dict[str, Any]] = None, - datasource: Optional[BaseDatasource] = None, + datasource: Optional[Union[BaseDatasource, Query]] = None, ) -> Response: try: result = command.run(force_cached=force_cached) diff --git a/superset/charts/post_processing.py b/superset/charts/post_processing.py index fd10930db0358..1165769fc8df4 100644 --- a/superset/charts/post_processing.py +++ b/superset/charts/post_processing.py @@ -27,7 +27,7 @@ """ from io import StringIO -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import pandas as pd from flask_babel import gettext as __ @@ -42,6 +42,7 @@ if TYPE_CHECKING: from superset.connectors.base.models import BaseDatasource + from superset.models.sql_lab import Query def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]: @@ -223,7 +224,7 @@ def list_unique_values(series: pd.Series) -> str: def pivot_table_v2( df: pd.DataFrame, form_data: Dict[str, Any], - datasource: Optional["BaseDatasource"] = None, + datasource: Optional[Union["BaseDatasource", "Query"]] = None, ) -> pd.DataFrame: """ Pivot table v2. @@ -249,7 +250,7 @@ def pivot_table_v2( def pivot_table( df: pd.DataFrame, form_data: Dict[str, Any], - datasource: Optional["BaseDatasource"] = None, + datasource: Optional[Union["BaseDatasource", "Query"]] = None, ) -> pd.DataFrame: """ Pivot table (v1). @@ -285,7 +286,9 @@ def pivot_table( def table( df: pd.DataFrame, form_data: Dict[str, Any], - datasource: Optional["BaseDatasource"] = None, # pylint: disable=unused-argument + datasource: Optional[ # pylint: disable=unused-argument + Union["BaseDatasource", "Query"] + ] = None, ) -> pd.DataFrame: """ Table. @@ -314,7 +317,7 @@ def table( def apply_post_process( result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None, - datasource: Optional["BaseDatasource"] = None, + datasource: Optional[Union["BaseDatasource", "Query"]] = None, ) -> Dict[Any, Any]: form_data = form_data or {} diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 3b5f171f41b7f..976ee177f94e5 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -246,6 +246,7 @@ def data(self) -> Dict[str, Any]: "database": {"id": self.database_id, "backend": self.database.backend}, "order_by_choices": order_by_choices, "schema": self.schema, + "verbose_map": {}, } def raise_for_access(self) -> None: diff --git a/superset/utils/core.py b/superset/utils/core.py index d229942834e81..460c17b949dac 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -126,6 +126,7 @@ if TYPE_CHECKING: from superset.connectors.base.models import BaseColumn, BaseDatasource + from superset.models.sql_lab import Query logging.getLogger("MARKDOWN").setLevel(logging.INFO) logger = logging.getLogger(__name__) @@ -1711,7 +1712,7 @@ def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]: def extract_dataframe_dtypes( df: pd.DataFrame, - datasource: Optional["BaseDatasource"] = None, + datasource: Optional[Union[BaseDatasource, Query]] = None, ) -> List[GenericDataType]: """Serialize pandas/numpy dtypes to generic types""" @@ -1731,13 +1732,13 @@ def extract_dataframe_dtypes( if datasource: for column in datasource.columns: if isinstance(column, dict): - columns_by_name[column.get("column_name")] = column + columns_by_name[column.get("column_name")] = column # type: ignore else: columns_by_name[column.column_name] = column generic_types: List[GenericDataType] = [] for column in df.columns: - column_object = columns_by_name.get(column) + column_object = columns_by_name.get(column) # type: ignore series = df[column] inferred_type = infer_dtype(series) if isinstance(column_object, dict):