Skip to content

Commit

Permalink
refactor: chartDataCommand - remove the responsibly of creating query…
Browse files Browse the repository at this point in the history
… context from command (apache#17461)
  • Loading branch information
ofekisr authored Nov 17, 2021
1 parent c54027a commit 3f2129b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 20 deletions.
26 changes: 20 additions & 6 deletions superset/charts/data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader
from superset.charts.post_processing import apply_post_process
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.exceptions import QueryObjectValidationError
from superset.extensions import event_logger
Expand All @@ -49,6 +50,8 @@
if TYPE_CHECKING:
from flask import Response

from superset.common.query_context import QueryContext

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -130,8 +133,8 @@ def get_data(self, pk: int) -> Response:
json_body["result_type"] = request.args.get("type", ChartDataResultType.FULL)

try:
command = ChartDataCommand()
query_context = command.set_query_context(json_body)
query_context = self._create_query_context_from_form(json_body)
command = ChartDataCommand(query_context)
command.validate()
except QueryObjectValidationError as error:
return self.response_400(message=error.message)
Expand Down Expand Up @@ -216,8 +219,8 @@ def data(self) -> Response:
return self.response_400(message=_("Request is not JSON"))

try:
command = ChartDataCommand()
query_context = command.set_query_context(json_body)
query_context = self._create_query_context_from_form(json_body)
command = ChartDataCommand(query_context)
command.validate()
except QueryObjectValidationError as error:
return self.response_400(message=error.message)
Expand Down Expand Up @@ -278,10 +281,10 @@ def data_from_cache(self, cache_key: str) -> Response:
500:
$ref: '#/components/responses/500'
"""
command = ChartDataCommand()
try:
cached_data = self._load_query_context_form_from_cache(cache_key)
command.set_query_context(cached_data)
query_context = self._create_query_context_from_form(cached_data)
command = ChartDataCommand(query_context)
command.validate()
except ChartDataCacheLoadError:
return self.response_404()
Expand Down Expand Up @@ -374,3 +377,14 @@ def _get_data_response(
# pylint: disable=invalid-name, no-self-use
def _load_query_context_form_from_cache(self, cache_key: str) -> Dict[str, Any]:
return QueryContextCacheLoader.load(cache_key)

# pylint: disable=no-self-use
def _create_query_context_from_form(
self, form_data: Dict[str, Any]
) -> QueryContext:
try:
return ChartDataQueryContextSchema().load(form_data)
except KeyError as ex:
raise ValidationError("Request is incorrect") from ex
except ValidationError as error:
raise error
14 changes: 3 additions & 11 deletions superset/charts/data/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
from typing import Any, Dict, Optional

from flask import Request
from marshmallow import ValidationError

from superset.charts.commands.exceptions import (
ChartDataCacheLoadError,
ChartDataQueryFailedError,
)
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.commands.base import BaseCommand
from superset.common.query_context import QueryContext
from superset.exceptions import CacheLoadError
Expand All @@ -37,6 +35,9 @@
class ChartDataCommand(BaseCommand):
_query_context: QueryContext

def __init__(self, query_context: QueryContext):
self._query_context = query_context

def run(self, **kwargs: Any) -> Dict[str, Any]:
# caching is handled in query_context.get_df_payload
# (also evals `force` property)
Expand All @@ -63,15 +64,6 @@ def run(self, **kwargs: Any) -> Dict[str, Any]:

return return_value

def set_query_context(self, form_data: Dict[str, Any]) -> QueryContext:
try:
self._query_context = ChartDataQueryContextSchema().load(form_data)
except KeyError as ex:
raise ValidationError("Request is incorrect") from ex
except ValidationError as error:
raise error
return self._query_context

def validate(self) -> None:
self._query_context.raise_for_access()

Expand Down
21 changes: 18 additions & 3 deletions superset/tasks/async_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import copy
import logging
from typing import Any, cast, Dict, Optional
from typing import Any, cast, Dict, Optional, TYPE_CHECKING

from celery.exceptions import SoftTimeLimitExceeded
from flask import current_app, g
from marshmallow import ValidationError

from superset.charts.schemas import ChartDataQueryContextSchema
from superset.exceptions import SupersetVizException
from superset.extensions import (
async_query_manager,
Expand All @@ -32,6 +35,9 @@
from superset.utils.cache import generate_cache_key, set_and_log_cache
from superset.views.utils import get_datasource_info, get_viz

if TYPE_CHECKING:
from superset.common.query_context import QueryContext

logger = logging.getLogger(__name__)
query_timeout = current_app.config[
"SQLLAB_ASYNC_TIME_LIMIT_SEC"
Expand All @@ -50,6 +56,15 @@ def set_form_data(form_data: Dict[str, Any]) -> None:
g.form_data = form_data


def _create_query_context_from_form(form_data: Dict[str, Any]) -> QueryContext:
try:
return ChartDataQueryContextSchema().load(form_data)
except KeyError as ex:
raise ValidationError("Request is incorrect") from ex
except ValidationError as error:
raise error


@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
def load_chart_data_into_cache(
job_metadata: Dict[str, Any], form_data: Dict[str, Any],
Expand All @@ -60,8 +75,8 @@ def load_chart_data_into_cache(
try:
ensure_user_is_set(job_metadata.get("user_id"))
set_form_data(form_data)
command = ChartDataCommand()
command.set_query_context(form_data)
query_context = _create_query_context_from_form(form_data)
command = ChartDataCommand(query_context)
result = command.run(cache=True)
cache_key = result["cache_key"]
result_url = f"/api/v1/chart/data/{cache_key}"
Expand Down

0 comments on commit 3f2129b

Please sign in to comment.