Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(ssh-tunnel): Refactor establishing raw connection with contextmanger #22366

Merged
merged 13 commits into from
Dec 29, 2022
39 changes: 15 additions & 24 deletions superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from __future__ import annotations

import logging
from contextlib import closing
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -136,18 +135,13 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
# TODO(villebro): refactor to use same code that's used by
# sql_lab.py:execute_sql_statements
try:
with dataset.database.get_sqla_engine_with_context(
schema=dataset.schema
) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
db_engine_spec.execute(cursor, query)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(
result, cursor.description, db_engine_spec
)
cols = result_set.columns
with dataset.database.get_raw_connection(schema=dataset.schema) as conn:
cursor = conn.cursor()
query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
db_engine_spec.execute(cursor, query)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
cols = result_set.columns
except Exception as ex:
raise SupersetGenericDBErrorException(message=str(ex)) from ex
return cols
Expand All @@ -159,17 +153,14 @@ def get_columns_description(
) -> List[ResultSetColumnType]:
db_engine_spec = database.db_engine_spec
try:
with database.get_sqla_engine_with_context() as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
query = database.apply_limit_to_sql(query, limit=1)
cursor.execute(query)
db_engine_spec.execute(cursor, query)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(
result, cursor.description, db_engine_spec
)
return result_set.columns
with database.get_raw_connection() as conn:
cursor = conn.cursor()
query = database.apply_limit_to_sql(query, limit=1)
cursor.execute(query)
db_engine_spec.execute(cursor, query)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
return result_set.columns
except Exception as ex:
raise SupersetGenericDBErrorException(message=str(ex)) from ex

Expand Down
15 changes: 6 additions & 9 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import json
import logging
import re
from contextlib import closing
from datetime import datetime
from typing import (
Any,
Expand Down Expand Up @@ -1299,14 +1298,12 @@ def estimate_query_cost(
statements = parsed_query.get_statements()

costs = []
with cls.get_engine(database, schema=schema, source=source) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in statements:
processed_statement = cls.process_statement(statement, database)
costs.append(
cls.estimate_statement_cost(processed_statement, cursor)
)
with database.get_raw_connection(schema=schema, source=source) as conn:
cursor = conn.cursor()
for statement in statements:
processed_statement = cls.process_statement(statement, database)
costs.append(cls.estimate_statement_cost(processed_statement, cursor))

return costs

@classmethod
Expand Down
11 changes: 4 additions & 7 deletions superset/db_engine_specs/gsheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
import json
import re
from contextlib import closing
from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING

from apispec import APISpec
Expand Down Expand Up @@ -109,12 +108,10 @@ def extra_table_metadata(
table_name: str,
schema_name: Optional[str],
) -> Dict[str, Any]:
with cls.get_engine(database, schema=schema_name) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(f'SELECT GET_METADATA("{table_name}")')
results = cursor.fetchone()[0]

with database.get_raw_connection(schema=schema_name) as conn:
cursor = conn.cursor()
cursor.execute(f'SELECT GET_METADATA("{table_name}")')
results = cursor.fetchone()[0]
try:
metadata = json.loads(results)
except Exception: # pylint: disable=broad-except
Expand Down
30 changes: 13 additions & 17 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import time
from abc import ABCMeta
from collections import defaultdict, deque
from contextlib import closing
from datetime import datetime
from distutils.version import StrictVersion
from textwrap import dedent
Expand Down Expand Up @@ -667,13 +666,11 @@ def get_view_names(
).strip()
params = {}

with cls.get_engine(database, schema=schema) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
results = cursor.fetchall()

return {row[0] for row in results}
with database.get_raw_connection(schema=schema) as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
results = cursor.fetchall()
return {row[0] for row in results}

@classmethod
def _create_column_info(
Expand Down Expand Up @@ -1196,16 +1193,15 @@ def get_create_view(
# pylint: disable=import-outside-toplevel
from pyhive.exc import DatabaseError

with cls.get_engine(database, schema=schema) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
sql = f"SHOW CREATE VIEW {schema}.{table}"
try:
cls.execute(cursor, sql)
with database.get_raw_connection(schema=schema) as conn:
cursor = conn.cursor()
sql = f"SHOW CREATE VIEW {schema}.{table}"
try:
cls.execute(cursor, sql)
except DatabaseError: # not a VIEW
return None
rows = cls.fetch_data(cursor, 1)

except DatabaseError: # not a VIEW
return None
rows = cls.fetch_data(cursor, 1)
return rows[0][0]

@classmethod
Expand Down
13 changes: 13 additions & 0 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,19 @@ def _get_sqla_engine(
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)

@contextmanager
def get_raw_connection(
self,
schema: Optional[str] = None,
nullpool: bool = True,
source: Optional[utils.QuerySource] = None,
) -> Connection:
with self.get_sqla_engine_with_context(
schema=schema, nullpool=nullpool, source=source
) as engine:
with closing(engine.raw_connection()) as conn:
yield conn

@property
def quote_identifier(self) -> Callable[[str], str]:
"""Add quotes to potential identifiter expressions if needed"""
Expand Down
104 changes: 47 additions & 57 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,66 +464,56 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
)
)

with database.get_sqla_engine_with_context(
query.schema, source=QuerySource.SQL_LAB
) as engine:
with database.get_raw_connection(query.schema, source=QuerySource.SQL_LAB) as conn:
# Sharing a single connection and cursor across the
# execution of all statements (if many)
with closing(engine.raw_connection()) as conn:
# closing the connection closes the cursor as well
cursor = conn.cursor()
cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
if cancel_query_id is not None:
query.set_extra_json_key(QUERY_CANCEL_KEY, cancel_query_id)
session.commit()
statement_count = len(statements)
for i, statement in enumerate(statements):
# Check if stopped
session.refresh(query)
if query.status == QueryStatus.STOPPED:
payload.update({"status": query.status})
return payload

# For CTAS we create the table only on the last statement
apply_ctas = query.select_as_cta and (
query.ctas_method == CtasMethod.VIEW
or (
query.ctas_method == CtasMethod.TABLE
and i == len(statements) - 1
)
cursor = conn.cursor()
cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
if cancel_query_id is not None:
query.set_extra_json_key(QUERY_CANCEL_KEY, cancel_query_id)
session.commit()
statement_count = len(statements)
for i, statement in enumerate(statements):
# Check if stopped
session.refresh(query)
if query.status == QueryStatus.STOPPED:
payload.update({"status": query.status})
return payload
# For CTAS we create the table only on the last statement
apply_ctas = query.select_as_cta and (
query.ctas_method == CtasMethod.VIEW
or (query.ctas_method == CtasMethod.TABLE and i == len(statements) - 1)
)
# Run statement
msg = f"Running statement {i+1} out of {statement_count}"
logger.info("Query %s: %s", str(query_id), msg)
query.set_extra_json_key("progress", msg)
session.commit()
try:
result_set = execute_sql_statement(
statement,
query,
session,
cursor,
log_params,
apply_ctas,
)

# Run statement
msg = f"Running statement {i+1} out of {statement_count}"
logger.info("Query %s: %s", str(query_id), msg)
query.set_extra_json_key("progress", msg)
session.commit()
try:
result_set = execute_sql_statement(
statement,
query,
session,
cursor,
log_params,
apply_ctas,
)
except SqlLabQueryStoppedException:
payload.update({"status": QueryStatus.STOPPED})
return payload
except Exception as ex: # pylint: disable=broad-except
msg = str(ex)
prefix_message = (
f"[Statement {i+1} out of {statement_count}]"
if statement_count > 1
else ""
)
payload = handle_query_error(
ex, query, session, payload, prefix_message
)
return payload

# Commit the connection so CTA queries will create the table.
conn.commit()
except SqlLabQueryStoppedException:
payload.update({"status": QueryStatus.STOPPED})
return payload
except Exception as ex: # pylint: disable=broad-except
msg = str(ex)
prefix_message = (
f"[Statement {i+1} out of {statement_count}]"
if statement_count > 1
else ""
)
payload = handle_query_error(
ex, query, session, payload, prefix_message
)
return payload
# Commit the connection so CTA queries will create the table.
conn.commit()

# Success, updating the query entry in database
query.rows = result_set.size
Expand Down
2 changes: 2 additions & 0 deletions superset/sql_validators/presto_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def validate(
statements = parsed_query.get_statements()

logger.info("Validating %i statement(s)", len(statements))
# todo(hughhh): update this to use new database.get_raw_connection()
# this function keeps stalling CI
with database.get_sqla_engine_with_context(
schema, source=QuerySource.SQL_LAB
) as engine:
Expand Down
32 changes: 9 additions & 23 deletions tests/integration_tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@ def test_get_datatype_presto(self):
def test_get_view_names_with_schema(self):
database = mock.MagicMock()
mock_execute = mock.MagicMock()
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
database.get_raw_connection().__enter__().cursor().execute = mock_execute
database.get_raw_connection().__enter__().cursor().fetchall = mock.MagicMock(
return_value=[["a", "b,", "c"], ["d", "e"]]
)

Expand All @@ -61,10 +59,8 @@ def test_get_view_names_with_schema(self):
def test_get_view_names_without_schema(self):
database = mock.MagicMock()
mock_execute = mock.MagicMock()
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
database.get_raw_connection().__enter__().cursor().execute = mock_execute
database.get_raw_connection().__enter__().cursor().fetchall = mock.MagicMock(
return_value=[["a", "b,", "c"], ["d", "e"]]
)
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None)
Expand Down Expand Up @@ -823,15 +819,9 @@ def test_get_create_view(self):
mock_execute = mock.MagicMock()
mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]])
database = mock.MagicMock()
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
mock_fetchall
)
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.poll.return_value = (
False
)
database.get_raw_connection().__enter__().cursor().execute = mock_execute
database.get_raw_connection().__enter__().cursor().fetchall = mock_fetchall
database.get_raw_connection().__enter__().cursor().return_value = False
schema = "schema"
table = "table"
result = PrestoEngineSpec.get_create_view(database, schema=schema, table=table)
Expand All @@ -841,9 +831,7 @@ def test_get_create_view(self):
def test_get_create_view_exception(self):
mock_execute = mock.MagicMock(side_effect=Exception())
database = mock.MagicMock()
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_raw_connection().__enter__().cursor().execute = mock_execute
schema = "schema"
table = "table"
with self.assertRaises(Exception):
Expand All @@ -854,9 +842,7 @@ def test_get_create_view_database_error(self):

mock_execute = mock.MagicMock(side_effect=DatabaseError())
database = mock.MagicMock()
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_raw_connection().__enter__().cursor().execute = mock_execute
schema = "schema"
table = "table"
result = PrestoEngineSpec.get_create_view(database, schema=schema, table=table)
Expand Down
Loading