From c2715d68d09a52e9a1fc89f57ae966834583b979 Mon Sep 17 00:00:00 2001
From: Ville Brofeldt <33317356+villebro@users.noreply.github.com>
Date: Sat, 24 Dec 2022 06:31:46 +0200
Subject: [PATCH] feat(trino): support early cancellation of queries (#22498)
(cherry picked from commit b6d39d194c90dbbf0050bb3d32d2e1a513dfc0a6)
---
.../src/SqlLab/actions/sqlLab.js | 6 +-
.../src/SqlLab/actions/sqlLab.test.js | 31 ++++++++++
.../src/SqlLab/components/ResultSet/index.tsx | 26 ++++----
.../components/SqlEditorTabHeader/index.tsx | 4 +-
.../src/SqlLab/reducers/sqlLab.js | 29 +++++----
.../views/CRUD/data/query/QueryList.test.tsx | 3 +-
.../src/views/CRUD/data/query/QueryList.tsx | 34 ++++++++---
.../data/query/QueryPreviewModal.test.tsx | 3 +-
superset-frontend/src/views/CRUD/types.ts | 10 +---
superset/constants.py | 3 +
superset/db_engine_specs/base.py | 52 ++++++++++------
superset/db_engine_specs/hive.py | 2 +-
superset/db_engine_specs/presto.py | 2 +-
superset/db_engine_specs/trino.py | 23 +++++++-
superset/sql_lab.py | 16 ++++-
.../unit_tests/db_engine_specs/test_trino.py | 59 +++++++++++++++++++
16 files changed, 230 insertions(+), 73 deletions(-)
diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js
index a58630cda155e..12487d1a9437d 100644
--- a/superset-frontend/src/SqlLab/actions/sqlLab.js
+++ b/superset-frontend/src/SqlLab/actions/sqlLab.js
@@ -17,7 +17,7 @@
* under the License.
*/
import shortid from 'shortid';
-import { t, SupersetClient } from '@superset-ui/core';
+import { SupersetClient, t } from '@superset-ui/core';
import invert from 'lodash/invert';
import mapKeys from 'lodash/mapKeys';
import { isFeatureEnabled, FeatureFlag } from 'src/featureFlags';
@@ -229,11 +229,13 @@ export function startQuery(query) {
export function querySuccess(query, results) {
return function (dispatch) {
+ const sqlEditorId = results?.query?.sqlEditorId;
const sync =
+ sqlEditorId &&
!query.isDataPreview &&
isFeatureEnabled(FeatureFlag.SQLLAB_BACKEND_PERSISTENCE)
? SupersetClient.put({
- endpoint: encodeURI(`/tabstateview/${results.query.sqlEditorId}`),
+ endpoint: encodeURI(`/tabstateview/${sqlEditorId}`),
postPayload: { latest_query_id: query.id },
})
: Promise.resolve();
diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.test.js b/superset-frontend/src/SqlLab/actions/sqlLab.test.js
index 7792f1da8ad63..acc79031edc15 100644
--- a/superset-frontend/src/SqlLab/actions/sqlLab.test.js
+++ b/superset-frontend/src/SqlLab/actions/sqlLab.test.js
@@ -30,6 +30,7 @@ import {
initialState,
queryId,
} from 'src/SqlLab/fixtures';
+import { QueryState } from '@superset-ui/core';
const middlewares = [thunk];
const mockStore = configureMockStore(middlewares);
@@ -502,6 +503,7 @@ describe('async actions', () => {
const results = {
data: mockBigNumber,
query: { sqlEditorId: 'abcd' },
+ status: QueryState.SUCCESS,
query_id: 'efgh',
};
fetchMock.get(fetchQueryEndpoint, JSON.stringify(results), {
@@ -525,6 +527,35 @@ describe('async actions', () => {
expect(fetchMock.calls(updateTabStateEndpoint)).toHaveLength(1);
});
});
+
+ it("doesn't update the tab state in the backend on stoppped query", () => {
+ expect.assertions(2);
+
+ const results = {
+ status: QueryState.STOPPED,
+ query_id: 'efgh',
+ };
+ fetchMock.get(fetchQueryEndpoint, JSON.stringify(results), {
+ overwriteRoutes: true,
+ });
+ const store = mockStore({});
+ const expectedActions = [
+ {
+ type: actions.REQUEST_QUERY_RESULTS,
+ query,
+ },
+ // missing below
+ {
+ type: actions.QUERY_SUCCESS,
+ query,
+ results,
+ },
+ ];
+ return store.dispatch(actions.fetchQueryResults(query)).then(() => {
+ expect(store.getActions()).toEqual(expectedActions);
+ expect(fetchMock.calls(updateTabStateEndpoint)).toHaveLength(0);
+ });
+ });
});
describe('addQueryEditor', () => {
diff --git a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx
index 0d61d8ddab648..10cdd8a39e29f 100644
--- a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx
+++ b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx
@@ -16,13 +16,13 @@
* specific language governing permissions and limitations
* under the License.
*/
-import React, { useState, useEffect, useCallback } from 'react';
+import React, { useCallback, useEffect, useState } from 'react';
import { useDispatch } from 'react-redux';
import ButtonGroup from 'src/components/ButtonGroup';
import Alert from 'src/components/Alert';
import Button from 'src/components/Button';
import shortid from 'shortid';
-import { styled, t, QueryResponse } from '@superset-ui/core';
+import { QueryResponse, QueryState, styled, t } from '@superset-ui/core';
import { usePrevious } from 'src/hooks/usePrevious';
import ErrorMessageWithStackTrace from 'src/components/ErrorMessage/ErrorMessageWithStackTrace';
import {
@@ -43,9 +43,9 @@ import CopyToClipboard from 'src/components/CopyToClipboard';
import { addDangerToast } from 'src/components/MessageToasts/actions';
import { prepareCopyToClipboardTabularData } from 'src/utils/common';
import {
- CtasEnum,
- clearQueryResults,
addQueryEditor,
+ clearQueryResults,
+ CtasEnum,
fetchQueryResults,
reFetchQueryResults,
reRunQuery,
@@ -387,8 +387,8 @@ const ResultSet = ({
let trackingUrl;
if (
query.trackingUrl &&
- query.state !== 'success' &&
- query.state !== 'fetching'
+ query.state !== QueryState.SUCCESS &&
+ query.state !== QueryState.FETCHING
) {
trackingUrl = (
);
}
@@ -406,11 +408,11 @@ const ResultSet = ({
sql = ;
}
- if (query.state === 'stopped') {
+ if (query.state === QueryState.STOPPED) {
return ;
}
- if (query.state === 'failed') {
+ if (query.state === QueryState.FAILED) {
return (
= ({ queryEditor }) => {
}),
shallowEqual,
);
- const queryStatus = useSelector(
+ const queryState = useSelector(
({ sqlLab }) => sqlLab.queries[qe.latestQueryId || '']?.state || '',
);
const dispatch = useDispatch();
@@ -139,7 +139,7 @@ const SqlEditorTabHeader: React.FC = ({ queryEditor }) => {
}
/>
- {qe.name} {' '}
+ {qe.name} {' '}
);
};
diff --git a/superset-frontend/src/SqlLab/reducers/sqlLab.js b/superset-frontend/src/SqlLab/reducers/sqlLab.js
index ed103a2afe1bc..478487d6e239d 100644
--- a/superset-frontend/src/SqlLab/reducers/sqlLab.js
+++ b/superset-frontend/src/SqlLab/reducers/sqlLab.js
@@ -16,8 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
-import { t } from '@superset-ui/core';
-
+import { QueryState, t } from '@superset-ui/core';
import getInitialState from './getInitialState';
import * as actions from '../actions/sqlLab';
import { now } from '../../utils/dates';
@@ -391,7 +390,7 @@ export default function sqlLabReducer(state = {}, action) {
},
[actions.STOP_QUERY]() {
return alterInObject(state, 'queries', action.query, {
- state: 'stopped',
+ state: QueryState.STOPPED,
results: [],
});
},
@@ -405,12 +404,16 @@ export default function sqlLabReducer(state = {}, action) {
},
[actions.REQUEST_QUERY_RESULTS]() {
return alterInObject(state, 'queries', action.query, {
- state: 'fetching',
+ state: QueryState.FETCHING,
});
},
[actions.QUERY_SUCCESS]() {
- // prevent race condition were query succeeds shortly after being canceled
- if (action.query.state === 'stopped') {
+ // prevent race condition where query succeeds shortly after being canceled
+ // or the final result was unsuccessful
+ if (
+ action.query.state === QueryState.STOPPED ||
+ action.results.status !== QueryState.SUCCESS
+ ) {
return state;
}
const alts = {
@@ -418,7 +421,7 @@ export default function sqlLabReducer(state = {}, action) {
progress: 100,
results: action.results,
rows: action?.results?.query?.rows || 0,
- state: 'success',
+ state: QueryState.SUCCESS,
limitingFactor: action?.results?.query?.limitingFactor,
tempSchema: action?.results?.query?.tempSchema,
tempTable: action?.results?.query?.tempTable,
@@ -434,11 +437,11 @@ export default function sqlLabReducer(state = {}, action) {
return alterInObject(state, 'queries', action.query, alts);
},
[actions.QUERY_FAILED]() {
- if (action.query.state === 'stopped') {
+ if (action.query.state === QueryState.STOPPED) {
return state;
}
const alts = {
- state: 'failed',
+ state: QueryState.FAILED,
errors: action.errors,
errorMessage: action.msg,
endDttm: now(),
@@ -723,8 +726,8 @@ export default function sqlLabReducer(state = {}, action) {
Object.entries(action.alteredQueries).forEach(([id, changedQuery]) => {
if (
!state.queries.hasOwnProperty(id) ||
- (state.queries[id].state !== 'stopped' &&
- state.queries[id].state !== 'failed')
+ (state.queries[id].state !== QueryState.STOPPED &&
+ state.queries[id].state !== QueryState.FAILED)
) {
if (changedQuery.changedOn > queriesLastUpdate) {
queriesLastUpdate = changedQuery.changedOn;
@@ -738,8 +741,8 @@ export default function sqlLabReducer(state = {}, action) {
// because of async behavior, sql lab may still poll a couple of seconds
// when it started fetching or finished rendering results
state:
- currentState === 'success' &&
- ['fetching', 'success'].includes(prevState)
+ currentState === QueryState.SUCCESS &&
+ [QueryState.FETCHING, QueryState.SUCCESS].includes(prevState)
? prevState
: currentState,
};
diff --git a/superset-frontend/src/views/CRUD/data/query/QueryList.test.tsx b/superset-frontend/src/views/CRUD/data/query/QueryList.test.tsx
index eaaa75a1cb792..be28d7e2dfa85 100644
--- a/superset-frontend/src/views/CRUD/data/query/QueryList.test.tsx
+++ b/superset-frontend/src/views/CRUD/data/query/QueryList.test.tsx
@@ -33,6 +33,7 @@ import ListView from 'src/components/ListView';
import Filters from 'src/components/ListView/Filters';
import SyntaxHighlighter from 'react-syntax-highlighter/dist/cjs/light';
import SubMenu from 'src/views/components/SubMenu';
+import { QueryState } from '@superset-ui/core';
// store needed for withToasts
const mockStore = configureStore([thunk]);
@@ -54,7 +55,7 @@ const mockQueries: QueryObject[] = [...new Array(3)].map((_, i) => ({
{ schema: 'foo', table: 'table' },
{ schema: 'bar', table: 'table_2' },
],
- status: 'success',
+ status: QueryState.SUCCESS,
tab_name: 'Main Tab',
user: {
first_name: 'cool',
diff --git a/superset-frontend/src/views/CRUD/data/query/QueryList.tsx b/superset-frontend/src/views/CRUD/data/query/QueryList.tsx
index bbee625092a6b..dbe8e259dacce 100644
--- a/superset-frontend/src/views/CRUD/data/query/QueryList.tsx
+++ b/superset-frontend/src/views/CRUD/data/query/QueryList.tsx
@@ -17,7 +17,13 @@
* under the License.
*/
import React, { useMemo, useState, useCallback, ReactElement } from 'react';
-import { SupersetClient, t, styled, useTheme } from '@superset-ui/core';
+import {
+ QueryState,
+ styled,
+ SupersetClient,
+ t,
+ useTheme,
+} from '@superset-ui/core';
import moment from 'moment';
import {
createFetchRelated,
@@ -127,7 +133,13 @@ function QueryList({ addDangerToast }: QueryListProps) {
row: {
original: { status },
},
- }: any) => {
+ }: {
+ row: {
+ original: {
+ status: QueryState;
+ };
+ };
+ }) => {
const statusConfig: {
name: ReactElement | null;
label: string;
@@ -135,33 +147,39 @@ function QueryList({ addDangerToast }: QueryListProps) {
name: null,
label: '',
};
- if (status === 'success') {
+ if (status === QueryState.SUCCESS) {
statusConfig.name = (
);
statusConfig.label = t('Success');
- } else if (status === 'failed' || status === 'stopped') {
+ } else if (
+ status === QueryState.FAILED ||
+ status === QueryState.STOPPED
+ ) {
statusConfig.name = (
);
statusConfig.label = t('Failed');
- } else if (status === 'running') {
+ } else if (status === QueryState.RUNNING) {
statusConfig.name = (
);
statusConfig.label = t('Running');
- } else if (status === 'timed_out') {
+ } else if (status === QueryState.TIMED_OUT) {
statusConfig.name = (
);
statusConfig.label = t('Offline');
- } else if (status === 'scheduled' || status === 'pending') {
+ } else if (
+ status === QueryState.SCHEDULED ||
+ status === QueryState.PENDING
+ ) {
statusConfig.name = (
);
diff --git a/superset-frontend/src/views/CRUD/data/query/QueryPreviewModal.test.tsx b/superset-frontend/src/views/CRUD/data/query/QueryPreviewModal.test.tsx
index 7a85e4c292123..96498f6e69a65 100644
--- a/superset-frontend/src/views/CRUD/data/query/QueryPreviewModal.test.tsx
+++ b/superset-frontend/src/views/CRUD/data/query/QueryPreviewModal.test.tsx
@@ -27,6 +27,7 @@ import QueryPreviewModal from 'src/views/CRUD/data/query/QueryPreviewModal';
import { QueryObject } from 'src/views/CRUD/types';
import SyntaxHighlighter from 'react-syntax-highlighter/dist/cjs/light';
import { act } from 'react-dom/test-utils';
+import { QueryState } from '@superset-ui/core';
// store needed for withToasts
const mockStore = configureStore([thunk]);
@@ -46,7 +47,7 @@ const mockQueries: QueryObject[] = [...new Array(3)].map((_, i) => ({
{ schema: 'foo', table: 'table' },
{ schema: 'bar', table: 'table_2' },
],
- status: 'success',
+ status: QueryState.SUCCESS,
tab_name: 'Main Tab',
user: {
first_name: 'cool',
diff --git a/superset-frontend/src/views/CRUD/types.ts b/superset-frontend/src/views/CRUD/types.ts
index 0090697747ac5..07b26d27340a6 100644
--- a/superset-frontend/src/views/CRUD/types.ts
+++ b/superset-frontend/src/views/CRUD/types.ts
@@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
+import { QueryState } from '@superset-ui/core';
import { User } from 'src/types/bootstrapTypes';
import Database from 'src/types/Database';
import Owner from 'src/types/Owner';
@@ -91,14 +92,7 @@ export interface QueryObject {
sql: string;
executed_sql: string | null;
sql_tables?: { catalog?: string; schema: string; table: string }[];
- status:
- | 'success'
- | 'failed'
- | 'stopped'
- | 'running'
- | 'timed_out'
- | 'scheduled'
- | 'pending';
+ status: QueryState;
tab_name: string;
user: {
first_name: string;
diff --git a/superset/constants.py b/superset/constants.py
index 7d759acf6741c..5091d65a432dc 100644
--- a/superset/constants.py
+++ b/superset/constants.py
@@ -34,6 +34,9 @@
NO_TIME_RANGE = "No filter"
+QUERY_CANCEL_KEY = "cancel_query"
+QUERY_EARLY_CANCEL_KEY = "early_cancel_query"
+
class RouteMethod: # pylint: disable=too-few-public-methods
"""
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 3b4d38ee9b641..0f124de34aa65 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
+
+from __future__ import annotations
+
import json
import logging
import re
@@ -477,7 +480,7 @@ def get_text_clause(cls, clause: str) -> TextClause:
@classmethod
def get_engine(
cls,
- database: "Database",
+ database: Database,
schema: Optional[str] = None,
source: Optional[utils.QuerySource] = None,
) -> ContextManager[Engine]:
@@ -732,7 +735,7 @@ def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]
@classmethod
def extra_table_metadata( # pylint: disable=unused-argument
cls,
- database: "Database",
+ database: Database,
table_name: str,
schema_name: Optional[str],
) -> Dict[str, Any]:
@@ -749,7 +752,7 @@ def extra_table_metadata( # pylint: disable=unused-argument
@classmethod
def apply_limit_to_sql(
- cls, sql: str, limit: int, database: "Database", force: bool = False
+ cls, sql: str, limit: int, database: Database, force: bool = False
) -> str:
"""
Alters the SQL statement to apply a LIMIT clause
@@ -891,7 +894,7 @@ def get_cte_query(cls, sql: str) -> Optional[str]:
@classmethod
def df_to_sql(
cls,
- database: "Database",
+ database: Database,
table: Table,
df: pd.DataFrame,
to_sql_kwargs: Dict[str, Any],
@@ -938,7 +941,7 @@ def convert_dttm( # pylint: disable=unused-argument
return None
@classmethod
- def handle_cursor(cls, cursor: Any, query: "Query", session: Session) -> None:
+ def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
"""Handle a live cursor between the execute and fetchall calls
The flow works without this method doing anything, but it allows
@@ -1030,7 +1033,7 @@ def get_schema_names(cls, inspector: Inspector) -> List[str]:
@classmethod
def get_table_names( # pylint: disable=unused-argument
cls,
- database: "Database",
+ database: Database,
inspector: Inspector,
schema: Optional[str],
) -> Set[str]:
@@ -1058,7 +1061,7 @@ def get_table_names( # pylint: disable=unused-argument
@classmethod
def get_view_names( # pylint: disable=unused-argument
cls,
- database: "Database",
+ database: Database,
inspector: Inspector,
schema: Optional[str],
) -> Set[str]:
@@ -1124,7 +1127,7 @@ def get_columns(
@classmethod
def get_metrics( # pylint: disable=unused-argument
cls,
- database: "Database",
+ database: Database,
inspector: Inspector,
table_name: str,
schema: Optional[str],
@@ -1146,7 +1149,7 @@ def where_latest_partition( # pylint: disable=too-many-arguments,unused-argumen
cls,
table_name: str,
schema: Optional[str],
- database: "Database",
+ database: Database,
query: Select,
columns: Optional[List[Dict[str, str]]] = None,
) -> Optional[Select]:
@@ -1171,7 +1174,7 @@ def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[Any]:
@classmethod
def select_star( # pylint: disable=too-many-arguments,too-many-locals
cls,
- database: "Database",
+ database: Database,
table_name: str,
engine: Engine,
schema: Optional[str] = None,
@@ -1250,7 +1253,7 @@ def query_cost_formatter(
raise Exception("Database does not support cost estimation")
@classmethod
- def process_statement(cls, statement: str, database: "Database") -> str:
+ def process_statement(cls, statement: str, database: Database) -> str:
"""
Process a SQL statement by stripping and mutating it.
@@ -1274,7 +1277,7 @@ def process_statement(cls, statement: str, database: "Database") -> str:
@classmethod
def estimate_query_cost(
cls,
- database: "Database",
+ database: Database,
schema: str,
sql: str,
source: Optional[utils.QuerySource] = None,
@@ -1468,7 +1471,7 @@ def column_datatype_to_string(
@classmethod
def get_function_names( # pylint: disable=unused-argument
cls,
- database: "Database",
+ database: Database,
) -> List[str]:
"""
Get a list of function names that are able to be called on the database.
@@ -1493,7 +1496,7 @@ def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple[Any, ...]]:
@staticmethod
def mutate_db_for_connection_test( # pylint: disable=unused-argument
- database: "Database",
+ database: Database,
) -> None:
"""
Some databases require passing additional parameters for validating database
@@ -1505,7 +1508,7 @@ def mutate_db_for_connection_test( # pylint: disable=unused-argument
return None
@staticmethod
- def get_extra_params(database: "Database") -> Dict[str, Any]:
+ def get_extra_params(database: Database) -> Dict[str, Any]:
"""
Some databases require adding elements to connection parameters,
like passing certificates to `extra`. This can be done here.
@@ -1524,7 +1527,7 @@ def get_extra_params(database: "Database") -> Dict[str, Any]:
@staticmethod
def update_params_from_encrypted_extra( # pylint: disable=invalid-name
- database: "Database", params: Dict[str, Any]
+ database: Database, params: Dict[str, Any]
) -> None:
"""
Some databases require some sensitive information which do not conform to
@@ -1586,11 +1589,22 @@ def get_column_spec( # pylint: disable=unused-argument
)
return None
+ # pylint: disable=unused-argument
+ @classmethod
+ def prepare_cancel_query(cls, query: Query, session: Session) -> None:
+ """
+ Some databases may acquire the query cancelation id after the query
+ cancelation request has been received. For those cases, the db engine spec
+ can record the cancelation intent so that the query can either be stopped
+ prior to execution, or canceled once the query id is acquired.
+ """
+ return None
+
@classmethod
def has_implicit_cancel(cls) -> bool:
"""
Return True if the live cursor handles the implicit cancelation of the query,
- False otherise.
+ False otherwise.
:return: Whether the live cursor implicitly cancels the query
:see: handle_cursor
@@ -1602,7 +1616,7 @@ def has_implicit_cancel(cls) -> bool:
def get_cancel_query_id( # pylint: disable=unused-argument
cls,
cursor: Any,
- query: "Query",
+ query: Query,
) -> Optional[str]:
"""
Select identifiers from the database engine that uniquely identifies the
@@ -1620,7 +1634,7 @@ def get_cancel_query_id( # pylint: disable=unused-argument
def cancel_query( # pylint: disable=unused-argument
cls,
cursor: Any,
- query: "Query",
+ query: Query,
cancel_query_id: str,
) -> bool:
"""
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index 8b6767e80ec57..af794ff0c23fa 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -559,7 +559,7 @@ def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
def has_implicit_cancel(cls) -> bool:
"""
Return True if the live cursor handles the implicit cancelation of the query,
- False otherise.
+ False otherwise.
:return: Whether the live cursor implicitly cancels the query
:see: handle_cursor
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 8b7574a021db9..2e8fc09fd1fbb 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -1304,7 +1304,7 @@ def get_column_spec(
def has_implicit_cancel(cls) -> bool:
"""
Return True if the live cursor handles the implicit cancelation of the query,
- False otherise.
+ False otherwise.
:return: Whether the live cursor implicitly cancels the query
:see: handle_cursor
diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py
index 2a1d8cc63984b..3b23f7987327e 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -26,7 +26,7 @@
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import Session
-from superset.constants import USER_AGENT
+from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
@@ -181,11 +181,30 @@ def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
query.tracking_url = tracking_url
# Adds the executed query id to the extra payload so the query can be cancelled
- query.set_extra_json_key("cancel_query", cursor.stats["queryId"])
+ query.set_extra_json_key(
+ key=QUERY_CANCEL_KEY,
+ value=(cancel_query_id := cursor.stats["queryId"]),
+ )
session.commit()
+
+ # if query cancelation was requested prior to the handle_cursor call, but
+ # the query was still executed, trigger the actual query cancelation now
+ if query.extra.get(QUERY_EARLY_CANCEL_KEY):
+ cls.cancel_query(
+ cursor=cursor,
+ query=query,
+ cancel_query_id=cancel_query_id,
+ )
+
super().handle_cursor(cursor=cursor, query=query, session=session)
+ @classmethod
+ def prepare_cancel_query(cls, query: Query, session: Session) -> None:
+ if QUERY_CANCEL_KEY not in query.extra:
+ query.set_extra_json_key(QUERY_EARLY_CANCEL_KEY, True)
+ session.commit()
+
@classmethod
def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool:
"""
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 57b9ede0fde21..c8b3bca2b1ebb 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -33,12 +33,14 @@
from superset import (
app,
+ db,
is_feature_enabled,
results_backend,
results_backend_use_msgpack,
security_manager,
)
from superset.common.db_query_status import QueryStatus
+from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY
from superset.dataframe import df_to_records
from superset.db_engine_specs import BaseEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
@@ -69,7 +71,6 @@
SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"]
log_query = config["QUERY_LOGGER"]
logger = logging.getLogger(__name__)
-cancel_query_key = "cancel_query"
class SqlLabException(Exception):
@@ -603,7 +604,7 @@ def cancel_query(query: Query) -> bool:
"""
Cancel a running query.
- Note some engines implicitly handle the cancelation of a query and thus no expliicit
+ Note some engines implicitly handle the cancelation of a query and thus no explicit
action is required.
:param query: Query to cancel
@@ -613,7 +614,16 @@ def cancel_query(query: Query) -> bool:
if query.database.db_engine_spec.has_implicit_cancel():
return True
- cancel_query_id = query.extra.get(cancel_query_key)
+ # Some databases may need to make preparations for query cancellation
+ query.database.db_engine_spec.prepare_cancel_query(query, db.session)
+
+ if query.extra.get(QUERY_EARLY_CANCEL_KEY):
+ # Query has been cancelled prior to being able to set the cancel key.
+ # This can happen if the query cancellation key can only be acquired after the
+ # query has been executed
+ return True
+
+ cancel_query_id = query.extra.get(QUERY_CANCEL_KEY)
if cancel_query_id is None:
return False
diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py
index 6a77e63236091..382b65ce52547 100644
--- a/tests/unit_tests/db_engine_specs/test_trino.py
+++ b/tests/unit_tests/db_engine_specs/test_trino.py
@@ -15,8 +15,15 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
+import json
+from typing import Any, Dict
from unittest import mock
+import pytest
+from pytest_mock import MockerFixture
+
+from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY
+
@mock.patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_success(engine_mock: mock.Mock) -> None:
@@ -36,3 +43,55 @@ def test_cancel_query_failed(engine_mock: mock.Mock) -> None:
query = Query()
cursor_mock = engine_mock.raiseError.side_effect = Exception()
assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is False
+
+
+@pytest.mark.parametrize(
+ "initial_extra,final_extra",
+ [
+ ({}, {QUERY_EARLY_CANCEL_KEY: True}),
+ ({QUERY_CANCEL_KEY: "my_key"}, {QUERY_CANCEL_KEY: "my_key"}),
+ ],
+)
+def test_prepare_cancel_query(
+ initial_extra: Dict[str, Any],
+ final_extra: Dict[str, Any],
+ mocker: MockerFixture,
+) -> None:
+ from superset.db_engine_specs.trino import TrinoEngineSpec
+ from superset.models.sql_lab import Query
+
+ session_mock = mocker.MagicMock()
+ query = Query(extra_json=json.dumps(initial_extra))
+ TrinoEngineSpec.prepare_cancel_query(query=query, session=session_mock)
+ assert query.extra == final_extra
+
+
+@pytest.mark.parametrize("cancel_early", [True, False])
+@mock.patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
+@mock.patch("sqlalchemy.engine.Engine.connect")
+def test_handle_cursor_early_cancel(
+ engine_mock: mock.Mock,
+ cancel_query_mock: mock.Mock,
+ cancel_early: bool,
+ mocker: MockerFixture,
+) -> None:
+ from superset.db_engine_specs.trino import TrinoEngineSpec
+ from superset.models.sql_lab import Query
+
+ query_id = "myQueryId"
+
+ cursor_mock = engine_mock.return_value.__enter__.return_value
+ cursor_mock.stats = {"queryId": query_id}
+ session_mock = mocker.MagicMock()
+
+ query = Query()
+
+ if cancel_early:
+ TrinoEngineSpec.prepare_cancel_query(query=query, session=session_mock)
+
+ TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query, session=session_mock)
+
+ if cancel_early:
+ assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id
+ else:
+ assert cancel_query_mock.call_args is None