Skip to content

Commit

Permalink
update other test
Browse files Browse the repository at this point in the history
  • Loading branch information
hughhhh committed Oct 20, 2022
1 parent ca24e25 commit c006017
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 99 deletions.
27 changes: 14 additions & 13 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,10 @@ def test_create_dataset_same_name_different_schema(self):
return

example_db = get_example_database()
example_db.get_sqla_engine().execute(
f"CREATE TABLE {CTAS_SCHEMA_NAME}.birth_names AS SELECT 2 as two"
)
with example_db.get_sqla_engine_with_context() as engine:
engine.execute(
f"CREATE TABLE {CTAS_SCHEMA_NAME}.birth_names AS SELECT 2 as two"
)

self.login(username="admin")
table_data = {
Expand All @@ -635,9 +636,8 @@ def test_create_dataset_same_name_different_schema(self):
uri = f'api/v1/dataset/{data.get("id")}'
rv = self.client.delete(uri)
assert rv.status_code == 200
example_db.get_sqla_engine().execute(
f"DROP TABLE {CTAS_SCHEMA_NAME}.birth_names"
)
with example_db.get_sqla_engine_with_context() as engine:
engine.execute(f"DROP TABLE {CTAS_SCHEMA_NAME}.birth_names")

def test_create_dataset_validate_database(self):
"""
Expand Down Expand Up @@ -703,13 +703,14 @@ def test_create_dataset_validate_view_exists(
mock_get_table.return_value = None

example_db = get_example_database()
engine = example_db.get_sqla_engine()
dialect = engine.dialect

with patch.object(
dialect, "get_view_names", wraps=dialect.get_view_names
) as patch_get_view_names:
patch_get_view_names.return_value = ["test_case_view"]
with example_db.get_sqla_engine_with_context() as engine:
engine = engine
dialect = engine.dialect

with patch.object(
dialect, "get_view_names", wraps=dialect.get_view_names
) as patch_get_view_names:
patch_get_view_names.return_value = ["test_case_view"]

self.login(username="admin")
table_data = {
Expand Down
27 changes: 14 additions & 13 deletions tests/integration_tests/fixtures/energy_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,22 @@
def load_energy_table_data():
with app.app_context():
database = get_example_database()
df = _get_dataframe()
df.to_sql(
ENERGY_USAGE_TBL_NAME,
database.get_sqla_engine(),
if_exists="replace",
chunksize=500,
index=False,
dtype={"source": String(255), "target": String(255), "value": Float()},
method="multi",
schema=get_example_default_schema(),
)
with database.get_sqla_engine_with_context() as engine:
df = _get_dataframe()
df.to_sql(
ENERGY_USAGE_TBL_NAME,
engine,
if_exists="replace",
chunksize=500,
index=False,
dtype={"source": String(255), "target": String(255), "value": Float()},
method="multi",
schema=get_example_default_schema(),
)
yield
with app.app_context():
engine = get_example_database().get_sqla_engine()
engine.execute("DROP TABLE IF EXISTS energy_usage")
with get_example_database().get_sqla_engine_with_context() as engine:
engine.execute("DROP TABLE IF EXISTS energy_usage")


@pytest.fixture()
Expand Down
21 changes: 11 additions & 10 deletions tests/integration_tests/fixtures/unicode_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,17 @@
@pytest.fixture(scope="session")
def load_unicode_data():
with app.app_context():
_get_dataframe().to_sql(
UNICODE_TBL_NAME,
get_example_database().get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype={"phrase": String(500)},
index=False,
method="multi",
schema=get_example_default_schema(),
)
with get_example_database().get_sqla_engine_with_context() as engine:
_get_dataframe().to_sql(
UNICODE_TBL_NAME,
engine,
if_exists="replace",
chunksize=500,
dtype={"phrase": String(500)},
index=False,
method="multi",
schema=get_example_default_schema(),
)

yield
with app.app_context():
Expand Down
21 changes: 11 additions & 10 deletions tests/integration_tests/fixtures/world_bank_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,17 @@ def load_world_bank_data():
"country_name": String(255),
"region": String(255),
}
_get_dataframe(database).to_sql(
WB_HEALTH_POPULATION,
get_example_database().get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype=dtype,
index=False,
method="multi",
schema=get_example_default_schema(),
)
with database.get_sqla_engine_with_context() as engine:
_get_dataframe(database).to_sql(
WB_HEALTH_POPULATION,
engine,
if_exists="replace",
chunksize=500,
dtype=dtype,
index=False,
method="multi",
schema=get_example_default_schema(),
)

yield
with app.app_context():
Expand Down
55 changes: 39 additions & 16 deletions tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,13 @@ def test_database_schema_postgres(self):
sqlalchemy_uri = "postgresql+psycopg2://postgres.airbnb.io:5439/prod"
model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)

db = make_url(model.get_sqla_engine().url).database
self.assertEqual("prod", db)
with model.get_sqla_engine_with_context() as engine:
db = make_url(engine.url).database
self.assertEqual("prod", db)

db = make_url(model.get_sqla_engine(schema="foo").url).database
self.assertEqual("prod", db)
with model.get_sqla_engine_with_context(schema="foo") as engine:
db = make_url(engine.url).database
self.assertEqual("prod", db)

@unittest.skipUnless(
SupersetTestCase.is_module_installed("thrift"), "thrift not installed"
Expand All @@ -95,11 +97,14 @@ def test_database_schema_postgres(self):
def test_database_schema_hive(self):
sqlalchemy_uri = "hive://hive@hive.airbnb.io:10000/default?auth=NOSASL"
model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
db = make_url(model.get_sqla_engine().url).database
self.assertEqual("default", db)

db = make_url(model.get_sqla_engine(schema="core_db").url).database
self.assertEqual("core_db", db)
with model.get_sqla_engine_with_context() as engine:
db = make_url(engine.url).database
self.assertEqual("default", db)

with model.get_sqla_engine_with_context(schema="core_db") as engine:
db = make_url(engine.url).database
self.assertEqual("core_db", db)

@unittest.skipUnless(
SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
Expand All @@ -108,11 +113,13 @@ def test_database_schema_mysql(self):
sqlalchemy_uri = "mysql://root@localhost/superset"
model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)

db = make_url(model.get_sqla_engine().url).database
self.assertEqual("superset", db)
with model.get_sqla_engine_with_context() as engine:
db = make_url(engine.url).database
self.assertEqual("superset", db)

db = make_url(model.get_sqla_engine(schema="staging").url).database
self.assertEqual("staging", db)
with model.get_sqla_engine_with_context(schema="staging") as engine:
db = make_url(engine.url).database
self.assertEqual("staging", db)

@unittest.skipUnless(
SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
Expand All @@ -124,12 +131,14 @@ def test_database_impersonate_user(self):

with override_user(example_user):
model.impersonate_user = True
username = make_url(model.get_sqla_engine().url).username
self.assertEqual(example_user.username, username)
with model.get_sqla_engine_with_context() as engine:
username = make_url(engine.url).username
self.assertEqual(example_user.username, username)

model.impersonate_user = False
username = make_url(model.get_sqla_engine().url).username
self.assertNotEqual(example_user.username, username)
with model.get_sqla_engine_with_context() as engine:
username = make_url(engine.url).username
self.assertNotEqual(example_user.username, username)

@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_presto(self, mocked_create_engine):
Expand Down Expand Up @@ -373,6 +382,20 @@ def test_get_sqla_engine(self, mocked_create_engine):
with self.assertRaises(SupersetException):
model.get_sqla_engine()

# todo(hughhh): update this test
# @mock.patch("superset.models.core.create_engine")
# def test_get_sqla_engine_with_context(self, mocked_create_engine):
# model = Database(
# database_name="test_database",
# sqlalchemy_uri="mysql://root@localhost",
# )
# model.db_engine_spec.get_dbapi_exception_mapping = mock.Mock(
# return_value={Exception: SupersetException}
# )
# mocked_create_engine.side_effect = Exception()
# with self.assertRaises(SupersetException):
# model.get_sqla_engine()


class TestSqlaTableModel(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
Expand Down
13 changes: 4 additions & 9 deletions tests/integration_tests/reports/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,10 @@ def assert_log(state: str, error_message: Optional[str] = None):

@contextmanager
def create_test_table_context(database: Database):
database.get_sqla_engine().execute(
"CREATE TABLE test_table AS SELECT 1 as first, 2 as second"
)
database.get_sqla_engine().execute(
"INSERT INTO test_table (first, second) VALUES (1, 2)"
)
database.get_sqla_engine().execute(
"INSERT INTO test_table (first, second) VALUES (3, 4)"
)
with database.get_sqla_engine_with_context() as engine:
engine.execute("CREATE TABLE test_table AS SELECT 1 as first, 2 as second")
engine.execute("INSERT INTO test_table (first, second) VALUES (1, 2)")
engine.execute("INSERT INTO test_table (first, second) VALUES (3, 4)")

yield db.session
database.get_sqla_engine().execute("DROP TABLE test_table")
Expand Down
56 changes: 28 additions & 28 deletions tests/integration_tests/sqllab_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,19 +207,21 @@ def test_sql_json_cta_dynamic_db(self, ctas_method):
# assertions
db.session.commit()
examples_db = get_example_database()
engine = examples_db.get_sqla_engine()
data = engine.execute(
f"SELECT * FROM admin_database.{tmp_table_name}"
).fetchall()
names_count = engine.execute(f"SELECT COUNT(*) FROM birth_names").first()
self.assertEqual(
names_count[0], len(data)
) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True

# cleanup
engine.execute(f"DROP {ctas_method} admin_database.{tmp_table_name}")
examples_db.allow_ctas = old_allow_ctas
db.session.commit()
with examples_db.get_sqla_engine_with_context() as engine:
data = engine.execute(
f"SELECT * FROM admin_database.{tmp_table_name}"
).fetchall()
names_count = engine.execute(
f"SELECT COUNT(*) FROM birth_names"
).first()
self.assertEqual(
names_count[0], len(data)
) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True

# cleanup
engine.execute(f"DROP {ctas_method} admin_database.{tmp_table_name}")
examples_db.allow_ctas = old_allow_ctas
db.session.commit()

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_multi_sql(self):
Expand Down Expand Up @@ -275,9 +277,10 @@ def test_sql_json_schema_access(self):
"SchemaUser", ["SchemaPermission", "Gamma", "sql_lab"]
)

examples_db.get_sqla_engine().execute(
f"CREATE TABLE IF NOT EXISTS {CTAS_SCHEMA_NAME}.test_table AS SELECT 1 as c1, 2 as c2"
)
with examples_db.get_sqla_engine_with_context() as engine:
engine.execute(
f"CREATE TABLE IF NOT EXISTS {CTAS_SCHEMA_NAME}.test_table AS SELECT 1 as c1, 2 as c2"
)

data = self.run_sql(
f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", "3", username="SchemaUser"
Expand All @@ -303,9 +306,8 @@ def test_sql_json_schema_access(self):
self.assertEqual(1, len(data["data"]))

db.session.query(Query).delete()
get_example_database().get_sqla_engine().execute(
f"DROP TABLE IF EXISTS {CTAS_SCHEMA_NAME}.test_table"
)
with get_example_database().get_sqla_engine_with_context() as engine:
engine.execute(f"DROP TABLE IF EXISTS {CTAS_SCHEMA_NAME}.test_table")
db.session.commit()

def test_queries_endpoint(self):
Expand Down Expand Up @@ -520,12 +522,10 @@ def test_sqllab_viz_bad_payload(self):
def test_sqllab_table_viz(self):
self.login("admin")
examples_db = get_example_database()
examples_db.get_sqla_engine().execute(
"DROP TABLE IF EXISTS test_sqllab_table_viz"
)
examples_db.get_sqla_engine().execute(
"CREATE TABLE test_sqllab_table_viz AS SELECT 2 as col"
)
with examples_db.get_sqla_engine_with_context() as engine:
engine.execute("DROP TABLE IF EXISTS test_sqllab_table_viz")
engine.execute("CREATE TABLE test_sqllab_table_viz AS SELECT 2 as col")

examples_dbid = examples_db.id

payload = {
Expand All @@ -543,9 +543,9 @@ def test_sqllab_table_viz(self):
table = db.session.query(SqlaTable).filter_by(id=table_id).one()
self.assertEqual([owner.username for owner in table.owners], ["admin"])
db.session.delete(table)
get_example_database().get_sqla_engine().execute(
"DROP TABLE test_sqllab_table_viz"
)

with get_example_database().get_sqla_engine_with_context() as engine:
engine.execute("DROP TABLE test_sqllab_table_viz")
db.session.commit()

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
Expand Down

0 comments on commit c006017

Please sign in to comment.