From c2715d68d09a52e9a1fc89f57ae966834583b979 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Sat, 24 Dec 2022 06:31:46 +0200 Subject: [PATCH] feat(trino): support early cancellation of queries (#22498) (cherry picked from commit b6d39d194c90dbbf0050bb3d32d2e1a513dfc0a6) --- .../src/SqlLab/actions/sqlLab.js | 6 +- .../src/SqlLab/actions/sqlLab.test.js | 31 ++++++++++ .../src/SqlLab/components/ResultSet/index.tsx | 26 ++++---- .../components/SqlEditorTabHeader/index.tsx | 4 +- .../src/SqlLab/reducers/sqlLab.js | 29 +++++---- .../views/CRUD/data/query/QueryList.test.tsx | 3 +- .../src/views/CRUD/data/query/QueryList.tsx | 34 ++++++++--- .../data/query/QueryPreviewModal.test.tsx | 3 +- superset-frontend/src/views/CRUD/types.ts | 10 +--- superset/constants.py | 3 + superset/db_engine_specs/base.py | 52 ++++++++++------ superset/db_engine_specs/hive.py | 2 +- superset/db_engine_specs/presto.py | 2 +- superset/db_engine_specs/trino.py | 23 +++++++- superset/sql_lab.py | 16 ++++- .../unit_tests/db_engine_specs/test_trino.py | 59 +++++++++++++++++++ 16 files changed, 230 insertions(+), 73 deletions(-) diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js index a58630cda155e..12487d1a9437d 100644 --- a/superset-frontend/src/SqlLab/actions/sqlLab.js +++ b/superset-frontend/src/SqlLab/actions/sqlLab.js @@ -17,7 +17,7 @@ * under the License. */ import shortid from 'shortid'; -import { t, SupersetClient } from '@superset-ui/core'; +import { SupersetClient, t } from '@superset-ui/core'; import invert from 'lodash/invert'; import mapKeys from 'lodash/mapKeys'; import { isFeatureEnabled, FeatureFlag } from 'src/featureFlags'; @@ -229,11 +229,13 @@ export function startQuery(query) { export function querySuccess(query, results) { return function (dispatch) { + const sqlEditorId = results?.query?.sqlEditorId; const sync = + sqlEditorId && !query.isDataPreview && isFeatureEnabled(FeatureFlag.SQLLAB_BACKEND_PERSISTENCE) ? SupersetClient.put({ - endpoint: encodeURI(`/tabstateview/${results.query.sqlEditorId}`), + endpoint: encodeURI(`/tabstateview/${sqlEditorId}`), postPayload: { latest_query_id: query.id }, }) : Promise.resolve(); diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.test.js b/superset-frontend/src/SqlLab/actions/sqlLab.test.js index 7792f1da8ad63..acc79031edc15 100644 --- a/superset-frontend/src/SqlLab/actions/sqlLab.test.js +++ b/superset-frontend/src/SqlLab/actions/sqlLab.test.js @@ -30,6 +30,7 @@ import { initialState, queryId, } from 'src/SqlLab/fixtures'; +import { QueryState } from '@superset-ui/core'; const middlewares = [thunk]; const mockStore = configureMockStore(middlewares); @@ -502,6 +503,7 @@ describe('async actions', () => { const results = { data: mockBigNumber, query: { sqlEditorId: 'abcd' }, + status: QueryState.SUCCESS, query_id: 'efgh', }; fetchMock.get(fetchQueryEndpoint, JSON.stringify(results), { @@ -525,6 +527,35 @@ describe('async actions', () => { expect(fetchMock.calls(updateTabStateEndpoint)).toHaveLength(1); }); }); + + it("doesn't update the tab state in the backend on stoppped query", () => { + expect.assertions(2); + + const results = { + status: QueryState.STOPPED, + query_id: 'efgh', + }; + fetchMock.get(fetchQueryEndpoint, JSON.stringify(results), { + overwriteRoutes: true, + }); + const store = mockStore({}); + const expectedActions = [ + { + type: actions.REQUEST_QUERY_RESULTS, + query, + }, + // missing below + { + type: actions.QUERY_SUCCESS, + query, + results, + }, + ]; + return store.dispatch(actions.fetchQueryResults(query)).then(() => { + expect(store.getActions()).toEqual(expectedActions); + expect(fetchMock.calls(updateTabStateEndpoint)).toHaveLength(0); + }); + }); }); describe('addQueryEditor', () => { diff --git a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx index 0d61d8ddab648..10cdd8a39e29f 100644 --- a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx +++ b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx @@ -16,13 +16,13 @@ * specific language governing permissions and limitations * under the License. */ -import React, { useState, useEffect, useCallback } from 'react'; +import React, { useCallback, useEffect, useState } from 'react'; import { useDispatch } from 'react-redux'; import ButtonGroup from 'src/components/ButtonGroup'; import Alert from 'src/components/Alert'; import Button from 'src/components/Button'; import shortid from 'shortid'; -import { styled, t, QueryResponse } from '@superset-ui/core'; +import { QueryResponse, QueryState, styled, t } from '@superset-ui/core'; import { usePrevious } from 'src/hooks/usePrevious'; import ErrorMessageWithStackTrace from 'src/components/ErrorMessage/ErrorMessageWithStackTrace'; import { @@ -43,9 +43,9 @@ import CopyToClipboard from 'src/components/CopyToClipboard'; import { addDangerToast } from 'src/components/MessageToasts/actions'; import { prepareCopyToClipboardTabularData } from 'src/utils/common'; import { - CtasEnum, - clearQueryResults, addQueryEditor, + clearQueryResults, + CtasEnum, fetchQueryResults, reFetchQueryResults, reRunQuery, @@ -387,8 +387,8 @@ const ResultSet = ({ let trackingUrl; if ( query.trackingUrl && - query.state !== 'success' && - query.state !== 'fetching' + query.state !== QueryState.SUCCESS && + query.state !== QueryState.FETCHING ) { trackingUrl = ( ); } @@ -406,11 +408,11 @@ const ResultSet = ({ sql = ; } - if (query.state === 'stopped') { + if (query.state === QueryState.STOPPED) { return ; } - if (query.state === 'failed') { + if (query.state === QueryState.FAILED) { return ( = ({ queryEditor }) => { }), shallowEqual, ); - const queryStatus = useSelector( + const queryState = useSelector( ({ sqlLab }) => sqlLab.queries[qe.latestQueryId || '']?.state || '', ); const dispatch = useDispatch(); @@ -139,7 +139,7 @@ const SqlEditorTabHeader: React.FC = ({ queryEditor }) => { } /> - {qe.name} {' '} + {qe.name} {' '} ); }; diff --git a/superset-frontend/src/SqlLab/reducers/sqlLab.js b/superset-frontend/src/SqlLab/reducers/sqlLab.js index ed103a2afe1bc..478487d6e239d 100644 --- a/superset-frontend/src/SqlLab/reducers/sqlLab.js +++ b/superset-frontend/src/SqlLab/reducers/sqlLab.js @@ -16,8 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -import { t } from '@superset-ui/core'; - +import { QueryState, t } from '@superset-ui/core'; import getInitialState from './getInitialState'; import * as actions from '../actions/sqlLab'; import { now } from '../../utils/dates'; @@ -391,7 +390,7 @@ export default function sqlLabReducer(state = {}, action) { }, [actions.STOP_QUERY]() { return alterInObject(state, 'queries', action.query, { - state: 'stopped', + state: QueryState.STOPPED, results: [], }); }, @@ -405,12 +404,16 @@ export default function sqlLabReducer(state = {}, action) { }, [actions.REQUEST_QUERY_RESULTS]() { return alterInObject(state, 'queries', action.query, { - state: 'fetching', + state: QueryState.FETCHING, }); }, [actions.QUERY_SUCCESS]() { - // prevent race condition were query succeeds shortly after being canceled - if (action.query.state === 'stopped') { + // prevent race condition where query succeeds shortly after being canceled + // or the final result was unsuccessful + if ( + action.query.state === QueryState.STOPPED || + action.results.status !== QueryState.SUCCESS + ) { return state; } const alts = { @@ -418,7 +421,7 @@ export default function sqlLabReducer(state = {}, action) { progress: 100, results: action.results, rows: action?.results?.query?.rows || 0, - state: 'success', + state: QueryState.SUCCESS, limitingFactor: action?.results?.query?.limitingFactor, tempSchema: action?.results?.query?.tempSchema, tempTable: action?.results?.query?.tempTable, @@ -434,11 +437,11 @@ export default function sqlLabReducer(state = {}, action) { return alterInObject(state, 'queries', action.query, alts); }, [actions.QUERY_FAILED]() { - if (action.query.state === 'stopped') { + if (action.query.state === QueryState.STOPPED) { return state; } const alts = { - state: 'failed', + state: QueryState.FAILED, errors: action.errors, errorMessage: action.msg, endDttm: now(), @@ -723,8 +726,8 @@ export default function sqlLabReducer(state = {}, action) { Object.entries(action.alteredQueries).forEach(([id, changedQuery]) => { if ( !state.queries.hasOwnProperty(id) || - (state.queries[id].state !== 'stopped' && - state.queries[id].state !== 'failed') + (state.queries[id].state !== QueryState.STOPPED && + state.queries[id].state !== QueryState.FAILED) ) { if (changedQuery.changedOn > queriesLastUpdate) { queriesLastUpdate = changedQuery.changedOn; @@ -738,8 +741,8 @@ export default function sqlLabReducer(state = {}, action) { // because of async behavior, sql lab may still poll a couple of seconds // when it started fetching or finished rendering results state: - currentState === 'success' && - ['fetching', 'success'].includes(prevState) + currentState === QueryState.SUCCESS && + [QueryState.FETCHING, QueryState.SUCCESS].includes(prevState) ? prevState : currentState, }; diff --git a/superset-frontend/src/views/CRUD/data/query/QueryList.test.tsx b/superset-frontend/src/views/CRUD/data/query/QueryList.test.tsx index eaaa75a1cb792..be28d7e2dfa85 100644 --- a/superset-frontend/src/views/CRUD/data/query/QueryList.test.tsx +++ b/superset-frontend/src/views/CRUD/data/query/QueryList.test.tsx @@ -33,6 +33,7 @@ import ListView from 'src/components/ListView'; import Filters from 'src/components/ListView/Filters'; import SyntaxHighlighter from 'react-syntax-highlighter/dist/cjs/light'; import SubMenu from 'src/views/components/SubMenu'; +import { QueryState } from '@superset-ui/core'; // store needed for withToasts const mockStore = configureStore([thunk]); @@ -54,7 +55,7 @@ const mockQueries: QueryObject[] = [...new Array(3)].map((_, i) => ({ { schema: 'foo', table: 'table' }, { schema: 'bar', table: 'table_2' }, ], - status: 'success', + status: QueryState.SUCCESS, tab_name: 'Main Tab', user: { first_name: 'cool', diff --git a/superset-frontend/src/views/CRUD/data/query/QueryList.tsx b/superset-frontend/src/views/CRUD/data/query/QueryList.tsx index bbee625092a6b..dbe8e259dacce 100644 --- a/superset-frontend/src/views/CRUD/data/query/QueryList.tsx +++ b/superset-frontend/src/views/CRUD/data/query/QueryList.tsx @@ -17,7 +17,13 @@ * under the License. */ import React, { useMemo, useState, useCallback, ReactElement } from 'react'; -import { SupersetClient, t, styled, useTheme } from '@superset-ui/core'; +import { + QueryState, + styled, + SupersetClient, + t, + useTheme, +} from '@superset-ui/core'; import moment from 'moment'; import { createFetchRelated, @@ -127,7 +133,13 @@ function QueryList({ addDangerToast }: QueryListProps) { row: { original: { status }, }, - }: any) => { + }: { + row: { + original: { + status: QueryState; + }; + }; + }) => { const statusConfig: { name: ReactElement | null; label: string; @@ -135,33 +147,39 @@ function QueryList({ addDangerToast }: QueryListProps) { name: null, label: '', }; - if (status === 'success') { + if (status === QueryState.SUCCESS) { statusConfig.name = ( ); statusConfig.label = t('Success'); - } else if (status === 'failed' || status === 'stopped') { + } else if ( + status === QueryState.FAILED || + status === QueryState.STOPPED + ) { statusConfig.name = ( ); statusConfig.label = t('Failed'); - } else if (status === 'running') { + } else if (status === QueryState.RUNNING) { statusConfig.name = ( ); statusConfig.label = t('Running'); - } else if (status === 'timed_out') { + } else if (status === QueryState.TIMED_OUT) { statusConfig.name = ( ); statusConfig.label = t('Offline'); - } else if (status === 'scheduled' || status === 'pending') { + } else if ( + status === QueryState.SCHEDULED || + status === QueryState.PENDING + ) { statusConfig.name = ( ); diff --git a/superset-frontend/src/views/CRUD/data/query/QueryPreviewModal.test.tsx b/superset-frontend/src/views/CRUD/data/query/QueryPreviewModal.test.tsx index 7a85e4c292123..96498f6e69a65 100644 --- a/superset-frontend/src/views/CRUD/data/query/QueryPreviewModal.test.tsx +++ b/superset-frontend/src/views/CRUD/data/query/QueryPreviewModal.test.tsx @@ -27,6 +27,7 @@ import QueryPreviewModal from 'src/views/CRUD/data/query/QueryPreviewModal'; import { QueryObject } from 'src/views/CRUD/types'; import SyntaxHighlighter from 'react-syntax-highlighter/dist/cjs/light'; import { act } from 'react-dom/test-utils'; +import { QueryState } from '@superset-ui/core'; // store needed for withToasts const mockStore = configureStore([thunk]); @@ -46,7 +47,7 @@ const mockQueries: QueryObject[] = [...new Array(3)].map((_, i) => ({ { schema: 'foo', table: 'table' }, { schema: 'bar', table: 'table_2' }, ], - status: 'success', + status: QueryState.SUCCESS, tab_name: 'Main Tab', user: { first_name: 'cool', diff --git a/superset-frontend/src/views/CRUD/types.ts b/superset-frontend/src/views/CRUD/types.ts index 0090697747ac5..07b26d27340a6 100644 --- a/superset-frontend/src/views/CRUD/types.ts +++ b/superset-frontend/src/views/CRUD/types.ts @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +import { QueryState } from '@superset-ui/core'; import { User } from 'src/types/bootstrapTypes'; import Database from 'src/types/Database'; import Owner from 'src/types/Owner'; @@ -91,14 +92,7 @@ export interface QueryObject { sql: string; executed_sql: string | null; sql_tables?: { catalog?: string; schema: string; table: string }[]; - status: - | 'success' - | 'failed' - | 'stopped' - | 'running' - | 'timed_out' - | 'scheduled' - | 'pending'; + status: QueryState; tab_name: string; user: { first_name: string; diff --git a/superset/constants.py b/superset/constants.py index 7d759acf6741c..5091d65a432dc 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -34,6 +34,9 @@ NO_TIME_RANGE = "No filter" +QUERY_CANCEL_KEY = "cancel_query" +QUERY_EARLY_CANCEL_KEY = "early_cancel_query" + class RouteMethod: # pylint: disable=too-few-public-methods """ diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 3b4d38ee9b641..0f124de34aa65 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=too-many-lines + +from __future__ import annotations + import json import logging import re @@ -477,7 +480,7 @@ def get_text_clause(cls, clause: str) -> TextClause: @classmethod def get_engine( cls, - database: "Database", + database: Database, schema: Optional[str] = None, source: Optional[utils.QuerySource] = None, ) -> ContextManager[Engine]: @@ -732,7 +735,7 @@ def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any] @classmethod def extra_table_metadata( # pylint: disable=unused-argument cls, - database: "Database", + database: Database, table_name: str, schema_name: Optional[str], ) -> Dict[str, Any]: @@ -749,7 +752,7 @@ def extra_table_metadata( # pylint: disable=unused-argument @classmethod def apply_limit_to_sql( - cls, sql: str, limit: int, database: "Database", force: bool = False + cls, sql: str, limit: int, database: Database, force: bool = False ) -> str: """ Alters the SQL statement to apply a LIMIT clause @@ -891,7 +894,7 @@ def get_cte_query(cls, sql: str) -> Optional[str]: @classmethod def df_to_sql( cls, - database: "Database", + database: Database, table: Table, df: pd.DataFrame, to_sql_kwargs: Dict[str, Any], @@ -938,7 +941,7 @@ def convert_dttm( # pylint: disable=unused-argument return None @classmethod - def handle_cursor(cls, cursor: Any, query: "Query", session: Session) -> None: + def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None: """Handle a live cursor between the execute and fetchall calls The flow works without this method doing anything, but it allows @@ -1030,7 +1033,7 @@ def get_schema_names(cls, inspector: Inspector) -> List[str]: @classmethod def get_table_names( # pylint: disable=unused-argument cls, - database: "Database", + database: Database, inspector: Inspector, schema: Optional[str], ) -> Set[str]: @@ -1058,7 +1061,7 @@ def get_table_names( # pylint: disable=unused-argument @classmethod def get_view_names( # pylint: disable=unused-argument cls, - database: "Database", + database: Database, inspector: Inspector, schema: Optional[str], ) -> Set[str]: @@ -1124,7 +1127,7 @@ def get_columns( @classmethod def get_metrics( # pylint: disable=unused-argument cls, - database: "Database", + database: Database, inspector: Inspector, table_name: str, schema: Optional[str], @@ -1146,7 +1149,7 @@ def where_latest_partition( # pylint: disable=too-many-arguments,unused-argumen cls, table_name: str, schema: Optional[str], - database: "Database", + database: Database, query: Select, columns: Optional[List[Dict[str, str]]] = None, ) -> Optional[Select]: @@ -1171,7 +1174,7 @@ def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[Any]: @classmethod def select_star( # pylint: disable=too-many-arguments,too-many-locals cls, - database: "Database", + database: Database, table_name: str, engine: Engine, schema: Optional[str] = None, @@ -1250,7 +1253,7 @@ def query_cost_formatter( raise Exception("Database does not support cost estimation") @classmethod - def process_statement(cls, statement: str, database: "Database") -> str: + def process_statement(cls, statement: str, database: Database) -> str: """ Process a SQL statement by stripping and mutating it. @@ -1274,7 +1277,7 @@ def process_statement(cls, statement: str, database: "Database") -> str: @classmethod def estimate_query_cost( cls, - database: "Database", + database: Database, schema: str, sql: str, source: Optional[utils.QuerySource] = None, @@ -1468,7 +1471,7 @@ def column_datatype_to_string( @classmethod def get_function_names( # pylint: disable=unused-argument cls, - database: "Database", + database: Database, ) -> List[str]: """ Get a list of function names that are able to be called on the database. @@ -1493,7 +1496,7 @@ def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple[Any, ...]]: @staticmethod def mutate_db_for_connection_test( # pylint: disable=unused-argument - database: "Database", + database: Database, ) -> None: """ Some databases require passing additional parameters for validating database @@ -1505,7 +1508,7 @@ def mutate_db_for_connection_test( # pylint: disable=unused-argument return None @staticmethod - def get_extra_params(database: "Database") -> Dict[str, Any]: + def get_extra_params(database: Database) -> Dict[str, Any]: """ Some databases require adding elements to connection parameters, like passing certificates to `extra`. This can be done here. @@ -1524,7 +1527,7 @@ def get_extra_params(database: "Database") -> Dict[str, Any]: @staticmethod def update_params_from_encrypted_extra( # pylint: disable=invalid-name - database: "Database", params: Dict[str, Any] + database: Database, params: Dict[str, Any] ) -> None: """ Some databases require some sensitive information which do not conform to @@ -1586,11 +1589,22 @@ def get_column_spec( # pylint: disable=unused-argument ) return None + # pylint: disable=unused-argument + @classmethod + def prepare_cancel_query(cls, query: Query, session: Session) -> None: + """ + Some databases may acquire the query cancelation id after the query + cancelation request has been received. For those cases, the db engine spec + can record the cancelation intent so that the query can either be stopped + prior to execution, or canceled once the query id is acquired. + """ + return None + @classmethod def has_implicit_cancel(cls) -> bool: """ Return True if the live cursor handles the implicit cancelation of the query, - False otherise. + False otherwise. :return: Whether the live cursor implicitly cancels the query :see: handle_cursor @@ -1602,7 +1616,7 @@ def has_implicit_cancel(cls) -> bool: def get_cancel_query_id( # pylint: disable=unused-argument cls, cursor: Any, - query: "Query", + query: Query, ) -> Optional[str]: """ Select identifiers from the database engine that uniquely identifies the @@ -1620,7 +1634,7 @@ def get_cancel_query_id( # pylint: disable=unused-argument def cancel_query( # pylint: disable=unused-argument cls, cursor: Any, - query: "Query", + query: Query, cancel_query_id: str, ) -> bool: """ diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 8b6767e80ec57..af794ff0c23fa 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -559,7 +559,7 @@ def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool: def has_implicit_cancel(cls) -> bool: """ Return True if the live cursor handles the implicit cancelation of the query, - False otherise. + False otherwise. :return: Whether the live cursor implicitly cancels the query :see: handle_cursor diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 8b7574a021db9..2e8fc09fd1fbb 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -1304,7 +1304,7 @@ def get_column_spec( def has_implicit_cancel(cls) -> bool: """ Return True if the live cursor handles the implicit cancelation of the query, - False otherise. + False otherwise. :return: Whether the live cursor implicitly cancels the query :see: handle_cursor diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 2a1d8cc63984b..3b23f7987327e 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -26,7 +26,7 @@ from sqlalchemy.engine.url import URL from sqlalchemy.orm import Session -from superset.constants import USER_AGENT +from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError @@ -181,11 +181,30 @@ def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: query.tracking_url = tracking_url # Adds the executed query id to the extra payload so the query can be cancelled - query.set_extra_json_key("cancel_query", cursor.stats["queryId"]) + query.set_extra_json_key( + key=QUERY_CANCEL_KEY, + value=(cancel_query_id := cursor.stats["queryId"]), + ) session.commit() + + # if query cancelation was requested prior to the handle_cursor call, but + # the query was still executed, trigger the actual query cancelation now + if query.extra.get(QUERY_EARLY_CANCEL_KEY): + cls.cancel_query( + cursor=cursor, + query=query, + cancel_query_id=cancel_query_id, + ) + super().handle_cursor(cursor=cursor, query=query, session=session) + @classmethod + def prepare_cancel_query(cls, query: Query, session: Session) -> None: + if QUERY_CANCEL_KEY not in query.extra: + query.set_extra_json_key(QUERY_EARLY_CANCEL_KEY, True) + session.commit() + @classmethod def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool: """ diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 57b9ede0fde21..c8b3bca2b1ebb 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -33,12 +33,14 @@ from superset import ( app, + db, is_feature_enabled, results_backend, results_backend_use_msgpack, security_manager, ) from superset.common.db_query_status import QueryStatus +from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY from superset.dataframe import df_to_records from superset.db_engine_specs import BaseEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType @@ -69,7 +71,6 @@ SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"] log_query = config["QUERY_LOGGER"] logger = logging.getLogger(__name__) -cancel_query_key = "cancel_query" class SqlLabException(Exception): @@ -603,7 +604,7 @@ def cancel_query(query: Query) -> bool: """ Cancel a running query. - Note some engines implicitly handle the cancelation of a query and thus no expliicit + Note some engines implicitly handle the cancelation of a query and thus no explicit action is required. :param query: Query to cancel @@ -613,7 +614,16 @@ def cancel_query(query: Query) -> bool: if query.database.db_engine_spec.has_implicit_cancel(): return True - cancel_query_id = query.extra.get(cancel_query_key) + # Some databases may need to make preparations for query cancellation + query.database.db_engine_spec.prepare_cancel_query(query, db.session) + + if query.extra.get(QUERY_EARLY_CANCEL_KEY): + # Query has been cancelled prior to being able to set the cancel key. + # This can happen if the query cancellation key can only be acquired after the + # query has been executed + return True + + cancel_query_id = query.extra.get(QUERY_CANCEL_KEY) if cancel_query_id is None: return False diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 6a77e63236091..382b65ce52547 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -15,8 +15,15 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access +import json +from typing import Any, Dict from unittest import mock +import pytest +from pytest_mock import MockerFixture + +from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY + @mock.patch("sqlalchemy.engine.Engine.connect") def test_cancel_query_success(engine_mock: mock.Mock) -> None: @@ -36,3 +43,55 @@ def test_cancel_query_failed(engine_mock: mock.Mock) -> None: query = Query() cursor_mock = engine_mock.raiseError.side_effect = Exception() assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is False + + +@pytest.mark.parametrize( + "initial_extra,final_extra", + [ + ({}, {QUERY_EARLY_CANCEL_KEY: True}), + ({QUERY_CANCEL_KEY: "my_key"}, {QUERY_CANCEL_KEY: "my_key"}), + ], +) +def test_prepare_cancel_query( + initial_extra: Dict[str, Any], + final_extra: Dict[str, Any], + mocker: MockerFixture, +) -> None: + from superset.db_engine_specs.trino import TrinoEngineSpec + from superset.models.sql_lab import Query + + session_mock = mocker.MagicMock() + query = Query(extra_json=json.dumps(initial_extra)) + TrinoEngineSpec.prepare_cancel_query(query=query, session=session_mock) + assert query.extra == final_extra + + +@pytest.mark.parametrize("cancel_early", [True, False]) +@mock.patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query") +@mock.patch("sqlalchemy.engine.Engine.connect") +def test_handle_cursor_early_cancel( + engine_mock: mock.Mock, + cancel_query_mock: mock.Mock, + cancel_early: bool, + mocker: MockerFixture, +) -> None: + from superset.db_engine_specs.trino import TrinoEngineSpec + from superset.models.sql_lab import Query + + query_id = "myQueryId" + + cursor_mock = engine_mock.return_value.__enter__.return_value + cursor_mock.stats = {"queryId": query_id} + session_mock = mocker.MagicMock() + + query = Query() + + if cancel_early: + TrinoEngineSpec.prepare_cancel_query(query=query, session=session_mock) + + TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query, session=session_mock) + + if cancel_early: + assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id + else: + assert cancel_query_mock.call_args is None