From 7600da80412186d0f5d0c85e6cd831fbae2e9d9e Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Tue, 25 Oct 2022 14:12:48 -0400 Subject: [PATCH] feat: create function for get_sqla_engine with context (#21790) --- superset/models/core.py | 14 +++- tests/conftest.py | 3 +- tests/integration_tests/access_tests.py | 8 +- tests/integration_tests/celery_tests.py | 4 +- tests/integration_tests/conftest.py | 18 ++--- tests/integration_tests/csv_upload_tests.py | 16 ++-- tests/integration_tests/datasets/api_tests.py | 27 +++---- tests/integration_tests/datasource_tests.py | 19 +++-- .../fixtures/energy_dashboard.py | 27 +++---- .../fixtures/unicode_dashboard.py | 21 +++--- .../fixtures/world_bank_dashboard.py | 21 +++--- tests/integration_tests/model_tests.py | 75 +++++++++++++------ .../reports/commands_tests.py | 16 ++-- tests/integration_tests/sqllab_tests.py | 56 +++++++------- 14 files changed, 182 insertions(+), 143 deletions(-) diff --git a/superset/models/core.py b/superset/models/core.py index 008230ef4874f..d0a32a1864c2b 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -21,7 +21,7 @@ import logging import textwrap from ast import literal_eval -from contextlib import closing +from contextlib import closing, contextmanager from copy import deepcopy from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type @@ -362,6 +362,18 @@ def get_effective_user(self, object_url: URL) -> Optional[str]: else None ) + @contextmanager + def get_sqla_engine_with_context( + self, + schema: Optional[str] = None, + nullpool: bool = True, + source: Optional[utils.QuerySource] = None, + ) -> Engine: + try: + yield self.get_sqla_engine(schema=schema, nullpool=nullpool, source=source) + except Exception as ex: + raise self.db_engine_spec.get_dbapi_mapped_exception(ex) + def get_sqla_engine( self, schema: Optional[str] = None, diff --git a/tests/conftest.py b/tests/conftest.py index 2c129965f1bd6..a5945f2f5c4c4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -70,7 +70,8 @@ def mock_provider() -> Mock: @fixture(scope="session") def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine: - return example_db_provider().get_sqla_engine() + with example_db_provider().get_sqla_engine_with_context() as engine: + return engine @fixture(scope="session") diff --git a/tests/integration_tests/access_tests.py b/tests/integration_tests/access_tests.py index 5ab03055d9ee6..ae8b39a8d289a 100644 --- a/tests/integration_tests/access_tests.py +++ b/tests/integration_tests/access_tests.py @@ -158,8 +158,8 @@ def test_override_role_permissions_is_admin_only(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_override_role_permissions_1_table(self): database = get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name perm_data = ROLE_TABLES_PERM_DATA.copy() perm_data["database"][0]["schema"][0]["name"] = schema @@ -186,8 +186,8 @@ def test_override_role_permissions_1_table(self): ) def test_override_role_permissions_drops_absent_perms(self): database = get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name override_me = security_manager.find_role("override_me") override_me.permissions.append( diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py index f057d3128e574..da6db727e7114 100644 --- a/tests/integration_tests/celery_tests.py +++ b/tests/integration_tests/celery_tests.py @@ -112,7 +112,9 @@ def run_sql( def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None: """Drop table if it exists, works on any DB""" sql = f"DROP {table_type} IF EXISTS {table_name}" - get_example_database().get_sqla_engine().execute(sql) + database = get_example_database() + with database.get_sqla_engine_with_context() as engine: + engine.execute(sql) def quote_f(value: Optional[str]): diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 463f93b833681..efbc6bf7f07d3 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -134,8 +134,6 @@ def setup_sample_data() -> Any: yield with app.app_context(): - engine = get_example_database().get_sqla_engine() - # drop sqlachemy tables db.session.commit() @@ -210,14 +208,14 @@ def setup_presto_if_needed(): if backend in {"presto", "hive"}: database = get_example_database() - engine = database.get_sqla_engine() - drop_from_schema(engine, CTAS_SCHEMA_NAME) - engine.execute(f"DROP SCHEMA IF EXISTS {CTAS_SCHEMA_NAME}") - engine.execute(f"CREATE SCHEMA {CTAS_SCHEMA_NAME}") - - drop_from_schema(engine, ADMIN_SCHEMA_NAME) - engine.execute(f"DROP SCHEMA IF EXISTS {ADMIN_SCHEMA_NAME}") - engine.execute(f"CREATE SCHEMA {ADMIN_SCHEMA_NAME}") + with database.get_sqla_engine_with_context() as engine: + drop_from_schema(engine, CTAS_SCHEMA_NAME) + engine.execute(f"DROP SCHEMA IF EXISTS {CTAS_SCHEMA_NAME}") + engine.execute(f"CREATE SCHEMA {CTAS_SCHEMA_NAME}") + + drop_from_schema(engine, ADMIN_SCHEMA_NAME) + engine.execute(f"DROP SCHEMA IF EXISTS {ADMIN_SCHEMA_NAME}") + engine.execute(f"CREATE SCHEMA {ADMIN_SCHEMA_NAME}") def with_feature_flags(**mock_feature_flags): diff --git a/tests/integration_tests/csv_upload_tests.py b/tests/integration_tests/csv_upload_tests.py index 6cd228b9cb1d1..3941606aba2ca 100644 --- a/tests/integration_tests/csv_upload_tests.py +++ b/tests/integration_tests/csv_upload_tests.py @@ -71,14 +71,14 @@ def setup_csv_upload(login_as_admin): yield upload_db = get_upload_db() - engine = upload_db.get_sqla_engine() - engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}") - engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}") - engine.execute(f"DROP TABLE IF EXISTS {PARQUET_UPLOAD_TABLE}") - engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_SCHEMA}") - engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_EXPLORE}") - db.session.delete(upload_db) - db.session.commit() + with upload_db.get_sqla_engine_with_context() as engine: + engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}") + engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}") + engine.execute(f"DROP TABLE IF EXISTS {PARQUET_UPLOAD_TABLE}") + engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_SCHEMA}") + engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_EXPLORE}") + db.session.delete(upload_db) + db.session.commit() @pytest.fixture(scope="module") diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index ef003d05dc600..0175a2c3341b3 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -670,9 +670,10 @@ def test_create_dataset_same_name_different_schema(self): return example_db = get_example_database() - example_db.get_sqla_engine().execute( - f"CREATE TABLE {CTAS_SCHEMA_NAME}.birth_names AS SELECT 2 as two" - ) + with example_db.get_sqla_engine_with_context() as engine: + engine.execute( + f"CREATE TABLE {CTAS_SCHEMA_NAME}.birth_names AS SELECT 2 as two" + ) self.login(username="admin") table_data = { @@ -690,9 +691,8 @@ def test_create_dataset_same_name_different_schema(self): uri = f'api/v1/dataset/{data.get("id")}' rv = self.client.delete(uri) assert rv.status_code == 200 - example_db.get_sqla_engine().execute( - f"DROP TABLE {CTAS_SCHEMA_NAME}.birth_names" - ) + with example_db.get_sqla_engine_with_context() as engine: + engine.execute(f"DROP TABLE {CTAS_SCHEMA_NAME}.birth_names") def test_create_dataset_validate_database(self): """ @@ -758,13 +758,14 @@ def test_create_dataset_validate_view_exists( mock_get_table.return_value = None example_db = get_example_database() - engine = example_db.get_sqla_engine() - dialect = engine.dialect - - with patch.object( - dialect, "get_view_names", wraps=dialect.get_view_names - ) as patch_get_view_names: - patch_get_view_names.return_value = ["test_case_view"] + with example_db.get_sqla_engine_with_context() as engine: + engine = engine + dialect = engine.dialect + + with patch.object( + dialect, "get_view_names", wraps=dialect.get_view_names + ) as patch_get_view_names: + patch_get_view_names.return_value = ["test_case_view"] self.login(username="admin") table_data = { diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index ef3ba0c69d6b8..0896971743a34 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -45,18 +45,17 @@ def create_test_table_context(database: Database): schema = get_example_default_schema() full_table_name = f"{schema}.test_table" if schema else "test_table" - database.get_sqla_engine().execute( - f"CREATE TABLE IF NOT EXISTS {full_table_name} AS SELECT 1 as first, 2 as second" - ) - database.get_sqla_engine().execute( - f"INSERT INTO {full_table_name} (first, second) VALUES (1, 2)" - ) - database.get_sqla_engine().execute( - f"INSERT INTO {full_table_name} (first, second) VALUES (3, 4)" - ) + with database.get_sqla_engine_with_context() as engine: + engine.execute( + f"CREATE TABLE IF NOT EXISTS {full_table_name} AS SELECT 1 as first, 2 as second" + ) + engine.execute(f"INSERT INTO {full_table_name} (first, second) VALUES (1, 2)") + engine.execute(f"INSERT INTO {full_table_name} (first, second) VALUES (3, 4)") yield db.session - database.get_sqla_engine().execute(f"DROP TABLE {full_table_name}") + + with database.get_sqla_engine_with_context() as engine: + engine.execute(f"DROP TABLE {full_table_name}") class TestDatasource(SupersetTestCase): diff --git a/tests/integration_tests/fixtures/energy_dashboard.py b/tests/integration_tests/fixtures/energy_dashboard.py index 436ba1ce55b6a..202f494aa2d15 100644 --- a/tests/integration_tests/fixtures/energy_dashboard.py +++ b/tests/integration_tests/fixtures/energy_dashboard.py @@ -39,21 +39,22 @@ def load_energy_table_data(): with app.app_context(): database = get_example_database() - df = _get_dataframe() - df.to_sql( - ENERGY_USAGE_TBL_NAME, - database.get_sqla_engine(), - if_exists="replace", - chunksize=500, - index=False, - dtype={"source": String(255), "target": String(255), "value": Float()}, - method="multi", - schema=get_example_default_schema(), - ) + with database.get_sqla_engine_with_context() as engine: + df = _get_dataframe() + df.to_sql( + ENERGY_USAGE_TBL_NAME, + engine, + if_exists="replace", + chunksize=500, + index=False, + dtype={"source": String(255), "target": String(255), "value": Float()}, + method="multi", + schema=get_example_default_schema(), + ) yield with app.app_context(): - engine = get_example_database().get_sqla_engine() - engine.execute("DROP TABLE IF EXISTS energy_usage") + with get_example_database().get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE IF EXISTS energy_usage") @pytest.fixture() diff --git a/tests/integration_tests/fixtures/unicode_dashboard.py b/tests/integration_tests/fixtures/unicode_dashboard.py index c7b828176f2c8..9368df7614a9f 100644 --- a/tests/integration_tests/fixtures/unicode_dashboard.py +++ b/tests/integration_tests/fixtures/unicode_dashboard.py @@ -37,16 +37,17 @@ @pytest.fixture(scope="session") def load_unicode_data(): with app.app_context(): - _get_dataframe().to_sql( - UNICODE_TBL_NAME, - get_example_database().get_sqla_engine(), - if_exists="replace", - chunksize=500, - dtype={"phrase": String(500)}, - index=False, - method="multi", - schema=get_example_default_schema(), - ) + with get_example_database().get_sqla_engine_with_context() as engine: + _get_dataframe().to_sql( + UNICODE_TBL_NAME, + engine, + if_exists="replace", + chunksize=500, + dtype={"phrase": String(500)}, + index=False, + method="multi", + schema=get_example_default_schema(), + ) yield with app.app_context(): diff --git a/tests/integration_tests/fixtures/world_bank_dashboard.py b/tests/integration_tests/fixtures/world_bank_dashboard.py index 2c6fb2c3e26e4..e29962a8c9787 100644 --- a/tests/integration_tests/fixtures/world_bank_dashboard.py +++ b/tests/integration_tests/fixtures/world_bank_dashboard.py @@ -50,16 +50,17 @@ def load_world_bank_data(): "country_name": String(255), "region": String(255), } - _get_dataframe(database).to_sql( - WB_HEALTH_POPULATION, - get_example_database().get_sqla_engine(), - if_exists="replace", - chunksize=500, - dtype=dtype, - index=False, - method="multi", - schema=get_example_default_schema(), - ) + with database.get_sqla_engine_with_context() as engine: + _get_dataframe(database).to_sql( + WB_HEALTH_POPULATION, + engine, + if_exists="replace", + chunksize=500, + dtype=dtype, + index=False, + method="multi", + schema=get_example_default_schema(), + ) yield with app.app_context(): diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index c92de47f038b4..3e13664b63e36 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -57,30 +57,36 @@ def test_database_schema_presto(self): sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive/default" model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri) - db = make_url(model.get_sqla_engine().url).database - self.assertEqual("hive/default", db) + with model.get_sqla_engine_with_context() as engine: + db = make_url(engine.url).database + self.assertEqual("hive/default", db) - db = make_url(model.get_sqla_engine(schema="core_db").url).database - self.assertEqual("hive/core_db", db) + with model.get_sqla_engine_with_context(schema="core_db") as engine: + db = make_url(engine.url).database + self.assertEqual("hive/core_db", db) sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive" model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri) - db = make_url(model.get_sqla_engine().url).database - self.assertEqual("hive", db) + with model.get_sqla_engine_with_context() as engine: + db = make_url(engine.url).database + self.assertEqual("hive", db) - db = make_url(model.get_sqla_engine(schema="core_db").url).database - self.assertEqual("hive/core_db", db) + with model.get_sqla_engine_with_context(schema="core_db") as engine: + db = make_url(engine.url).database + self.assertEqual("hive/core_db", db) def test_database_schema_postgres(self): sqlalchemy_uri = "postgresql+psycopg2://postgres.airbnb.io:5439/prod" model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri) - db = make_url(model.get_sqla_engine().url).database - self.assertEqual("prod", db) + with model.get_sqla_engine_with_context() as engine: + db = make_url(engine.url).database + self.assertEqual("prod", db) - db = make_url(model.get_sqla_engine(schema="foo").url).database - self.assertEqual("prod", db) + with model.get_sqla_engine_with_context(schema="foo") as engine: + db = make_url(engine.url).database + self.assertEqual("prod", db) @unittest.skipUnless( SupersetTestCase.is_module_installed("thrift"), "thrift not installed" @@ -91,11 +97,14 @@ def test_database_schema_postgres(self): def test_database_schema_hive(self): sqlalchemy_uri = "hive://hive@hive.airbnb.io:10000/default?auth=NOSASL" model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri) - db = make_url(model.get_sqla_engine().url).database - self.assertEqual("default", db) - db = make_url(model.get_sqla_engine(schema="core_db").url).database - self.assertEqual("core_db", db) + with model.get_sqla_engine_with_context() as engine: + db = make_url(engine.url).database + self.assertEqual("default", db) + + with model.get_sqla_engine_with_context(schema="core_db") as engine: + db = make_url(engine.url).database + self.assertEqual("core_db", db) @unittest.skipUnless( SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed" @@ -104,11 +113,13 @@ def test_database_schema_mysql(self): sqlalchemy_uri = "mysql://root@localhost/superset" model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri) - db = make_url(model.get_sqla_engine().url).database - self.assertEqual("superset", db) + with model.get_sqla_engine_with_context() as engine: + db = make_url(engine.url).database + self.assertEqual("superset", db) - db = make_url(model.get_sqla_engine(schema="staging").url).database - self.assertEqual("staging", db) + with model.get_sqla_engine_with_context(schema="staging") as engine: + db = make_url(engine.url).database + self.assertEqual("staging", db) @unittest.skipUnless( SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed" @@ -120,12 +131,14 @@ def test_database_impersonate_user(self): with override_user(example_user): model.impersonate_user = True - username = make_url(model.get_sqla_engine().url).username - self.assertEqual(example_user.username, username) + with model.get_sqla_engine_with_context() as engine: + username = make_url(engine.url).username + self.assertEqual(example_user.username, username) model.impersonate_user = False - username = make_url(model.get_sqla_engine().url).username - self.assertNotEqual(example_user.username, username) + with model.get_sqla_engine_with_context() as engine: + username = make_url(engine.url).username + self.assertNotEqual(example_user.username, username) @mock.patch("superset.models.core.create_engine") def test_impersonate_user_presto(self, mocked_create_engine): @@ -369,6 +382,20 @@ def test_get_sqla_engine(self, mocked_create_engine): with self.assertRaises(SupersetException): model.get_sqla_engine() + # todo(hughhh): update this test + # @mock.patch("superset.models.core.create_engine") + # def test_get_sqla_engine_with_context(self, mocked_create_engine): + # model = Database( + # database_name="test_database", + # sqlalchemy_uri="mysql://root@localhost", + # ) + # model.db_engine_spec.get_dbapi_exception_mapping = mock.Mock( + # return_value={Exception: SupersetException} + # ) + # mocked_create_engine.side_effect = Exception() + # with self.assertRaises(SupersetException): + # model.get_sqla_engine() + class TestSqlaTableModel(SupersetTestCase): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") diff --git a/tests/integration_tests/reports/commands_tests.py b/tests/integration_tests/reports/commands_tests.py index c82c1b5fdb2d1..b3ef86c5e32cb 100644 --- a/tests/integration_tests/reports/commands_tests.py +++ b/tests/integration_tests/reports/commands_tests.py @@ -130,18 +130,14 @@ def assert_log(state: str, error_message: Optional[str] = None): @contextmanager def create_test_table_context(database: Database): - database.get_sqla_engine().execute( - "CREATE TABLE test_table AS SELECT 1 as first, 2 as second" - ) - database.get_sqla_engine().execute( - "INSERT INTO test_table (first, second) VALUES (1, 2)" - ) - database.get_sqla_engine().execute( - "INSERT INTO test_table (first, second) VALUES (3, 4)" - ) + with database.get_sqla_engine_with_context() as engine: + engine.execute("CREATE TABLE test_table AS SELECT 1 as first, 2 as second") + engine.execute("INSERT INTO test_table (first, second) VALUES (1, 2)") + engine.execute("INSERT INTO test_table (first, second) VALUES (3, 4)") yield db.session - database.get_sqla_engine().execute("DROP TABLE test_table") + with database.get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE test_table") @pytest.fixture() diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 0c4019e7d9f17..bee9b08114a40 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -207,19 +207,21 @@ def test_sql_json_cta_dynamic_db(self, ctas_method): # assertions db.session.commit() examples_db = get_example_database() - engine = examples_db.get_sqla_engine() - data = engine.execute( - f"SELECT * FROM admin_database.{tmp_table_name}" - ).fetchall() - names_count = engine.execute(f"SELECT COUNT(*) FROM birth_names").first() - self.assertEqual( - names_count[0], len(data) - ) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True - - # cleanup - engine.execute(f"DROP {ctas_method} admin_database.{tmp_table_name}") - examples_db.allow_ctas = old_allow_ctas - db.session.commit() + with examples_db.get_sqla_engine_with_context() as engine: + data = engine.execute( + f"SELECT * FROM admin_database.{tmp_table_name}" + ).fetchall() + names_count = engine.execute( + f"SELECT COUNT(*) FROM birth_names" + ).first() + self.assertEqual( + names_count[0], len(data) + ) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True + + # cleanup + engine.execute(f"DROP {ctas_method} admin_database.{tmp_table_name}") + examples_db.allow_ctas = old_allow_ctas + db.session.commit() @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_multi_sql(self): @@ -275,9 +277,10 @@ def test_sql_json_schema_access(self): "SchemaUser", ["SchemaPermission", "Gamma", "sql_lab"] ) - examples_db.get_sqla_engine().execute( - f"CREATE TABLE IF NOT EXISTS {CTAS_SCHEMA_NAME}.test_table AS SELECT 1 as c1, 2 as c2" - ) + with examples_db.get_sqla_engine_with_context() as engine: + engine.execute( + f"CREATE TABLE IF NOT EXISTS {CTAS_SCHEMA_NAME}.test_table AS SELECT 1 as c1, 2 as c2" + ) data = self.run_sql( f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", "3", username="SchemaUser" @@ -303,9 +306,8 @@ def test_sql_json_schema_access(self): self.assertEqual(1, len(data["data"])) db.session.query(Query).delete() - get_example_database().get_sqla_engine().execute( - f"DROP TABLE IF EXISTS {CTAS_SCHEMA_NAME}.test_table" - ) + with get_example_database().get_sqla_engine_with_context() as engine: + engine.execute(f"DROP TABLE IF EXISTS {CTAS_SCHEMA_NAME}.test_table") db.session.commit() def test_queries_endpoint(self): @@ -520,12 +522,10 @@ def test_sqllab_viz_bad_payload(self): def test_sqllab_table_viz(self): self.login("admin") examples_db = get_example_database() - examples_db.get_sqla_engine().execute( - "DROP TABLE IF EXISTS test_sqllab_table_viz" - ) - examples_db.get_sqla_engine().execute( - "CREATE TABLE test_sqllab_table_viz AS SELECT 2 as col" - ) + with examples_db.get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE IF EXISTS test_sqllab_table_viz") + engine.execute("CREATE TABLE test_sqllab_table_viz AS SELECT 2 as col") + examples_dbid = examples_db.id payload = { @@ -543,9 +543,9 @@ def test_sqllab_table_viz(self): table = db.session.query(SqlaTable).filter_by(id=table_id).one() self.assertEqual([owner.username for owner in table.owners], ["admin"]) db.session.delete(table) - get_example_database().get_sqla_engine().execute( - "DROP TABLE test_sqllab_table_viz" - ) + + with get_example_database().get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE test_sqllab_table_viz") db.session.commit() @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")