Skip to content

Commit

Permalink
fix(query-exec): missing context info during trino's individual threa…
Browse files Browse the repository at this point in the history
…d query execution
  • Loading branch information
joaoferrao committed Aug 31, 2024
1 parent 7ebecc3 commit 528ce78
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 27 deletions.
3 changes: 2 additions & 1 deletion superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np
import pandas as pd
import pyarrow as pa
from flask import ctx, current_app, Flask, g
from flask import copy_current_request_context, ctx, current_app, Flask, g
from sqlalchemy import text
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
Expand Down Expand Up @@ -242,6 +242,7 @@ def execute_with_cursor(
execute_result: dict[str, Any] = {}
execute_event = threading.Event()

@copy_current_request_context
def _execute(
results: dict[str, Any],
event: threading.Event,
Expand Down
56 changes: 30 additions & 26 deletions tests/unit_tests/db_engine_specs/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,21 +421,23 @@ def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
def _mock_execute(*args, **kwargs):
mock_cursor.query_id = query_id

mock_cursor.execute.side_effect = _mock_execute
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)
with app.test_request_context("/some/place/"):
mock_cursor.execute.side_effect = _mock_execute

mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)

mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)


def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
Expand All @@ -446,23 +448,25 @@ def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
mock_cursor.query_id = None

mock_query = mocker.MagicMock()
g.some_value = "some_value"

def _mock_execute(*args, **kwargs):
assert has_app_context()
assert g.some_value == "some_value"

with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)
with app.test_request_context("/some/place/"):
g.some_value = "some_value"

with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)


def test_get_columns(mocker: MockerFixture):
Expand Down

0 comments on commit 528ce78

Please sign in to comment.