Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Celery SoftTimeLimit handling #13740

Merged
merged 3 commits into from
Apr 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion superset/reports/commands/alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def _execute_query(self) -> pd.DataFrame:
(stop - start) * 1000.0,
)
return df
except SoftTimeLimitExceeded:
except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while executing the alert query: %s", ex)
raise AlertQueryTimeout()
except Exception as ex:
raise AlertQueryError(message=str(ex))
Expand Down
1 change: 1 addition & 0 deletions superset/reports/commands/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def _get_screenshot(self) -> bytes:
try:
image_data = screenshot.get_screenshot(user=user)
except SoftTimeLimitExceeded:
logger.warning("A timeout occurred while taking a screenshot.")
raise ReportScheduleScreenshotTimeout()
except Exception as ex:
raise ReportScheduleScreenshotFailedError(
Expand Down
17 changes: 9 additions & 8 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,15 @@ def get_sql_results( # pylint: disable=too-many-arguments
expand_data=expand_data,
log_params=log_params,
)
except SoftTimeLimitExceeded as ex:
logger.warning("Query %d: Time limit exceeded", query_id)
logger.debug("Query %d: %s", query_id, ex)
raise SqlLabTimeoutException(
_(
"SQL Lab timeout. This environment's policy is to kill queries "
"after {} seconds.".format(SQLLAB_TIMEOUT)
)
)
except Exception as ex: # pylint: disable=broad-except
logger.debug("Query %d: %s", query_id, ex)
stats_logger.incr("error_sqllab_unhandled")
Expand Down Expand Up @@ -237,14 +246,6 @@ def execute_sql_statement(
str(query.to_dict()),
)
data = db_engine_spec.fetch_data(cursor, query.limit)

except SoftTimeLimitExceeded as ex:
logger.error("Query %d: Time limit exceeded", query.id)
logger.debug("Query %d: %s", query.id, ex)
raise SqlLabTimeoutException(
"SQL Lab timeout. This environment's policy is to kill queries "
"after {} seconds.".format(SQLLAB_TIMEOUT)
)
except Exception as ex:
logger.error("Query %d: %s", query.id, type(ex))
logger.debug("Query %d: %s", query.id, ex)
Expand Down
17 changes: 13 additions & 4 deletions superset/tasks/async_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
from typing import Any, cast, Dict, Optional

from celery.exceptions import SoftTimeLimitExceeded
from flask import current_app, g

from superset import app
Expand Down Expand Up @@ -47,9 +48,7 @@ def ensure_user_is_set(user_id: Optional[int]) -> None:
def load_chart_data_into_cache(
job_metadata: Dict[str, Any], form_data: Dict[str, Any],
) -> None:
from superset.charts.commands.data import (
ChartDataCommand,
) # load here due to circular imports
from superset.charts.commands.data import ChartDataCommand

with app.app_context(): # type: ignore
try:
Expand All @@ -62,6 +61,11 @@ def load_chart_data_into_cache(
async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_DONE, result_url=result_url,
)
except SoftTimeLimitExceeded as exc:
logger.warning(
"A timeout occurred while loading chart data, error: %s", exc
)
raise exc
except Exception as exc:
# TODO: QueryContext should support SIP-40 style errors
error = exc.message if hasattr(exc, "message") else str(exc) # type: ignore # pylint: disable=no-member
Expand All @@ -75,7 +79,7 @@ def load_chart_data_into_cache(


@celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout)
def load_explore_json_into_cache(
def load_explore_json_into_cache( # pylint: disable=too-many-locals
job_metadata: Dict[str, Any],
form_data: Dict[str, Any],
response_type: Optional[str] = None,
Expand Down Expand Up @@ -106,6 +110,11 @@ def load_explore_json_into_cache(
async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_DONE, result_url=result_url,
)
except SoftTimeLimitExceeded as ex:
logger.warning(
"A timeout occurred while loading explore json, error: %s", ex
)
raise ex
except Exception as exc:
if isinstance(exc, SupersetVizException):
errors = exc.errors # pylint: disable=no-member
Expand Down
3 changes: 3 additions & 0 deletions superset/tasks/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Iterator

import croniter
from celery.exceptions import SoftTimeLimitExceeded
from dateutil import parser

from superset import app
Expand Down Expand Up @@ -91,5 +92,7 @@ def execute(report_schedule_id: int, scheduled_dttm: str) -> None:
def prune_log() -> None:
try:
AsyncPruneReportScheduleLogCommand().run()
except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while pruning report schedule logs: %s", ex)
except CommandException as ex:
logger.error("An exception occurred while pruning report schedule logs: %s", ex)
28 changes: 22 additions & 6 deletions tests/reports/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@
ReportScheduleNotFoundError,
ReportScheduleNotificationError,
ReportSchedulePreviousWorkingError,
ReportSchedulePruneLogError,
ReportScheduleScreenshotFailedError,
ReportScheduleScreenshotTimeout,
ReportScheduleWorkingTimeoutError,
)
from superset.reports.commands.execute import AsyncExecuteReportScheduleCommand
from superset.reports.commands.log_prune import AsyncPruneReportScheduleLogCommand
from superset.utils.core import get_example_database
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices
from tests.fixtures.world_bank_dashboard import (
Expand Down Expand Up @@ -186,7 +188,7 @@ def create_test_table_context(database: Database):
database.get_sqla_engine().execute("DROP TABLE test_table")


@pytest.yield_fixture()
@pytest.fixture()
def create_report_email_chart():
with app.app_context():
chart = db.session.query(Slice).first()
Expand All @@ -198,7 +200,7 @@ def create_report_email_chart():
cleanup_report_schedule(report_schedule)


@pytest.yield_fixture()
@pytest.fixture()
def create_report_email_dashboard():
with app.app_context():
dashboard = db.session.query(Dashboard).first()
Expand All @@ -210,7 +212,7 @@ def create_report_email_dashboard():
cleanup_report_schedule(report_schedule)


@pytest.yield_fixture()
@pytest.fixture()
def create_report_slack_chart():
with app.app_context():
chart = db.session.query(Slice).first()
Expand All @@ -222,7 +224,7 @@ def create_report_slack_chart():
cleanup_report_schedule(report_schedule)


@pytest.yield_fixture()
@pytest.fixture()
def create_report_slack_chart_working():
with app.app_context():
chart = db.session.query(Slice).first()
Expand All @@ -248,7 +250,7 @@ def create_report_slack_chart_working():
cleanup_report_schedule(report_schedule)


@pytest.yield_fixture()
@pytest.fixture()
def create_alert_slack_chart_success():
with app.app_context():
chart = db.session.query(Slice).first()
Expand All @@ -274,7 +276,7 @@ def create_alert_slack_chart_success():
cleanup_report_schedule(report_schedule)


@pytest.yield_fixture()
@pytest.fixture()
def create_alert_slack_chart_grace():
with app.app_context():
chart = db.session.query(Slice).first()
Expand Down Expand Up @@ -1109,3 +1111,17 @@ def test_grace_period_error_flap(
assert (
get_notification_error_sent_count(create_invalid_sql_alert_email_chart) == 2
)


@pytest.mark.usefixtures(
"load_birth_names_dashboard_with_slices", "create_report_email_dashboard"
)
@patch("superset.reports.dao.ReportScheduleDAO.bulk_delete_logs")
def test_prune_log_soft_time_out(bulk_delete_logs, create_report_email_dashboard):
from celery.exceptions import SoftTimeLimitExceeded
from datetime import datetime, timedelta

bulk_delete_logs.side_effect = SoftTimeLimitExceeded()
with pytest.raises(SoftTimeLimitExceeded) as excinfo:
AsyncPruneReportScheduleLogCommand().run()
assert str(excinfo.value) == "SoftTimeLimitExceeded()"
30 changes: 29 additions & 1 deletion tests/sqllab_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
from superset.models.core import Database
from superset.models.sql_lab import Query, SavedQuery
from superset.result_set import SupersetResultSet
from superset.sql_lab import execute_sql_statements, SqlLabException
from superset.sql_lab import (
execute_sql_statements,
get_sql_results,
SqlLabException,
SqlLabTimeoutException,
)
from superset.sql_parse import CtasMethod
from superset.utils.core import (
datetime_to_epoch,
Expand Down Expand Up @@ -793,3 +798,26 @@ def test_execute_sql_statements_ctas(
"sure your query has only a SELECT statement. Then, "
"try running your query again."
)

@mock.patch("superset.sql_lab.get_query")
@mock.patch("superset.sql_lab.execute_sql_statement")
def test_get_sql_results_soft_time_limit(
self, mock_execute_sql_statement, mock_get_query
):
from celery.exceptions import SoftTimeLimitExceeded

sql = """
-- comment
SET @value = 42;
SELECT @value AS foo;
-- comment
"""
mock_get_query.side_effect = SoftTimeLimitExceeded()
with pytest.raises(SqlLabTimeoutException) as excinfo:
get_sql_results(
1, sql, return_results=True, store_results=False,
)
assert (
str(excinfo.value)
== "SQL Lab timeout. This environment's policy is to kill queries after 21600 seconds."
)
51 changes: 51 additions & 0 deletions tests/tasks/async_queries_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from uuid import uuid4

import pytest
from celery.exceptions import SoftTimeLimitExceeded

from superset import db
from superset.charts.commands.data import ChartDataCommand
Expand Down Expand Up @@ -94,6 +95,31 @@ def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_comman
errors = [{"message": "Error: foo"}]
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)

@mock.patch.object(ChartDataCommand, "run")
@mock.patch.object(async_query_manager, "update_job")
def test_soft_timeout_load_chart_data_into_cache(
self, mock_update_job, mock_run_command
):
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
form_data = {}
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": user.id,
"status": "pending",
"errors": [],
}
errors = ["A timeout occurred while loading chart data"]

with pytest.raises(SoftTimeLimitExceeded):
with mock.patch.object(
async_queries, "ensure_user_is_set",
) as ensure_user_is_set:
ensure_user_is_set.side_effect = SoftTimeLimitExceeded()
load_chart_data_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache(self, mock_update_job):
Expand Down Expand Up @@ -151,3 +177,28 @@ def test_load_explore_json_into_cache_error(self, mock_update_job):

errors = ["The dataset associated with this chart no longer exists"]
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)

@mock.patch.object(ChartDataCommand, "run")
@mock.patch.object(async_query_manager, "update_job")
def test_soft_timeout_load_explore_json_into_cache(
self, mock_update_job, mock_run_command
):
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
form_data = {}
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": user.id,
"status": "pending",
"errors": [],
}
errors = ["A timeout occurred while loading explore json, error"]

with pytest.raises(SoftTimeLimitExceeded):
with mock.patch.object(
async_queries, "ensure_user_is_set",
) as ensure_user_is_set:
ensure_user_is_set.side_effect = SoftTimeLimitExceeded()
load_explore_json_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors)
Loading