Skip to content

Commit

Permalink
fix: calls to _get_sqla_engine (apache#24953)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Aug 11, 2023
1 parent 760d3e1 commit d16f4c2
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 37 deletions.
7 changes: 4 additions & 3 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ def extra_table_metadata(
}

if database.has_view_by_name(table_name, schema_name):
metadata["view"] = database.inspector.get_view_definition(
table_name, schema_name
)
with database.get_inspector_with_context() as inspector:
metadata["view"] = inspector.get_view_definition(
table_name, schema_name
)

return metadata

Expand Down
51 changes: 24 additions & 27 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,8 @@ def get_df( # pylint: disable=too-many-locals
mutator: Callable[[pd.DataFrame], None] | None = None,
) -> pd.DataFrame:
sqls = self.db_engine_spec.parse_sql(sql)
engine = self._get_sqla_engine(schema)
with self.get_sqla_engine_with_context(schema) as engine:
engine_url = engine.url
mutate_after_split = config["MUTATE_AFTER_SPLIT"]
sql_query_mutator = config["SQL_QUERY_MUTATOR"]

Expand All @@ -577,7 +578,7 @@ def needs_conversion(df_series: pd.Series) -> bool:
def _log_query(sql: str) -> None:
if log_query:
log_query(
engine.url,
engine_url,
sql,
schema,
__name__,
Expand Down Expand Up @@ -624,13 +625,12 @@ def _log_query(sql: str) -> None:
return df

def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str:
engine = self._get_sqla_engine(schema=schema)
with self.get_sqla_engine_with_context(schema) as engine:
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))

sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))

# pylint: disable=protected-access
if engine.dialect.identifier_preparer._double_percents: # noqa
sql = sql.replace("%%", "%")
# pylint: disable=protected-access
if engine.dialect.identifier_preparer._double_percents: # noqa
sql = sql.replace("%%", "%")

return sql

Expand All @@ -645,18 +645,18 @@ def select_star( # pylint: disable=too-many-arguments
cols: list[ResultSetColumnType] | None = None,
) -> str:
"""Generates a ``select *`` statement in the proper dialect"""
eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
return self.db_engine_spec.select_star(
self,
table_name,
schema=schema,
engine=eng,
limit=limit,
show_cols=show_cols,
indent=indent,
latest_partition=latest_partition,
cols=cols,
)
with self.get_sqla_engine_with_context(schema) as engine:
return self.db_engine_spec.select_star(
self,
table_name,
schema=schema,
engine=engine,
limit=limit,
show_cols=show_cols,
indent=indent,
latest_partition=latest_partition,
cols=cols,
)

def apply_limit_to_sql(
self, sql: str, limit: int = 1000, force: bool = False
Expand All @@ -668,11 +668,6 @@ def apply_limit_to_sql(
def safe_sqlalchemy_uri(self) -> str:
return self.sqlalchemy_uri

@property
def inspector(self) -> Inspector:
engine = self._get_sqla_engine()
return sqla.inspect(engine)

@cache_util.memoized_func(
key="db:{self.id}:schema:{schema}:table_list",
cache=cache_manager.cache,
Expand Down Expand Up @@ -955,8 +950,10 @@ def _has_view(
return view_name in view_names

def has_view(self, view_name: str, schema: str | None = None) -> bool:
engine = self._get_sqla_engine()
return engine.run_callable(self._has_view, engine.dialect, view_name, schema)
with self.get_sqla_engine_with_context(schema) as engine:
return engine.run_callable(
self._has_view, engine.dialect, view_name, schema
)

def has_view_by_name(self, view_name: str, schema: str | None = None) -> bool:
return self.has_view(view_name=view_name, schema=schema)
Expand Down
5 changes: 2 additions & 3 deletions tests/integration_tests/celery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,8 @@ def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None:
def quote_f(value: Optional[str]):
if not value:
return value
return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier(
value
)
with get_example_database().get_inspector_with_context() as inspector:
return inspector.engine.dialect.identifier_preparer.quote_identifier(value)


def cta_result(ctas_method: CtasMethod):
Expand Down
7 changes: 4 additions & 3 deletions tests/integration_tests/charts/data/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ def get_expected_row_count(self, client_id: str) -> int:

def quote_name(self, name: str):
if get_main_database().backend in {"presto", "hive"}:
return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier(
name
)
with get_example_database().get_inspector_with_context() as inspector: # E: Ne
return inspector.engine.dialect.identifier_preparer.quote_identifier(
name
)
return name


Expand Down
3 changes: 2 additions & 1 deletion tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ def test_select_star(self):
db = get_example_database()
table_name = "energy_usage"
sql = db.select_star(table_name, show_cols=False, latest_partition=False)
quote = db.inspector.engine.dialect.identifier_preparer.quote_identifier
with db.get_sqla_engine_with_context() as engine:
quote = engine.dialect.identifier_preparer.quote_identifier
expected = (
textwrap.dedent(
f"""\
Expand Down

0 comments on commit d16f4c2

Please sign in to comment.