From 528ce78640222d368a5715032a6cfc5f37868140 Mon Sep 17 00:00:00 2001 From: Joao Ferrao Date: Sat, 31 Aug 2024 13:10:28 +0200 Subject: [PATCH] fix(query-exec): missing context info during trino's individual thread query execution --- superset/db_engine_specs/trino.py | 3 +- .../unit_tests/db_engine_specs/test_trino.py | 56 ++++++++++--------- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 62924c50458c0..755dd21b2ae97 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -27,7 +27,7 @@ import numpy as np import pandas as pd import pyarrow as pa -from flask import ctx, current_app, Flask, g +from flask import copy_current_request_context, ctx, current_app, Flask, g from sqlalchemy import text from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL @@ -242,6 +242,7 @@ def execute_with_cursor( execute_result: dict[str, Any] = {} execute_event = threading.Event() + @copy_current_request_context def _execute( results: dict[str, Any], event: threading.Event, diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 5a32cd05044cd..990ae891c465c 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -421,21 +421,23 @@ def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture): def _mock_execute(*args, **kwargs): mock_cursor.query_id = query_id - mock_cursor.execute.side_effect = _mock_execute - with patch.dict( - "superset.config.DISALLOWED_SQL_FUNCTIONS", - {}, - clear=True, - ): - TrinoEngineSpec.execute_with_cursor( - cursor=mock_cursor, - sql="SELECT 1 FROM foo", - query=mock_query, - ) + with app.test_request_context("/some/place/"): + mock_cursor.execute.side_effect = _mock_execute - mock_query.set_extra_json_key.assert_called_once_with( - key=QUERY_CANCEL_KEY, value=query_id - ) + with patch.dict( + "superset.config.DISALLOWED_SQL_FUNCTIONS", + {}, + clear=True, + ): + TrinoEngineSpec.execute_with_cursor( + cursor=mock_cursor, + sql="SELECT 1 FROM foo", + query=mock_query, + ) + + mock_query.set_extra_json_key.assert_called_once_with( + key=QUERY_CANCEL_KEY, value=query_id + ) def test_execute_with_cursor_app_context(app, mocker: MockerFixture): @@ -446,23 +448,25 @@ def test_execute_with_cursor_app_context(app, mocker: MockerFixture): mock_cursor.query_id = None mock_query = mocker.MagicMock() - g.some_value = "some_value" def _mock_execute(*args, **kwargs): assert has_app_context() assert g.some_value == "some_value" - with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute): - with patch.dict( - "superset.config.DISALLOWED_SQL_FUNCTIONS", - {}, - clear=True, - ): - TrinoEngineSpec.execute_with_cursor( - cursor=mock_cursor, - sql="SELECT 1 FROM foo", - query=mock_query, - ) + with app.test_request_context("/some/place/"): + g.some_value = "some_value" + + with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute): + with patch.dict( + "superset.config.DISALLOWED_SQL_FUNCTIONS", + {}, + clear=True, + ): + TrinoEngineSpec.execute_with_cursor( + cursor=mock_cursor, + sql="SELECT 1 FROM foo", + query=mock_query, + ) def test_get_columns(mocker: MockerFixture):