diff --git a/superset/charts/api.py b/superset/charts/api.py index 39b0c2dbf8980..c87b7bdda8dd2 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -47,6 +47,7 @@ from superset.charts.commands.export import ExportChartsCommand from superset.charts.commands.importers.dispatcher import ImportChartsCommand from superset.charts.commands.update import UpdateChartCommand +from superset.charts.commands.warm_up_cache import ChartWarmUpCacheCommand from superset.charts.filters import ( ChartAllTextFilter, ChartCertifiedFilter, @@ -59,6 +60,7 @@ ) from superset.charts.schemas import ( CHART_SCHEMAS, + ChartCacheWarmUpRequestSchema, ChartPostSchema, ChartPutSchema, get_delete_ids_schema, @@ -68,6 +70,7 @@ screenshot_query_schema, thumbnail_query_schema, ) +from superset.commands.exceptions import CommandException from superset.commands.importers.exceptions import ( IncorrectFormatError, NoValidFilesFoundError, @@ -118,6 +121,7 @@ def ensure_thumbnails_enabled(self) -> Optional[Response]: "thumbnail", "screenshot", "cache_screenshot", + "warm_up_cache", } class_permission_name = "Chart" method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP @@ -942,6 +946,63 @@ def remove_favorite(self, pk: int) -> Response: ChartDAO.remove_favorite(chart) return self.response(200, result="OK") + @expose("/warm_up_cache", methods=("PUT",)) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".warm_up_cache", + log_to_statsd=False, + ) + def warm_up_cache(self) -> Response: + """ + --- + put: + summary: >- + Warms up the cache for the chart + description: >- + Warms up the cache for the chart. + Note for slices a force refresh occurs. + In terms of the `extra_filters` these can be obtained from records in the JSON + encoded `logs.json` column associated with the `explore_json` action. + requestBody: + description: >- + Identifies the chart to warm up cache for, and any additional dashboard or + filter context to use. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ChartCacheWarmUpRequestSchema" + responses: + 200: + description: Each chart's warmup status + content: + application/json: + schema: + $ref: "#/components/schemas/ChartCacheWarmUpResponseSchema" + 400: + $ref: '#/components/responses/400' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + try: + body = ChartCacheWarmUpRequestSchema().load(request.json) + except ValidationError as error: + return self.response_400(message=error.messages) + try: + result = ChartWarmUpCacheCommand( + body["chart_id"], + body.get("dashboard_id"), + body.get("extra_filters"), + ).run() + return self.response(200, result=[result]) + except CommandException as ex: + return self.response(ex.status, message=ex.message) + @expose("/import/", methods=("POST",)) @protect() @statsd_metrics diff --git a/superset/charts/commands/exceptions.py b/superset/charts/commands/exceptions.py index 6d5c078b12289..1079cdca81c90 100644 --- a/superset/charts/commands/exceptions.py +++ b/superset/charts/commands/exceptions.py @@ -153,3 +153,8 @@ class ChartBulkDeleteFailedReportsExistError(ChartBulkDeleteFailedError): class ChartImportError(ImportFailedError): message = _("Import chart failed for an unknown reason") + + +class WarmUpCacheChartNotFoundError(CommandException): + status = 404 + message = _("Chart not found") diff --git a/superset/charts/commands/warm_up_cache.py b/superset/charts/commands/warm_up_cache.py new file mode 100644 index 0000000000000..6fe9f94ffa620 --- /dev/null +++ b/superset/charts/commands/warm_up_cache.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from typing import Any, Optional, Union + +import simplejson as json +from flask import g + +from superset.charts.commands.exceptions import WarmUpCacheChartNotFoundError +from superset.commands.base import BaseCommand +from superset.extensions import db +from superset.models.slice import Slice +from superset.utils.core import error_msg_from_exception +from superset.views.utils import get_dashboard_extra_filters, get_form_data, get_viz + + +class ChartWarmUpCacheCommand(BaseCommand): + # pylint: disable=too-many-arguments + def __init__( + self, + chart_or_id: Union[int, Slice], + dashboard_id: Optional[int], + extra_filters: Optional[str], + ): + self._chart_or_id = chart_or_id + self._dashboard_id = dashboard_id + self._extra_filters = extra_filters + + def run(self) -> dict[str, Any]: + self.validate() + chart: Slice = self._chart_or_id # type: ignore + try: + form_data = get_form_data(chart.id, use_slice_data=True)[0] + if self._dashboard_id: + form_data["extra_filters"] = ( + json.loads(self._extra_filters) + if self._extra_filters + else get_dashboard_extra_filters(chart.id, self._dashboard_id) + ) + + if not chart.datasource: + raise Exception("Chart's datasource does not exist") + + obj = get_viz( + datasource_type=chart.datasource.type, + datasource_id=chart.datasource.id, + form_data=form_data, + force=True, + ) + + # pylint: disable=assigning-non-slot + g.form_data = form_data + payload = obj.get_payload() + delattr(g, "form_data") + error = payload["errors"] or None + status = payload["status"] + except Exception as ex: # pylint: disable=broad-except + error = error_msg_from_exception(ex) + status = None + + return {"chart_id": chart.id, "viz_error": error, "viz_status": status} + + def validate(self) -> None: + if isinstance(self._chart_or_id, Slice): + return + chart = db.session.query(Slice).filter_by(id=self._chart_or_id).scalar() + if not chart: + raise WarmUpCacheChartNotFoundError() + self._chart_or_id = chart diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index a5e0a6c44cac0..1145d5be73694 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -1557,7 +1557,45 @@ class ImportV1ChartSchema(Schema): external_url = fields.String(allow_none=True) +class ChartCacheWarmUpRequestSchema(Schema): + chart_id = fields.Integer( + required=True, + metadata={"description": "The ID of the chart to warm up cache for"}, + ) + dashboard_id = fields.Integer( + metadata={ + "description": "The ID of the dashboard to get filters for when warming cache" + } + ) + extra_filters = fields.String( + metadata={"description": "Extra filters to apply when warming up cache"} + ) + + +class ChartCacheWarmUpResponseSingleSchema(Schema): + chart_id = fields.Integer( + metadata={"description": "The ID of the chart the status belongs to"} + ) + viz_error = fields.String( + metadata={"description": "Error that occurred when warming cache for chart"} + ) + viz_status = fields.String( + metadata={"description": "Status of the underlying query for the viz"} + ) + + +class ChartCacheWarmUpResponseSchema(Schema): + result = fields.List( + fields.Nested(ChartCacheWarmUpResponseSingleSchema), + metadata={ + "description": "A list of each chart's warmup status and errors if any" + }, + ) + + CHART_SCHEMAS = ( + ChartCacheWarmUpRequestSchema, + ChartCacheWarmUpResponseSchema, ChartDataQueryContextSchema, ChartDataResponseSchema, ChartDataAsyncResponseSchema, diff --git a/superset/datasets/api.py b/superset/datasets/api.py index b2457b066a64b..6e6cf38aad89e 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=too-many-lines import json import logging from datetime import datetime @@ -29,6 +30,7 @@ from marshmallow import ValidationError from superset import event_logger, is_feature_enabled +from superset.commands.exceptions import CommandException from superset.commands.importers.exceptions import NoValidFilesFoundError from superset.commands.importers.v1.utils import get_contents_from_bundle from superset.connectors.sqla.models import SqlaTable @@ -53,8 +55,11 @@ from superset.datasets.commands.importers.dispatcher import ImportDatasetsCommand from superset.datasets.commands.refresh import RefreshDatasetCommand from superset.datasets.commands.update import UpdateDatasetCommand +from superset.datasets.commands.warm_up_cache import DatasetWarmUpCacheCommand from superset.datasets.filters import DatasetCertifiedFilter, DatasetIsNullOrEmptyFilter from superset.datasets.schemas import ( + DatasetCacheWarmUpRequestSchema, + DatasetCacheWarmUpResponseSchema, DatasetDuplicateSchema, DatasetPostSchema, DatasetPutSchema, @@ -95,6 +100,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): "related_objects", "duplicate", "get_or_create_dataset", + "warm_up_cache", } list_columns = [ "id", @@ -244,6 +250,8 @@ class DatasetRestApi(BaseSupersetModelRestApi): "get_export_ids_schema": get_export_ids_schema, } openapi_spec_component_schemas = ( + DatasetCacheWarmUpRequestSchema, + DatasetCacheWarmUpResponseSchema, DatasetRelatedObjectsResponse, DatasetDuplicateSchema, GetOrCreateDatasetSchema, @@ -992,3 +1000,61 @@ def get_or_create_dataset(self) -> Response: exc_info=True, ) return self.response_422(message=ex.message) + + @expose("/warm_up_cache", methods=("PUT",)) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".warm_up_cache", + log_to_statsd=False, + ) + def warm_up_cache(self) -> Response: + """ + --- + put: + summary: >- + Warms up the cache for each chart powered by the given table + description: >- + Warms up the cache for the table. + Note for slices a force refresh occurs. + In terms of the `extra_filters` these can be obtained from records in the JSON + encoded `logs.json` column associated with the `explore_json` action. + requestBody: + description: >- + Identifies the database and table to warm up cache for, and any + additional dashboard or filter context to use. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/DatasetCacheWarmUpRequestSchema" + responses: + 200: + description: Each chart's warmup status + content: + application/json: + schema: + $ref: "#/components/schemas/DatasetCacheWarmUpResponseSchema" + 400: + $ref: '#/components/responses/400' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + try: + body = DatasetCacheWarmUpRequestSchema().load(request.json) + except ValidationError as error: + return self.response_400(message=error.messages) + try: + result = DatasetWarmUpCacheCommand( + body["db_name"], + body["table_name"], + body.get("dashboard_id"), + body.get("extra_filters"), + ).run() + return self.response(200, result=result) + except CommandException as ex: + return self.response(ex.status, message=ex.message) diff --git a/superset/datasets/commands/exceptions.py b/superset/datasets/commands/exceptions.py index e06e92802f04c..7c6ef86634f3a 100644 --- a/superset/datasets/commands/exceptions.py +++ b/superset/datasets/commands/exceptions.py @@ -212,3 +212,8 @@ class DatasetDuplicateFailedError(CreateFailedError): class DatasetForbiddenDataURI(ImportFailedError): message = _("Data URI is not allowed.") + + +class WarmUpCacheTableNotFoundError(CommandException): + status = 404 + message = _("The provided table was not found in the provided database") diff --git a/superset/datasets/commands/warm_up_cache.py b/superset/datasets/commands/warm_up_cache.py new file mode 100644 index 0000000000000..62044e7224f36 --- /dev/null +++ b/superset/datasets/commands/warm_up_cache.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from typing import Any, Optional + +from superset.charts.commands.warm_up_cache import ChartWarmUpCacheCommand +from superset.commands.base import BaseCommand +from superset.connectors.sqla.models import SqlaTable +from superset.datasets.commands.exceptions import WarmUpCacheTableNotFoundError +from superset.extensions import db +from superset.models.core import Database +from superset.models.slice import Slice + + +class DatasetWarmUpCacheCommand(BaseCommand): + # pylint: disable=too-many-arguments + def __init__( + self, + db_name: str, + table_name: str, + dashboard_id: Optional[int], + extra_filters: Optional[str], + ): + self._db_name = db_name + self._table_name = table_name + self._dashboard_id = dashboard_id + self._extra_filters = extra_filters + self._charts: list[Slice] = [] + + def run(self) -> list[dict[str, Any]]: + self.validate() + return [ + ChartWarmUpCacheCommand( + chart, self._dashboard_id, self._extra_filters + ).run() + for chart in self._charts + ] + + def validate(self) -> None: + table = ( + db.session.query(SqlaTable) + .join(Database) + .filter( + Database.database_name == self._db_name, + SqlaTable.table_name == self._table_name, + ) + ).one_or_none() + if not table: + raise WarmUpCacheTableNotFoundError() + self._charts = ( + db.session.query(Slice) + .filter_by(datasource_id=table.id, datasource_type=table.type) + .all() + ) diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py index 9a2af980666e8..f95897ce59b15 100644 --- a/superset/datasets/schemas.py +++ b/superset/datasets/schemas.py @@ -254,3 +254,43 @@ class Meta: # pylint: disable=too-few-public-methods model = Dataset load_instance = True include_relationships = True + + +class DatasetCacheWarmUpRequestSchema(Schema): + db_name = fields.String( + required=True, + metadata={"description": "The name of the database where the table is located"}, + ) + table_name = fields.String( + required=True, + metadata={"description": "The name of the table to warm up cache for"}, + ) + dashboard_id = fields.Integer( + metadata={ + "description": "The ID of the dashboard to get filters for when warming cache" + } + ) + extra_filters = fields.String( + metadata={"description": "Extra filters to apply when warming up cache"} + ) + + +class DatasetCacheWarmUpResponseSingleSchema(Schema): + chart_id = fields.Integer( + metadata={"description": "The ID of the chart the status belongs to"} + ) + viz_error = fields.String( + metadata={"description": "Error that occurred when warming cache for chart"} + ) + viz_status = fields.String( + metadata={"description": "Status of the underlying query for the viz"} + ) + + +class DatasetCacheWarmUpResponseSchema(Schema): + result = fields.List( + fields.Nested(DatasetCacheWarmUpResponseSingleSchema), + metadata={ + "description": "A list of each chart's warmup status and errors if any" + }, + ) diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py index 448271269a209..68b5657a22545 100644 --- a/superset/tasks/cache.py +++ b/superset/tasks/cache.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import json import logging from typing import Any, Optional, Union from urllib import request @@ -36,22 +37,20 @@ logger.setLevel(logging.INFO) -def get_url(chart: Slice, dashboard: Optional[Dashboard] = None) -> str: - """Return external URL for warming up a given chart/table cache.""" - with app.test_request_context(): - baseurl = "{WEBDRIVER_BASEURL}".format(**app.config) - url = f"{baseurl}superset/warm_up_cache/?slice_id={chart.id}" - if dashboard: - url += f"&dashboard_id={dashboard.id}" - return url +def get_payload(chart: Slice, dashboard: Optional[Dashboard] = None) -> dict[str, int]: + """Return payload for warming up a given chart/table cache.""" + payload = {"chart_id": chart.id} + if dashboard: + payload["dashboard_id"] = dashboard.id + return payload class Strategy: # pylint: disable=too-few-public-methods """ A cache warm up strategy. - Each strategy defines a `get_urls` method that returns a list of URLs to - be fetched from the `/superset/warm_up_cache/` endpoint. + Each strategy defines a `get_payloads` method that returns a list of payloads to + send to the `/api/v1/chart/warm_up_cache` endpoint. Strategies can be configured in `superset/config.py`: @@ -72,8 +71,8 @@ class Strategy: # pylint: disable=too-few-public-methods def __init__(self) -> None: pass - def get_urls(self) -> list[str]: - raise NotImplementedError("Subclasses must implement get_urls!") + def get_payloads(self) -> list[dict[str, int]]: + raise NotImplementedError("Subclasses must implement get_payloads!") class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods @@ -94,11 +93,11 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods name = "dummy" - def get_urls(self) -> list[str]: + def get_payloads(self) -> list[dict[str, int]]: session = db.create_scoped_session() charts = session.query(Slice).all() - return [get_url(chart) for chart in charts] + return [get_payload(chart) for chart in charts] class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-methods @@ -126,8 +125,8 @@ def __init__(self, top_n: int = 5, since: str = "7 days ago") -> None: self.top_n = top_n self.since = parse_human_datetime(since) if since else None - def get_urls(self) -> list[str]: - urls = [] + def get_payloads(self) -> list[dict[str, int]]: + payloads = [] session = db.create_scoped_session() records = ( @@ -142,9 +141,9 @@ def get_urls(self) -> list[str]: dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all() for dashboard in dashboards: for chart in dashboard.slices: - urls.append(get_url(chart, dashboard)) + payloads.append(get_payload(chart, dashboard)) - return urls + return payloads class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods @@ -169,8 +168,8 @@ def __init__(self, tags: Optional[list[str]] = None) -> None: super().__init__() self.tags = tags or [] - def get_urls(self) -> list[str]: - urls = [] + def get_payloads(self) -> list[dict[str, int]]: + payloads = [] session = db.create_scoped_session() tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all() @@ -191,7 +190,7 @@ def get_urls(self) -> list[str]: tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)) for dashboard in tagged_dashboards: for chart in dashboard.slices: - urls.append(get_url(chart)) + payloads.append(get_payload(chart)) # add charts that are tagged tagged_objects = ( @@ -207,35 +206,46 @@ def get_urls(self) -> list[str]: chart_ids = [tagged_object.object_id for tagged_object in tagged_objects] tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids)) for chart in tagged_charts: - urls.append(get_url(chart)) + payloads.append(get_payload(chart)) - return urls + return payloads strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy] @celery_app.task(name="fetch_url") -def fetch_url(url: str, headers: dict[str, str]) -> dict[str, str]: +def fetch_url(data: str, headers: dict[str, str]) -> dict[str, str]: """ Celery job to fetch url """ result = {} try: - logger.info("Fetching %s", url) - req = request.Request(url, headers=headers) + baseurl = "{WEBDRIVER_BASEURL}".format(**app.config) + url = f"{baseurl}api/v1/chart/warm_up_cache" + logger.info("Fetching %s with payload %s", url, data) + req = request.Request( + url, data=bytes(data, "utf-8"), headers=headers, method="PUT" + ) response = request.urlopen( # pylint: disable=consider-using-with req, timeout=600 ) - logger.info("Fetched %s, status code: %s", url, response.code) + logger.info( + "Fetched %s with payload %s, status code: %s", url, data, response.code + ) if response.code == 200: - result = {"success": url, "response": response.read().decode("utf-8")} + result = {"success": data, "response": response.read().decode("utf-8")} else: - result = {"error": url, "status_code": response.code} - logger.error("Error fetching %s, status code: %s", url, response.code) + result = {"error": data, "status_code": response.code} + logger.error( + "Error fetching %s with payload %s, status code: %s", + url, + data, + response.code, + ) except URLError as err: logger.exception("Error warming up cache!") - result = {"error": url, "exception": str(err)} + result = {"error": data, "exception": str(err)} return result @@ -270,16 +280,20 @@ def cache_warmup( user = security_manager.get_user_by_username(app.config["THUMBNAIL_SELENIUM_USER"]) cookies = MachineAuthProvider.get_auth_cookies(user) - headers = {"Cookie": f"session={cookies.get('session', '')}"} + headers = { + "Cookie": f"session={cookies.get('session', '')}", + "Content-Type": "application/json", + } results: dict[str, list[str]] = {"scheduled": [], "errors": []} - for url in strategy.get_urls(): + for payload in strategy.get_payloads(): try: - logger.info("Scheduling %s", url) - fetch_url.delay(url, headers) - results["scheduled"].append(url) + payload = json.dumps(payload) + logger.info("Scheduling %s", payload) + fetch_url.delay(payload, headers) + results["scheduled"].append(payload) except SchedulingError: - logger.exception("Error scheduling fetch_url: %s", url) - results["errors"].append(url) + logger.exception("Error scheduling fetch_url for payload: %s", payload) + results["errors"].append(payload) return results diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 60633c8894471..69e99978e5a29 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -33,6 +33,7 @@ from superset.reports.models import ReportSchedule, ReportScheduleType from superset.models.slice import Slice from superset.utils.core import get_example_default_schema +from superset.utils.database import get_example_database from tests.integration_tests.conftest import with_feature_flags from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin @@ -199,7 +200,12 @@ def test_info_security_chart(self): rv = self.get_assert_metric(uri, "info") data = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 200 - assert set(data["permissions"]) == {"can_read", "can_write", "can_export"} + assert set(data["permissions"]) == { + "can_read", + "can_write", + "can_export", + "can_warm_up_cache", + } def create_chart_import(self): buf = BytesIO() @@ -1682,3 +1688,85 @@ def test_gets_owned_created_favorited_by_me_filter(self): assert data["result"][0]["slice_name"] == "name0" assert data["result"][0]["datasource_id"] == 1 + + @pytest.mark.usefixtures( + "load_energy_table_with_slice", "load_birth_names_dashboard_with_slices" + ) + def test_warm_up_cache(self): + self.login() + slc = self.get_slice("Girls", db.session) + rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": slc.id}) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + + self.assertEqual( + data["result"], + [{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}], + ) + + dashboard = self.get_dash_by_slug("births") + + rv = self.client.put( + "/api/v1/chart/warm_up_cache", + json={"chart_id": slc.id, "dashboard_id": dashboard.id}, + ) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual( + data["result"], + [{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}], + ) + + rv = self.client.put( + "/api/v1/chart/warm_up_cache", + json={ + "chart_id": slc.id, + "dashboard_id": dashboard.id, + "extra_filters": json.dumps( + [{"col": "name", "op": "in", "val": ["Jennifer"]}] + ), + }, + ) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual( + data["result"], + [{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}], + ) + + def test_warm_up_cache_chart_id_required(self): + self.login() + rv = self.client.put("/api/v1/chart/warm_up_cache", json={"dashboard_id": 1}) + self.assertEqual(rv.status_code, 400) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual( + data, + {"message": {"chart_id": ["Missing data for required field."]}}, + ) + + def test_warm_up_cache_chart_not_found(self): + self.login() + rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": 99999}) + self.assertEqual(rv.status_code, 404) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(data, {"message": "Chart not found"}) + + def test_warm_up_cache_payload_validation(self): + self.login() + rv = self.client.put( + "/api/v1/chart/warm_up_cache", + json={"chart_id": "id", "dashboard_id": "id", "extra_filters": 4}, + ) + self.assertEqual(rv.status_code, 400) + data = json.loads(rv.data.decode("utf-8")) + print(data) + self.assertEqual( + data, + { + "message": { + "chart_id": ["Not a valid integer."], + "dashboard_id": ["Not a valid integer."], + "extra_filters": ["Not a valid string."], + } + }, + ) diff --git a/tests/integration_tests/charts/commands_tests.py b/tests/integration_tests/charts/commands_tests.py index 4d365d56b53a0..217b1655a5f05 100644 --- a/tests/integration_tests/charts/commands_tests.py +++ b/tests/integration_tests/charts/commands_tests.py @@ -23,16 +23,24 @@ from superset import db, security_manager from superset.charts.commands.create import CreateChartCommand -from superset.charts.commands.exceptions import ChartNotFoundError +from superset.charts.commands.exceptions import ( + ChartNotFoundError, + WarmUpCacheChartNotFoundError, +) from superset.charts.commands.export import ExportChartsCommand from superset.charts.commands.importers.v1 import ImportChartsCommand from superset.charts.commands.update import UpdateChartCommand +from superset.charts.commands.warm_up_cache import ChartWarmUpCacheCommand from superset.commands.exceptions import CommandInvalidError from superset.commands.importers.exceptions import IncorrectVersionError from superset.connectors.sqla.models import SqlaTable from superset.models.core import Database from superset.models.slice import Slice from tests.integration_tests.base_tests import SupersetTestCase +from tests.integration_tests.fixtures.birth_names_dashboard import ( + load_birth_names_dashboard_with_slices, + load_birth_names_data, +) from tests.integration_tests.fixtures.energy_dashboard import ( load_energy_table_data, load_energy_table_with_slice, @@ -442,3 +450,23 @@ def test_query_context_update_command(self, mock_sm_g, mock_g): assert chart.query_context == query_context assert len(chart.owners) == 1 assert chart.owners[0] == admin + + +class TestChartWarmUpCacheCommand(SupersetTestCase): + def test_warm_up_cache_command_chart_not_found(self): + with self.assertRaises(WarmUpCacheChartNotFoundError): + ChartWarmUpCacheCommand(99999, None, None).run() + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_warm_up_cache(self): + slc = self.get_slice("Girls", db.session) + result = ChartWarmUpCacheCommand(slc.id, None, None).run() + self.assertEqual( + result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"} + ) + + # can just pass in chart as well + result = ChartWarmUpCacheCommand(slc, None, None).run() + self.assertEqual( + result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"} + ) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 55fda1af65fd3..2f55a1e97815e 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -39,6 +39,7 @@ from superset.datasets.models import Dataset from superset.extensions import db, security_manager from superset.models.core import Database +from superset.models.slice import Slice from superset.utils.core import backend, get_example_default_schema from superset.utils.database import get_example_database, get_main_database from superset.utils.dict_import_export import export_to_dict @@ -514,6 +515,7 @@ def test_info_security_dataset(self): "can_export", "can_duplicate", "can_get_or_create_dataset", + "can_warm_up_cache", } def test_create_dataset_item(self): @@ -2501,3 +2503,117 @@ def test_get_or_create_dataset_creates_table(self): with examples_db.get_sqla_engine_with_context() as engine: engine.execute("DROP TABLE test_create_sqla_table_api") db.session.commit() + + @pytest.mark.usefixtures( + "load_energy_table_with_slice", "load_birth_names_dashboard_with_slices" + ) + def test_warm_up_cache(self): + """ + Dataset API: Test warm up cache endpoint + """ + self.login() + energy_table = self.get_energy_usage_dataset() + energy_charts = ( + db.session.query(Slice) + .filter( + Slice.datasource_id == energy_table.id, Slice.datasource_type == "table" + ) + .all() + ) + rv = self.client.put( + "/api/v1/dataset/warm_up_cache", + json={ + "table_name": "energy_usage", + "db_name": get_example_database().database_name, + }, + ) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual( + len(data["result"]), + len(energy_charts), + ) + for chart_result in data["result"]: + assert "chart_id" in chart_result + assert "viz_error" in chart_result + assert "viz_status" in chart_result + + # With dashboard id + dashboard = self.get_dash_by_slug("births") + birth_table = self.get_birth_names_dataset() + birth_charts = ( + db.session.query(Slice) + .filter( + Slice.datasource_id == birth_table.id, Slice.datasource_type == "table" + ) + .all() + ) + rv = self.client.put( + "/api/v1/dataset/warm_up_cache", + json={ + "table_name": "birth_names", + "db_name": get_example_database().database_name, + "dashboard_id": dashboard.id, + }, + ) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual( + len(data["result"]), + len(birth_charts), + ) + for chart_result in data["result"]: + assert "chart_id" in chart_result + assert "viz_error" in chart_result + assert "viz_status" in chart_result + + # With extra filters + rv = self.client.put( + "/api/v1/dataset/warm_up_cache", + json={ + "table_name": "birth_names", + "db_name": get_example_database().database_name, + "dashboard_id": dashboard.id, + "extra_filters": json.dumps( + [{"col": "name", "op": "in", "val": ["Jennifer"]}] + ), + }, + ) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual( + len(data["result"]), + len(birth_charts), + ) + for chart_result in data["result"]: + assert "chart_id" in chart_result + assert "viz_error" in chart_result + assert "viz_status" in chart_result + + def test_warm_up_cache_db_and_table_name_required(self): + self.login() + rv = self.client.put("/api/v1/dataset/warm_up_cache", json={"dashboard_id": 1}) + self.assertEqual(rv.status_code, 400) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual( + data, + { + "message": { + "db_name": ["Missing data for required field."], + "table_name": ["Missing data for required field."], + } + }, + ) + + def test_warm_up_cache_table_not_found(self): + self.login() + rv = self.client.put( + "/api/v1/dataset/warm_up_cache", + json={"table_name": "not_here", "db_name": "abc"}, + ) + self.assertEqual(rv.status_code, 404) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual( + data, + {"message": "The provided table was not found in the provided database"}, + ) diff --git a/tests/integration_tests/datasets/commands_tests.py b/tests/integration_tests/datasets/commands_tests.py index 953c34059fdd8..34a0625b36926 100644 --- a/tests/integration_tests/datasets/commands_tests.py +++ b/tests/integration_tests/datasets/commands_tests.py @@ -31,13 +31,20 @@ from superset.datasets.commands.exceptions import ( DatasetInvalidError, DatasetNotFoundError, + WarmUpCacheTableNotFoundError, ) from superset.datasets.commands.export import ExportDatasetsCommand from superset.datasets.commands.importers import v0, v1 +from superset.datasets.commands.warm_up_cache import DatasetWarmUpCacheCommand from superset.models.core import Database +from superset.models.slice import Slice from superset.utils.core import get_example_default_schema from superset.utils.database import get_example_database from tests.integration_tests.base_tests import SupersetTestCase +from tests.integration_tests.fixtures.birth_names_dashboard import ( + load_birth_names_dashboard_with_slices, + load_birth_names_data, +) from tests.integration_tests.fixtures.energy_dashboard import ( load_energy_table_data, load_energy_table_with_slice, @@ -575,3 +582,28 @@ def test_create_dataset_command(self, mock_g, mock_g2): with examples_db.get_sqla_engine_with_context() as engine: engine.execute("DROP TABLE test_create_dataset_command") db.session.commit() + + +class TestDatasetWarmUpCacheCommand(SupersetTestCase): + def test_warm_up_cache_command_table_not_found(self): + with self.assertRaises(WarmUpCacheTableNotFoundError): + DatasetWarmUpCacheCommand("not", "here", None, None).run() + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_warm_up_cache(self): + birth_table = self.get_birth_names_dataset() + birth_charts = ( + db.session.query(Slice) + .filter( + Slice.datasource_id == birth_table.id, Slice.datasource_type == "table" + ) + .all() + ) + results = DatasetWarmUpCacheCommand( + get_example_database().database_name, "birth_names", None, None + ).run() + self.assertEqual(len(results), len(birth_charts)) + for chart_result in results: + assert "chart_id" in chart_result + assert "viz_error" in chart_result + assert "viz_status" in chart_result diff --git a/tests/integration_tests/strategy_tests.py b/tests/integration_tests/strategy_tests.py index f6d664c649971..6fec16ca7475b 100644 --- a/tests/integration_tests/strategy_tests.py +++ b/tests/integration_tests/strategy_tests.py @@ -76,14 +76,11 @@ def test_top_n_dashboards_strategy(self): self.client.get(f"/superset/dashboard/{dash.id}/") strategy = TopNDashboardsStrategy(1) - result = sorted(strategy.get_urls()) - expected = sorted( - [ - f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}&dashboard_id={dash.id}" - for slc in dash.slices - ] - ) - self.assertEqual(result, expected) + result = strategy.get_payloads() + expected = [ + {"chart_id": chart.id, "dashboard_id": dash.id} for chart in dash.slices + ] + self.assertCountEqual(result, expected) def reset_tag(self, tag): """Remove associated object from tag, used to reset tests""" @@ -95,57 +92,52 @@ def reset_tag(self, tag): @pytest.mark.usefixtures( "load_unicode_dashboard_with_slice", "load_birth_names_dashboard_with_slices" ) - def test_dashboard_tags(self): + def test_dashboard_tags_strategy(self): tag1 = get_tag("tag1", db.session, TagTypes.custom) # delete first to make test idempotent self.reset_tag(tag1) strategy = DashboardTagsStrategy(["tag1"]) - result = sorted(strategy.get_urls()) + result = strategy.get_payloads() expected = [] self.assertEqual(result, expected) # tag dashboard 'births' with `tag1` tag1 = get_tag("tag1", db.session, TagTypes.custom) dash = self.get_dash_by_slug("births") - tag1_urls = sorted( - [ - f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}" - for slc in dash.slices - ] - ) + tag1_urls = [{"chart_id": chart.id} for chart in dash.slices] tagged_object = TaggedObject( tag_id=tag1.id, object_id=dash.id, object_type=ObjectTypes.dashboard ) db.session.add(tagged_object) db.session.commit() - self.assertEqual(sorted(strategy.get_urls()), tag1_urls) + self.assertCountEqual(strategy.get_payloads(), tag1_urls) strategy = DashboardTagsStrategy(["tag2"]) tag2 = get_tag("tag2", db.session, TagTypes.custom) self.reset_tag(tag2) - result = sorted(strategy.get_urls()) + result = strategy.get_payloads() expected = [] self.assertEqual(result, expected) # tag first slice dash = self.get_dash_by_slug("unicode-test") - slc = dash.slices[0] - tag2_urls = [f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}"] - object_id = slc.id + chart = dash.slices[0] + tag2_urls = [{"chart_id": chart.id}] + object_id = chart.id tagged_object = TaggedObject( tag_id=tag2.id, object_id=object_id, object_type=ObjectTypes.chart ) db.session.add(tagged_object) db.session.commit() - result = sorted(strategy.get_urls()) - self.assertEqual(result, tag2_urls) + result = strategy.get_payloads() + self.assertCountEqual(result, tag2_urls) strategy = DashboardTagsStrategy(["tag1", "tag2"]) - result = sorted(strategy.get_urls()) - expected = sorted(tag1_urls + tag2_urls) - self.assertEqual(result, expected) + result = strategy.get_payloads() + expected = tag1_urls + tag2_urls + self.assertCountEqual(result, expected)