Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: create function for get_sqla_engine with context #21790

Merged
merged 12 commits into from
Oct 25, 2022
14 changes: 13 additions & 1 deletion superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import textwrap
from ast import literal_eval
from contextlib import closing
from contextlib import closing, contextmanager
from copy import deepcopy
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
Expand Down Expand Up @@ -362,6 +362,18 @@ def get_effective_user(self, object_url: URL) -> Optional[str]:
else None
)

@contextmanager
def get_sqla_engine_with_context(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between having this new get_sqla_engine_with_context with the decorator VS using it in our existing get_sqla_engine ? I Mean, couldn't we just use the existing one and make use or not of the new functionality the decorator brings when needed? or would that mean changing a ton of places where the get_sqla_engine is being used right now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now there's no difference, but Hugh is planning to add support for SSH tunneling, which would require a setup phase before the engine is created, and a teardown after. In order for the SSH tunnel to work everywhere we will need to replace all existing calls with the new context manager.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok ok, makes sense then 😎 thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually I think we want to rename this to get_sqla_engine, since we want this to be the one and only way to create an engine.

self,
schema: Optional[str] = None,
nullpool: bool = True,
source: Optional[utils.QuerySource] = None,
) -> Engine:
try:
yield self.get_sqla_engine(schema=schema, nullpool=nullpool, source=source)
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)

def get_sqla_engine(
self,
schema: Optional[str] = None,
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def mock_provider() -> Mock:

@fixture(scope="session")
def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine:
return example_db_provider().get_sqla_engine()
with example_db_provider().get_sqla_engine_with_context() as engine:
return engine


@fixture(scope="session")
Expand Down
90 changes: 45 additions & 45 deletions tests/integration_tests/access_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,64 +158,64 @@ def test_override_role_permissions_is_admin_only(self):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_override_role_permissions_1_table(self):
database = get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
with database.get_sqla_engine_with_context() as engine:
schema = inspect(engine).default_schema_name

perm_data = ROLE_TABLES_PERM_DATA.copy()
perm_data["database"][0]["schema"][0]["name"] = schema
perm_data = ROLE_TABLES_PERM_DATA.copy()
perm_data["database"][0]["schema"][0]["name"] = schema

response = self.client.post(
"/superset/override_role_permissions/",
data=json.dumps(perm_data),
content_type="application/json",
)
self.assertEqual(201, response.status_code)
response = self.client.post(
"/superset/override_role_permissions/",
data=json.dumps(perm_data),
content_type="application/json",
)
self.assertEqual(201, response.status_code)

updated_override_me = security_manager.find_role("override_me")
self.assertEqual(1, len(updated_override_me.permissions))
birth_names = self.get_table(name="birth_names")
self.assertEqual(
birth_names.perm, updated_override_me.permissions[0].view_menu.name
)
self.assertEqual(
"datasource_access", updated_override_me.permissions[0].permission.name
)
updated_override_me = security_manager.find_role("override_me")
self.assertEqual(1, len(updated_override_me.permissions))
birth_names = self.get_table(name="birth_names")
self.assertEqual(
birth_names.perm, updated_override_me.permissions[0].view_menu.name
)
self.assertEqual(
"datasource_access", updated_override_me.permissions[0].permission.name
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this doesn't have to be inside the context manager.


@pytest.mark.usefixtures(
"load_energy_table_with_slice", "load_birth_names_dashboard_with_slices"
)
def test_override_role_permissions_drops_absent_perms(self):
database = get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
with database.get_sqla_engine_with_context() as engine:
schema = inspect(engine).default_schema_name

override_me = security_manager.find_role("override_me")
override_me.permissions.append(
security_manager.find_permission_view_menu(
view_menu_name=self.get_table(name="energy_usage").perm,
permission_name="datasource_access",
override_me = security_manager.find_role("override_me")
override_me.permissions.append(
security_manager.find_permission_view_menu(
view_menu_name=self.get_table(name="energy_usage").perm,
permission_name="datasource_access",
)
)
)
db.session.flush()
db.session.flush()

perm_data = ROLE_TABLES_PERM_DATA.copy()
perm_data["database"][0]["schema"][0]["name"] = schema
perm_data = ROLE_TABLES_PERM_DATA.copy()
perm_data["database"][0]["schema"][0]["name"] = schema

response = self.client.post(
"/superset/override_role_permissions/",
data=json.dumps(perm_data),
content_type="application/json",
)
self.assertEqual(201, response.status_code)
updated_override_me = security_manager.find_role("override_me")
self.assertEqual(1, len(updated_override_me.permissions))
birth_names = self.get_table(name="birth_names")
self.assertEqual(
birth_names.perm, updated_override_me.permissions[0].view_menu.name
)
self.assertEqual(
"datasource_access", updated_override_me.permissions[0].permission.name
)
response = self.client.post(
"/superset/override_role_permissions/",
data=json.dumps(perm_data),
content_type="application/json",
)
self.assertEqual(201, response.status_code)
updated_override_me = security_manager.find_role("override_me")
self.assertEqual(1, len(updated_override_me.permissions))
birth_names = self.get_table(name="birth_names")
self.assertEqual(
birth_names.perm, updated_override_me.permissions[0].view_menu.name
)
self.assertEqual(
"datasource_access", updated_override_me.permissions[0].permission.name
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.


def test_clean_requests_after_role_extend(self):
session = db.session
Expand Down
4 changes: 3 additions & 1 deletion tests/integration_tests/celery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def run_sql(
def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None:
"""Drop table if it exists, works on any DB"""
sql = f"DROP {table_type} IF EXISTS {table_name}"
get_example_database().get_sqla_engine().execute(sql)
database = get_example_database()
with database.get_sqla_engine_with_context() as engine:
engine.execute(sql)


def quote_f(value: Optional[str]):
Expand Down
18 changes: 8 additions & 10 deletions tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ def setup_sample_data() -> Any:
yield

with app.app_context():
engine = get_example_database().get_sqla_engine()

# drop sqlachemy tables

db.session.commit()
Expand Down Expand Up @@ -210,14 +208,14 @@ def setup_presto_if_needed():

if backend in {"presto", "hive"}:
database = get_example_database()
engine = database.get_sqla_engine()
drop_from_schema(engine, CTAS_SCHEMA_NAME)
engine.execute(f"DROP SCHEMA IF EXISTS {CTAS_SCHEMA_NAME}")
engine.execute(f"CREATE SCHEMA {CTAS_SCHEMA_NAME}")

drop_from_schema(engine, ADMIN_SCHEMA_NAME)
engine.execute(f"DROP SCHEMA IF EXISTS {ADMIN_SCHEMA_NAME}")
engine.execute(f"CREATE SCHEMA {ADMIN_SCHEMA_NAME}")
with database.get_sqla_engine_with_context() as engine:
drop_from_schema(engine, CTAS_SCHEMA_NAME)
engine.execute(f"DROP SCHEMA IF EXISTS {CTAS_SCHEMA_NAME}")
engine.execute(f"CREATE SCHEMA {CTAS_SCHEMA_NAME}")

drop_from_schema(engine, ADMIN_SCHEMA_NAME)
engine.execute(f"DROP SCHEMA IF EXISTS {ADMIN_SCHEMA_NAME}")
engine.execute(f"CREATE SCHEMA {ADMIN_SCHEMA_NAME}")


def with_feature_flags(**mock_feature_flags):
Expand Down
16 changes: 8 additions & 8 deletions tests/integration_tests/csv_upload_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ def setup_csv_upload(login_as_admin):
yield

upload_db = get_upload_db()
engine = upload_db.get_sqla_engine()
engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {PARQUET_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_SCHEMA}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_EXPLORE}")
db.session.delete(upload_db)
db.session.commit()
with upload_db.get_sqla_engine_with_context() as engine:
engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {PARQUET_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_SCHEMA}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_EXPLORE}")
db.session.delete(upload_db)
db.session.commit()


@pytest.fixture(scope="module")
Expand Down
27 changes: 14 additions & 13 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,10 @@ def test_create_dataset_same_name_different_schema(self):
return

example_db = get_example_database()
example_db.get_sqla_engine().execute(
f"CREATE TABLE {CTAS_SCHEMA_NAME}.birth_names AS SELECT 2 as two"
)
with example_db.get_sqla_engine_with_context() as engine:
engine.execute(
f"CREATE TABLE {CTAS_SCHEMA_NAME}.birth_names AS SELECT 2 as two"
)

self.login(username="admin")
table_data = {
Expand All @@ -635,9 +636,8 @@ def test_create_dataset_same_name_different_schema(self):
uri = f'api/v1/dataset/{data.get("id")}'
rv = self.client.delete(uri)
assert rv.status_code == 200
example_db.get_sqla_engine().execute(
f"DROP TABLE {CTAS_SCHEMA_NAME}.birth_names"
)
with example_db.get_sqla_engine_with_context() as engine:
engine.execute(f"DROP TABLE {CTAS_SCHEMA_NAME}.birth_names")

def test_create_dataset_validate_database(self):
"""
Expand Down Expand Up @@ -703,13 +703,14 @@ def test_create_dataset_validate_view_exists(
mock_get_table.return_value = None

example_db = get_example_database()
engine = example_db.get_sqla_engine()
dialect = engine.dialect

with patch.object(
dialect, "get_view_names", wraps=dialect.get_view_names
) as patch_get_view_names:
patch_get_view_names.return_value = ["test_case_view"]
with example_db.get_sqla_engine_with_context() as engine:
engine = engine
dialect = engine.dialect

with patch.object(
dialect, "get_view_names", wraps=dialect.get_view_names
) as patch_get_view_names:
patch_get_view_names.return_value = ["test_case_view"]

self.login(username="admin")
table_data = {
Expand Down
19 changes: 9 additions & 10 deletions tests/integration_tests/datasource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,17 @@ def create_test_table_context(database: Database):
schema = get_example_default_schema()
full_table_name = f"{schema}.test_table" if schema else "test_table"

database.get_sqla_engine().execute(
f"CREATE TABLE IF NOT EXISTS {full_table_name} AS SELECT 1 as first, 2 as second"
)
database.get_sqla_engine().execute(
f"INSERT INTO {full_table_name} (first, second) VALUES (1, 2)"
)
database.get_sqla_engine().execute(
f"INSERT INTO {full_table_name} (first, second) VALUES (3, 4)"
)
with database.get_sqla_engine_with_context() as engine:
engine.execute(
f"CREATE TABLE IF NOT EXISTS {full_table_name} AS SELECT 1 as first, 2 as second"
)
engine.execute(f"INSERT INTO {full_table_name} (first, second) VALUES (1, 2)")
engine.execute(f"INSERT INTO {full_table_name} (first, second) VALUES (3, 4)")

yield db.session
database.get_sqla_engine().execute(f"DROP TABLE {full_table_name}")

with database.get_sqla_engine_with_context() as engine:
engine.execute(f"DROP TABLE {full_table_name}")


class TestDatasource(SupersetTestCase):
Expand Down
27 changes: 14 additions & 13 deletions tests/integration_tests/fixtures/energy_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,22 @@
def load_energy_table_data():
with app.app_context():
database = get_example_database()
df = _get_dataframe()
df.to_sql(
ENERGY_USAGE_TBL_NAME,
database.get_sqla_engine(),
if_exists="replace",
chunksize=500,
index=False,
dtype={"source": String(255), "target": String(255), "value": Float()},
method="multi",
schema=get_example_default_schema(),
)
with database.get_sqla_engine_with_context() as engine:
df = _get_dataframe()
df.to_sql(
ENERGY_USAGE_TBL_NAME,
engine,
if_exists="replace",
chunksize=500,
index=False,
dtype={"source": String(255), "target": String(255), "value": Float()},
method="multi",
schema=get_example_default_schema(),
)
yield
with app.app_context():
engine = get_example_database().get_sqla_engine()
engine.execute("DROP TABLE IF EXISTS energy_usage")
with get_example_database().get_sqla_engine_with_context() as engine:
engine.execute("DROP TABLE IF EXISTS energy_usage")


@pytest.fixture()
Expand Down
21 changes: 11 additions & 10 deletions tests/integration_tests/fixtures/unicode_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,17 @@
@pytest.fixture(scope="session")
def load_unicode_data():
with app.app_context():
_get_dataframe().to_sql(
UNICODE_TBL_NAME,
get_example_database().get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype={"phrase": String(500)},
index=False,
method="multi",
schema=get_example_default_schema(),
)
with get_example_database().get_sqla_engine_with_context() as engine:
_get_dataframe().to_sql(
UNICODE_TBL_NAME,
engine,
if_exists="replace",
chunksize=500,
dtype={"phrase": String(500)},
index=False,
method="multi",
schema=get_example_default_schema(),
)

yield
with app.app_context():
Expand Down
21 changes: 11 additions & 10 deletions tests/integration_tests/fixtures/world_bank_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,17 @@ def load_world_bank_data():
"country_name": String(255),
"region": String(255),
}
_get_dataframe(database).to_sql(
WB_HEALTH_POPULATION,
get_example_database().get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype=dtype,
index=False,
method="multi",
schema=get_example_default_schema(),
)
with database.get_sqla_engine_with_context() as engine:
_get_dataframe(database).to_sql(
WB_HEALTH_POPULATION,
engine,
if_exists="replace",
chunksize=500,
dtype=dtype,
index=False,
method="multi",
schema=get_example_default_schema(),
)

yield
with app.app_context():
Expand Down
Loading