Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Change get_table_names/get_view_names return type #22085

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@ def get_table_names( # pylint: disable=unused-argument
database: "Database",
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
) -> Set[str]:
"""
Get all the real table names within the specified schema.

Expand All @@ -1030,21 +1030,21 @@ def get_table_names( # pylint: disable=unused-argument
"""

try:
tables = inspector.get_table_names(schema)
tables = set(inspector.get_table_names(schema))
except Exception as ex:
raise cls.get_dbapi_mapped_exception(ex) from ex

if schema and cls.try_remove_schema_from_table_name:
tables = [re.sub(f"^{schema}\\.", "", table) for table in tables]
return sorted(tables)
tables = {re.sub(f"^{schema}\\.", "", table) for table in tables}
return tables

@classmethod
def get_view_names( # pylint: disable=unused-argument
cls,
database: "Database",
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
) -> Set[str]:
"""
Get all the view names within the specified schema.

Expand All @@ -1058,13 +1058,13 @@ def get_view_names( # pylint: disable=unused-argument
"""

try:
views = inspector.get_view_names(schema)
views = set(inspector.get_view_names(schema))
except Exception as ex:
raise cls.get_dbapi_mapped_exception(ex) from ex

if schema and cls.try_remove_schema_from_table_name:
views = [re.sub(f"^{schema}\\.", "", view) for view in views]
return sorted(views)
views = {re.sub(f"^{schema}\\.", "", view) for view in views}
return views

@classmethod
def get_table_comment(
Expand Down
12 changes: 5 additions & 7 deletions superset/db_engine_specs/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

from datetime import datetime
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from typing import Any, Dict, Optional, Set, TYPE_CHECKING

from sqlalchemy.engine.reflection import Inspector

Expand Down Expand Up @@ -103,9 +103,7 @@ def get_table_names(
database: "Database",
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
tables = set(super().get_table_names(database, inspector, schema))
views = set(cls.get_view_names(database, inspector, schema))
actual_tables = tables - views

return list(actual_tables)
) -> Set[str]:
return super().get_table_names(
database, inspector, schema
) - cls.get_view_names(database, inspector, schema)
6 changes: 3 additions & 3 deletions superset/db_engine_specs/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING

from flask_babel import gettext as __
from sqlalchemy.engine.reflection import Inspector
Expand Down Expand Up @@ -75,5 +75,5 @@ def convert_dttm(
@classmethod
def get_table_names(
cls, database: Database, inspector: Inspector, schema: Optional[str]
) -> List[str]:
return sorted(inspector.get_table_names(schema))
) -> Set[str]:
return set(inspector.get_table_names(schema))
10 changes: 5 additions & 5 deletions superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Pattern, Set, Tuple, TYPE_CHECKING

from flask_babel import gettext as __
from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON
Expand Down Expand Up @@ -228,11 +228,11 @@ def query_cost_formatter(
@classmethod
def get_table_names(
cls, database: "Database", inspector: PGInspector, schema: Optional[str]
) -> List[str]:
) -> Set[str]:
"""Need to consider foreign tables for PostgreSQL"""
tables = inspector.get_table_names(schema)
tables.extend(inspector.get_foreign_table_names(schema))
return sorted(tables)
return set(inspector.get_table_names(schema)) | set(
inspector.get_foreign_table_names(schema)
)

@classmethod
def convert_dttm(
Expand Down
28 changes: 18 additions & 10 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,18 @@
from datetime import datetime
from distutils.version import StrictVersion
from textwrap import dedent
from typing import Any, cast, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING, Union
from typing import (
Any,
cast,
Dict,
List,
Optional,
Pattern,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
from urllib import parse

import pandas as pd
Expand Down Expand Up @@ -396,7 +407,7 @@ def get_table_names(
database: Database,
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
) -> Set[str]:
"""
Get all the real table names within the specified schema.

Expand All @@ -414,20 +425,17 @@ def get_table_names(
:returns: The physical table names
"""

return sorted(
list(
set(super().get_table_names(database, inspector, schema))
- set(cls.get_view_names(database, inspector, schema))
)
)
return super().get_table_names(
database, inspector, schema
) - cls.get_view_names(database, inspector, schema)

@classmethod
def get_view_names(
cls,
database: Database,
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
) -> Set[str]:
"""
Get all the view names within the specified schema.

Expand Down Expand Up @@ -469,7 +477,7 @@ def get_view_names(
cursor.execute(sql, params)
results = cursor.fetchall()

return sorted([row[0] for row in results])
return {row[0] for row in results}

@classmethod
def _create_column_info(
Expand Down
6 changes: 3 additions & 3 deletions superset/db_engine_specs/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING

from flask_babel import gettext as __
from sqlalchemy.engine.reflection import Inspector
Expand Down Expand Up @@ -88,6 +88,6 @@ def convert_dttm(
@classmethod
def get_table_names(
cls, database: "Database", inspector: Inspector, schema: Optional[str]
) -> List[str]:
) -> Set[str]:
"""Need to disregard the schema for Sqlite"""
return sorted(inspector.get_table_names())
return set(inspector.get_table_names())
32 changes: 20 additions & 12 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument
cache: bool = False,
cache_timeout: Optional[int] = None,
force: bool = False,
) -> List[Tuple[str, str]]:
) -> Set[Tuple[str, str]]:
"""Parameters need to be passed as keyword arguments.

For unused parameters, they are referenced in
Expand All @@ -556,13 +556,17 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument
:param cache: whether cache is enabled for the function
:param cache_timeout: timeout in seconds for the cache
:param force: whether to force refresh the cache
:return: list of tables
:return: The table/schema pairs
"""
try:
tables = self.db_engine_spec.get_table_names(
database=self, inspector=self.inspector, schema=schema
)
return [(table, schema) for table in tables]
return {
(table, schema)
for table in self.db_engine_spec.get_table_names(
database=self,
inspector=self.inspector,
schema=schema,
)
}
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)

Expand All @@ -576,7 +580,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument
cache: bool = False,
cache_timeout: Optional[int] = None,
force: bool = False,
) -> List[Tuple[str, str]]:
) -> Set[Tuple[str, str]]:
"""Parameters need to be passed as keyword arguments.

For unused parameters, they are referenced in
Expand All @@ -586,13 +590,17 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument
:param cache: whether cache is enabled for the function
:param cache_timeout: timeout in seconds for the cache
:param force: whether to force refresh the cache
:return: list of views
:return: set of views
"""
try:
views = self.db_engine_spec.get_view_names(
database=self, inspector=self.inspector, schema=schema
)
return [(view, schema) for view in views]
return {
(view, schema)
for view in self.db_engine_spec.get_view_names(
database=self,
inspector=self.inspector,
schema=schema,
)
}
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)

Expand Down
8 changes: 4 additions & 4 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,29 +1185,29 @@ def tables( # pylint: disable=no-self-use
tables = security_manager.get_datasources_accessible_by_user(
database=database,
schema=schema_parsed,
datasource_names=[
datasource_names=sorted(
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_table_names_in_schema(
schema=schema_parsed,
force=force_refresh_parsed,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
],
),
)

views = security_manager.get_datasources_accessible_by_user(
database=database,
schema=schema_parsed,
datasource_names=[
datasource_names=sorted(
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_view_names_in_schema(
schema=schema_parsed,
force=force_refresh_parsed,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
],
),
)
except SupersetException as ex:
return json_error_response(ex.message, ex.status)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ def test_create_dataset_validate_view_exists(
with patch.object(
dialect, "get_view_names", wraps=dialect.get_view_names
) as patch_get_view_names:
patch_get_view_names.return_value = ["test_case_view"]
patch_get_view_names.return_value = {"test_case_view"}

self.login(username="admin")
table_data = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,11 @@ def test_get_table_names(self):

""" Make sure base engine spec removes schema name from table name
ie. when try_remove_schema_from_table_name == True. """
base_result_expected = ["table", "table_2"]
base_result_expected = {"table", "table_2"}
base_result = BaseEngineSpec.get_table_names(
database=mock.ANY, schema="schema", inspector=inspector
)
self.assertListEqual(base_result_expected, base_result)
assert base_result_expected == base_result

@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_column_datatype_to_string(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/db_engine_specs/postgres_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def test_get_table_names(self):
inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])

pg_result_expected = ["schema.table", "table_2", "table_3"]
pg_result_expected = {"schema.table", "table_2", "table_3"}
pg_result = PostgresEngineSpec.get_table_names(
database=mock.ANY, schema="schema", inspector=inspector
)
self.assertListEqual(pg_result_expected, pg_result)
assert pg_result_expected == pg_result

def test_time_exp_literal_no_grain(self):
"""
Expand Down
10 changes: 5 additions & 5 deletions tests/integration_tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_get_view_names_with_schema(self):
).strip(),
{"schema": schema},
)
assert result == ["a", "d"]
assert result == {"a", "d"}

def test_get_view_names_without_schema(self):
database = mock.MagicMock()
Expand All @@ -76,7 +76,7 @@ def test_get_view_names_without_schema(self):
).strip(),
{},
)
assert result == ["a", "d"]
assert result == {"a", "d"}

def verify_presto_column(self, column, expected_results):
inspector = mock.Mock()
Expand Down Expand Up @@ -669,10 +669,10 @@ def test_get_table_names(
mock_get_view_names,
mock_get_table_names,
):
mock_get_view_names.return_value = ["view1", "view2"]
mock_get_table_names.return_value = ["table1", "table2", "view1", "view2"]
mock_get_view_names.return_value = {"view1", "view2"}
mock_get_table_names.return_value = {"table1", "table2", "view1", "view2"}
tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), None)
assert tables == ["table1", "table2"]
assert tables == {"table1", "table2"}

def test_get_full_name(self):
names = [
Expand Down