Skip to content

Commit

Permalink
feat: Implement Celery SoftTimeLimit handling (apache#13740)
Browse files Browse the repository at this point in the history
* log soft time limit error

* lint

* update test
  • Loading branch information
Lily Kuang authored and Allan Caetano de Oliveira committed May 21, 2021
1 parent f33e0c1 commit a806e41
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 32 deletions.
3 changes: 2 additions & 1 deletion superset/reports/commands/alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def _execute_query(self) -> pd.DataFrame:
(stop - start) * 1000.0,
)
return df
except SoftTimeLimitExceeded:
except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while executing the alert query: %s", ex)
raise AlertQueryTimeout()
except Exception as ex:
raise AlertQueryError(message=str(ex))
Expand Down
1 change: 1 addition & 0 deletions superset/reports/commands/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def _get_screenshot(self) -> bytes:
try:
image_data = screenshot.get_screenshot(user=user)
except SoftTimeLimitExceeded:
logger.warning("A timeout occurred while taking a screenshot.")
raise ReportScheduleScreenshotTimeout()
except Exception as ex:
raise ReportScheduleScreenshotFailedError(
Expand Down
17 changes: 9 additions & 8 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,15 @@ def get_sql_results( # pylint: disable=too-many-arguments
expand_data=expand_data,
log_params=log_params,
)
except SoftTimeLimitExceeded as ex:
logger.warning("Query %d: Time limit exceeded", query_id)
logger.debug("Query %d: %s", query_id, ex)
raise SqlLabTimeoutException(
_(
"SQL Lab timeout. This environment's policy is to kill queries "
"after {} seconds.".format(SQLLAB_TIMEOUT)
)
)
except Exception as ex: # pylint: disable=broad-except
logger.debug("Query %d: %s", query_id, ex)
stats_logger.incr("error_sqllab_unhandled")
Expand Down Expand Up @@ -237,14 +246,6 @@ def execute_sql_statement(
str(query.to_dict()),
)
data = db_engine_spec.fetch_data(cursor, query.limit)

except SoftTimeLimitExceeded as ex:
logger.error("Query %d: Time limit exceeded", query.id)
logger.debug("Query %d: %s", query.id, ex)
raise SqlLabTimeoutException(
"SQL Lab timeout. This environment's policy is to kill queries "
"after {} seconds.".format(SQLLAB_TIMEOUT)
)
except Exception as ex:
logger.error("Query %d: %s", query.id, type(ex))
logger.debug("Query %d: %s", query.id, ex)
Expand Down
17 changes: 13 additions & 4 deletions superset/tasks/async_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
from typing import Any, cast, Dict, Optional

from celery.exceptions import SoftTimeLimitExceeded
from flask import current_app, g

from superset import app
Expand Down Expand Up @@ -47,9 +48,7 @@ def ensure_user_is_set(user_id: Optional[int]) -> None:
def load_chart_data_into_cache(
job_metadata: Dict[str, Any], form_data: Dict[str, Any],
) -> None:
from superset.charts.commands.data import (
ChartDataCommand,
) # load here due to circular imports
from superset.charts.commands.data import ChartDataCommand

with app.app_context(): # type: ignore
try:
Expand All @@ -62,6 +61,11 @@ def load_chart_data_into_cache(
async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_DONE, result_url=result_url,
)
except SoftTimeLimitExceeded as exc:
logger.warning(
"A timeout occurred while loading chart data, error: %s", exc
)
raise exc
except Exception as exc:
# TODO: QueryContext should support SIP-40 style errors
error = exc.message if hasattr(exc, "message") else str(exc) # type: ignore # pylint: disable=no-member
Expand All @@ -75,7 +79,7 @@ def load_chart_data_into_cache(


@celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout)
def load_explore_json_into_cache(
def load_explore_json_into_cache( # pylint: disable=too-many-locals
job_metadata: Dict[str, Any],
form_data: Dict[str, Any],
response_type: Optional[str] = None,
Expand Down Expand Up @@ -106,6 +110,11 @@ def load_explore_json_into_cache(
async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_DONE, result_url=result_url,
)
except SoftTimeLimitExceeded as ex:
logger.warning(
"A timeout occurred while loading explore json, error: %s", ex
)
raise ex
except Exception as exc:
if isinstance(exc, SupersetVizException):
errors = exc.errors # pylint: disable=no-member
Expand Down
3 changes: 3 additions & 0 deletions superset/tasks/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Iterator

import croniter
from celery.exceptions import SoftTimeLimitExceeded
from dateutil import parser

from superset import app
Expand Down Expand Up @@ -91,5 +92,7 @@ def execute(report_schedule_id: int, scheduled_dttm: str) -> None:
def prune_log() -> None:
try:
AsyncPruneReportScheduleLogCommand().run()
except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while pruning report schedule logs: %s", ex)
except CommandException as ex:
logger.error("An exception occurred while pruning report schedule logs: %s", ex)
28 changes: 22 additions & 6 deletions tests/reports/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@
ReportScheduleNotFoundError,
ReportScheduleNotificationError,
ReportSchedulePreviousWorkingError,
ReportSchedulePruneLogError,
ReportScheduleScreenshotFailedError,
ReportScheduleScreenshotTimeout,
ReportScheduleWorkingTimeoutError,
)
from superset.reports.commands.execute import AsyncExecuteReportScheduleCommand
from superset.reports.commands.log_prune import AsyncPruneReportScheduleLogCommand
from superset.utils.core import get_example_database
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices
from tests.fixtures.world_bank_dashboard import (
Expand Down Expand Up @@ -193,7 +195,7 @@ def create_test_table_context(database: Database):
database.get_sqla_engine().execute("DROP TABLE test_table")


@pytest.yield_fixture()
@pytest.fixture()
def create_report_email_chart():
with app.app_context():
chart = db.session.query(Slice).first()
Expand All @@ -205,7 +207,7 @@ def create_report_email_chart():
cleanup_report_schedule(report_schedule)


@pytest.yield_fixture()
@pytest.fixture()
def create_report_email_dashboard():
with app.app_context():
dashboard = db.session.query(Dashboard).first()
Expand All @@ -217,7 +219,7 @@ def create_report_email_dashboard():
cleanup_report_schedule(report_schedule)


@pytest.yield_fixture()
@pytest.fixture()
def create_report_slack_chart():
with app.app_context():
chart = db.session.query(Slice).first()
Expand All @@ -229,7 +231,7 @@ def create_report_slack_chart():
cleanup_report_schedule(report_schedule)


@pytest.yield_fixture()
@pytest.fixture()
def create_report_slack_chart_working():
with app.app_context():
chart = db.session.query(Slice).first()
Expand All @@ -255,7 +257,7 @@ def create_report_slack_chart_working():
cleanup_report_schedule(report_schedule)


@pytest.yield_fixture()
@pytest.fixture()
def create_alert_slack_chart_success():
with app.app_context():
chart = db.session.query(Slice).first()
Expand All @@ -281,7 +283,7 @@ def create_alert_slack_chart_success():
cleanup_report_schedule(report_schedule)


@pytest.yield_fixture()
@pytest.fixture()
def create_alert_slack_chart_grace():
with app.app_context():
chart = db.session.query(Slice).first()
Expand Down Expand Up @@ -1115,3 +1117,17 @@ def test_grace_period_error_flap(
assert (
get_notification_error_sent_count(create_invalid_sql_alert_email_chart) == 2
)


@pytest.mark.usefixtures(
"load_birth_names_dashboard_with_slices", "create_report_email_dashboard"
)
@patch("superset.reports.dao.ReportScheduleDAO.bulk_delete_logs")
def test_prune_log_soft_time_out(bulk_delete_logs, create_report_email_dashboard):
from celery.exceptions import SoftTimeLimitExceeded
from datetime import datetime, timedelta

bulk_delete_logs.side_effect = SoftTimeLimitExceeded()
with pytest.raises(SoftTimeLimitExceeded) as excinfo:
AsyncPruneReportScheduleLogCommand().run()
assert str(excinfo.value) == "SoftTimeLimitExceeded()"
30 changes: 29 additions & 1 deletion tests/sqllab_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
from superset.models.core import Database
from superset.models.sql_lab import Query, SavedQuery
from superset.result_set import SupersetResultSet
from superset.sql_lab import execute_sql_statements, SqlLabException
from superset.sql_lab import (
execute_sql_statements,
get_sql_results,
SqlLabException,
SqlLabTimeoutException,
)
from superset.sql_parse import CtasMethod
from superset.utils.core import (
datetime_to_epoch,
Expand Down Expand Up @@ -793,3 +798,26 @@ def test_execute_sql_statements_ctas(
"sure your query has only a SELECT statement. Then, "
"try running your query again."
)

@mock.patch("superset.sql_lab.get_query")
@mock.patch("superset.sql_lab.execute_sql_statement")
def test_get_sql_results_soft_time_limit(
self, mock_execute_sql_statement, mock_get_query
):
from celery.exceptions import SoftTimeLimitExceeded

sql = """
-- comment
SET @value = 42;
SELECT @value AS foo;
-- comment
"""
mock_get_query.side_effect = SoftTimeLimitExceeded()
with pytest.raises(SqlLabTimeoutException) as excinfo:
get_sql_results(
1, sql, return_results=True, store_results=False,
)
assert (
str(excinfo.value)
== "SQL Lab timeout. This environment's policy is to kill queries after 21600 seconds."
)
51 changes: 51 additions & 0 deletions tests/tasks/async_queries_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from uuid import uuid4

import pytest
from celery.exceptions import SoftTimeLimitExceeded

from superset import db
from superset.charts.commands.data import ChartDataCommand
Expand Down Expand Up @@ -94,6 +95,31 @@ def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_comman
errors = [{"message": "Error: foo"}]
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)

@mock.patch.object(ChartDataCommand, "run")
@mock.patch.object(async_query_manager, "update_job")
def test_soft_timeout_load_chart_data_into_cache(
self, mock_update_job, mock_run_command
):
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
form_data = {}
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": user.id,
"status": "pending",
"errors": [],
}
errors = ["A timeout occurred while loading chart data"]

with pytest.raises(SoftTimeLimitExceeded):
with mock.patch.object(
async_queries, "ensure_user_is_set",
) as ensure_user_is_set:
ensure_user_is_set.side_effect = SoftTimeLimitExceeded()
load_chart_data_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache(self, mock_update_job):
Expand Down Expand Up @@ -151,3 +177,28 @@ def test_load_explore_json_into_cache_error(self, mock_update_job):

errors = ["The dataset associated with this chart no longer exists"]
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)

@mock.patch.object(ChartDataCommand, "run")
@mock.patch.object(async_query_manager, "update_job")
def test_soft_timeout_load_explore_json_into_cache(
self, mock_update_job, mock_run_command
):
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
form_data = {}
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": user.id,
"status": "pending",
"errors": [],
}
errors = ["A timeout occurred while loading explore json, error"]

with pytest.raises(SoftTimeLimitExceeded):
with mock.patch.object(
async_queries, "ensure_user_is_set",
) as ensure_user_is_set:
ensure_user_is_set.side_effect = SoftTimeLimitExceeded()
load_explore_json_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors)
Loading

0 comments on commit a806e41

Please sign in to comment.