diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 87b3d5dd3a28a..18e5de48c10d6 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -144,8 +144,9 @@ def get_columns_description( with database.get_raw_connection(catalog=catalog, schema=schema) as conn: cursor = conn.cursor() query = database.apply_limit_to_sql(query, limit=1) - cursor.execute(query) - db_engine_spec.execute(cursor, query, database) + mutated_query = database.mutate_sql_based_on_config(query) + cursor.execute(mutated_query) + db_engine_spec.execute(cursor, mutated_query, database) result = db_engine_spec.fetch_data(cursor, limit=1) result_set = SupersetResultSet(result, cursor.description, db_engine_spec) return result_set.columns diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index aaad26b85d723..a2f21e7fc729e 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -68,6 +68,24 @@ def create_test_table_context(database: Database): engine.execute(f"DROP TABLE {full_table_name}") +@contextmanager +def create_and_cleanup_table(table=None): + if table is None: + table = SqlaTable( + table_name="dummy_sql_table", + database=get_example_database(), + schema=get_example_default_schema(), + sql="select 123 as intcol, 'abc' as strcol", + ) + db.session.add(table) + db.session.commit() + try: + yield table + finally: + db.session.delete(table) + db.session.commit() + + class TestDatasource(SupersetTestCase): def setUp(self): db.session.begin(subtransactions=True) @@ -123,37 +141,22 @@ def test_always_filter_main_dttm(self): sql=sql, ) - db.session.add(table) - db.session.commit() + with create_and_cleanup_table(table): + table.always_filter_main_dttm = False + result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause) + assert "default_dttm" not in result and "additional_dttm" in result - table.always_filter_main_dttm = False - result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause) - assert "default_dttm" not in result and "additional_dttm" in result - - table.always_filter_main_dttm = True - result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause) - assert "default_dttm" in result and "additional_dttm" in result - - db.session.delete(table) - db.session.commit() + table.always_filter_main_dttm = True + result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause) + assert "default_dttm" in result and "additional_dttm" in result def test_external_metadata_for_virtual_table(self): self.login(ADMIN_USERNAME) - table = SqlaTable( - table_name="dummy_sql_table", - database=get_example_database(), - schema=get_example_default_schema(), - sql="select 123 as intcol, 'abc' as strcol", - ) - db.session.add(table) - db.session.commit() - table = self.get_table(name="dummy_sql_table") - url = f"/datasource/external_metadata/table/{table.id}/" - resp = self.get_json_resp(url) - assert {o.get("column_name") for o in resp} == {"intcol", "strcol"} - db.session.delete(table) - db.session.commit() + with create_and_cleanup_table() as table: + url = f"/datasource/external_metadata/table/{table.id}/" + resp = self.get_json_resp(url) + assert {o.get("column_name") for o in resp} == {"intcol", "strcol"} @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_external_metadata_by_name_for_physical_table(self): @@ -178,31 +181,42 @@ def test_external_metadata_by_name_for_physical_table(self): def test_external_metadata_by_name_for_virtual_table(self): self.login(ADMIN_USERNAME) - table = SqlaTable( - table_name="dummy_sql_table", - database=get_example_database(), - schema=get_example_default_schema(), - sql="select 123 as intcol, 'abc' as strcol", - ) - db.session.add(table) - db.session.commit() + with create_and_cleanup_table() as tbl: + params = prison.dumps( + { + "datasource_type": "table", + "database_name": tbl.database.database_name, + "schema_name": tbl.schema, + "table_name": tbl.table_name, + "normalize_columns": tbl.normalize_columns, + "always_filter_main_dttm": tbl.always_filter_main_dttm, + } + ) + url = f"/datasource/external_metadata_by_name/?q={params}" + resp = self.get_json_resp(url) + assert {o.get("column_name") for o in resp} == {"intcol", "strcol"} - tbl = self.get_table(name="dummy_sql_table") - params = prison.dumps( - { - "datasource_type": "table", - "database_name": tbl.database.database_name, - "schema_name": tbl.schema, - "table_name": tbl.table_name, - "normalize_columns": tbl.normalize_columns, - "always_filter_main_dttm": tbl.always_filter_main_dttm, - } - ) - url = f"/datasource/external_metadata_by_name/?q={params}" - resp = self.get_json_resp(url) - assert {o.get("column_name") for o in resp} == {"intcol", "strcol"} - db.session.delete(tbl) - db.session.commit() + def test_external_metadata_by_name_for_virtual_table_uses_mutator(self): + self.login(ADMIN_USERNAME) + with create_and_cleanup_table() as tbl: + app.config["SQL_QUERY_MUTATOR"] = ( + lambda sql, **kwargs: "SELECT 456 as intcol, 'def' as mutated_strcol" + ) + + params = prison.dumps( + { + "datasource_type": "table", + "database_name": tbl.database.database_name, + "schema_name": tbl.schema, + "table_name": tbl.table_name, + "normalize_columns": tbl.normalize_columns, + "always_filter_main_dttm": tbl.always_filter_main_dttm, + } + ) + url = f"/datasource/external_metadata_by_name/?q={params}" + resp = self.get_json_resp(url) + assert {o.get("column_name") for o in resp} == {"intcol", "mutated_strcol"} + app.config["SQL_QUERY_MUTATOR"] = None def test_external_metadata_by_name_from_sqla_inspector(self): self.login(ADMIN_USERNAME) @@ -278,15 +292,10 @@ def test_external_metadata_for_virtual_table_template_params(self): sql="select {{ foo }} as intcol", template_params=json.dumps({"foo": "123"}), ) - db.session.add(table) - db.session.commit() - - table = self.get_table(name="dummy_sql_table_with_template_params") - url = f"/datasource/external_metadata/table/{table.id}/" - resp = self.get_json_resp(url) - assert {o.get("column_name") for o in resp} == {"intcol"} - db.session.delete(table) - db.session.commit() + with create_and_cleanup_table(table) as tbl: + url = f"/datasource/external_metadata/table/{tbl.id}/" + resp = self.get_json_resp(url) + assert {o.get("column_name") for o in resp} == {"intcol"} def test_external_metadata_for_malicious_virtual_table(self): self.login(ADMIN_USERNAME) diff --git a/tests/unit_tests/connectors/sqla/utils_test.py b/tests/unit_tests/connectors/sqla/utils_test.py new file mode 100644 index 0000000000000..75d5a1fe32914 --- /dev/null +++ b/tests/unit_tests/connectors/sqla/utils_test.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pytest_mock import MockerFixture + +from superset.connectors.sqla.utils import get_columns_description + + +# Returns column descriptions when given valid database, catalog, schema, and query +def test_returns_column_descriptions(mocker: MockerFixture) -> None: + database = mocker.MagicMock() + cursor = mocker.MagicMock() + + result_set = mocker.MagicMock() + db_engine_spec = mocker.MagicMock() + + CURSOR_DESCR = ( + ("foo", "string"), + ("bar", "string"), + ("baz", "string"), + ("type_generic", "string"), + ("is_dttm", "boolean"), + ) + cursor.description = CURSOR_DESCR + + database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value = cursor + database.db_engine_spec = db_engine_spec + database.apply_limit_to_sql.return_value = "SELECT * FROM table LIMIT 1" + database.mutate_sql_based_on_config.return_value = "SELECT * FROM table LIMIT 1" + db_engine_spec.fetch_data.return_value = [("col1", "col1", "STRING", None, False)] + db_engine_spec.get_datatype.return_value = "STRING" + db_engine_spec.get_column_spec.return_value.is_dttm = False + db_engine_spec.get_column_spec.return_value.generic_type = "STRING" + + mocker.patch("superset.result_set.SupersetResultSet", return_value=result_set) + + columns = get_columns_description( + database, "catalog", "schema", "SELECT * FROM table" + ) + + assert columns == [ + { + "column_name": "foo", + "name": "foo", + "type": "STRING", + "type_generic": "STRING", + "is_dttm": False, + }, + { + "column_name": "bar", + "name": "bar", + "type": "STRING", + "type_generic": "STRING", + "is_dttm": False, + }, + { + "column_name": "baz", + "name": "baz", + "type": "STRING", + "type_generic": "STRING", + "is_dttm": False, + }, + { + "column_name": "type_generic", + "name": "type_generic", + "type": "STRING", + "type_generic": "STRING", + "is_dttm": False, + }, + { + "column_name": "is_dttm", + "name": "is_dttm", + "type": "STRING", + "type_generic": "STRING", + "is_dttm": False, + }, + ]