From ebd79e5f77b4e4168f48e7a4e85129d272fd26bc Mon Sep 17 00:00:00 2001 From: John Bodley Date: Wed, 12 Apr 2023 15:22:57 +1200 Subject: [PATCH] chore(db_engine_specs): Refactor get_index --- superset/db_engine_specs/base.py | 22 +++++ superset/db_engine_specs/bigquery.py | 22 +++++ superset/db_engine_specs/presto.py | 16 +++- superset/models/core.py | 3 +- .../db_engine_specs/base_engine_spec_tests.py | 23 +++++ .../db_engine_specs/bigquery_tests.py | 83 +++++++++++++++---- 6 files changed, 147 insertions(+), 22 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index fb76d9363d535..ed58e8cb86b38 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -43,6 +43,7 @@ import sqlparse from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin +from deprecation import deprecated from flask import current_app from flask_appbuilder.security.sqla.models import User from flask_babel import gettext as __, lazy_gettext as _ @@ -797,6 +798,7 @@ def get_datatype(cls, type_code: Any) -> Optional[str]: return None @classmethod + @deprecated(deprecated_in="3.0") def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Normalizes indexes for more consistency across db engines @@ -1179,6 +1181,26 @@ def get_view_names( # pylint: disable=unused-argument views = {re.sub(f"^{schema}\\.", "", view) for view in views} return views + @classmethod + def get_indexes( + cls, + database: Database, # pylint: disable=unused-argument + inspector: Inspector, + table_name: str, + schema: Optional[str], + ) -> List[Dict[str, Any]]: + """ + Get the indexes associated with the specified schema/table. + + :param database: The database to inspect + :param inspector: The SQLAlchemy inspector + :param table_name: The table to inspect + :param schema: The schema to inspect + :returns: The indexes + """ + + return inspector.get_indexes(table_name, schema) + @classmethod def get_table_comment( cls, inspector: Inspector, table_name: str, schema: Optional[str] diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index f344fcac2095a..976b2ee3167c5 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -23,6 +23,7 @@ import pandas as pd from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin +from deprecation import deprecated from flask_babel import gettext as __ from marshmallow import fields, Schema from marshmallow.exceptions import ValidationError @@ -278,6 +279,7 @@ def _truncate_label(cls, label: str) -> str: return "_" + md5_sha_from_str(label) @classmethod + @deprecated(deprecated_in="3.0") def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Normalizes indexes for more consistency across db engines @@ -296,6 +298,26 @@ def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any] normalized_idxs.append(ix) return normalized_idxs + @classmethod + def get_indexes( + cls, + database: "Database", + inspector: Inspector, + table_name: str, + schema: Optional[str], + ) -> List[Dict[str, Any]]: + """ + Get the indexes associated with the specified schema/table. + + :param database: The database to inspect + :param inspector: The SQLAlchemy inspector + :param table_name: The table to inspect + :param schema: The schema to inspect + :returns: The indexes + """ + + return cls.normalize_indexes(inspector.get_indexes(table_name, schema)) + @classmethod def extra_table_metadata( cls, database: "Database", table_name: str, schema_name: Optional[str] diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index a9b4f7aa60d61..12183c2b19137 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -561,10 +561,18 @@ def latest_partition( ) column_names = indexes[0]["column_names"] - part_fields = [(column_name, True) for column_name in column_names] - sql = cls._partition_query(table_name, database, 1, part_fields) - df = database.get_df(sql, schema) - return column_names, cls._latest_partition_from_df(df) + + return column_names, cls._latest_partition_from_df( + df=database.get_df( + sql=cls._partition_query( + table_name, + database, + limit=1, + order_by=[(column_name, True) for column_name in column_names], + ), + schema=schema, + ) + ) @classmethod def latest_sub_partition( diff --git a/superset/models/core.py b/superset/models/core.py index d7a38cdc033a5..d3daad52d3060 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -847,8 +847,7 @@ def get_indexes( self, table_name: str, schema: Optional[str] = None ) -> List[Dict[str, Any]]: with self.get_inspector_with_context() as inspector: - indexes = inspector.get_indexes(table_name, schema) - return self.db_engine_spec.normalize_indexes(indexes) + return self.db_engine_spec.get_indexes(self, inspector, table_name, schema) def get_pk_constraint( self, table_name: str, schema: Optional[str] = None diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index 6048ac8f19968..188fc94946a90 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -521,3 +521,26 @@ def test_validate_parameters_port_closed(is_port_open, is_hostname_valid): }, ) ] + + +def test_get_indexes(): + indexes = [ + { + "name": "partition", + "column_names": ["a", "b"], + "unique": False, + }, + ] + + inspector = mock.Mock() + inspector.get_indexes = mock.Mock(return_value=indexes) + + assert ( + BaseEngineSpec.get_indexes( + database=mock.Mock(), + inspector=inspector, + table_name="bar", + schema="foo", + ) + == indexes + ) diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index bf02765840586..2f4f1c70cc82f 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -144,27 +144,78 @@ def test_extra_table_metadata(self): ) self.assertEqual(result, expected_result) - def test_normalize_indexes(self): - """ - DB Eng Specs (bigquery): Test extra table metadata - """ - indexes = [{"name": "partition", "column_names": [None], "unique": False}] - normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes) - self.assertEqual(normalized_idx, []) + def test_get_indexes(self): + database = mock.Mock() + inspector = mock.Mock() + schema = "foo" + table_name = "bar" - indexes = [{"name": "partition", "column_names": ["dttm"], "unique": False}] - normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes) - self.assertEqual(normalized_idx, indexes) + inspector.get_indexes = mock.Mock( + return_value=[ + { + "name": "partition", + "column_names": [None], + "unique": False, + } + ] + ) - indexes = [ - {"name": "partition", "column_names": ["dttm", None], "unique": False} + assert ( + BigQueryEngineSpec.get_indexes( + database, + inspector, + table_name, + schema, + ) + == [] + ) + + inspector.get_indexes = mock.Mock( + return_value=[ + { + "name": "partition", + "column_names": ["dttm"], + "unique": False, + } + ] + ) + + assert BigQueryEngineSpec.get_indexes( + database, + inspector, + table_name, + schema, + ) == [ + { + "name": "partition", + "column_names": ["dttm"], + "unique": False, + } ] - normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes) - self.assertEqual( - normalized_idx, - [{"name": "partition", "column_names": ["dttm"], "unique": False}], + + inspector.get_indexes = mock.Mock( + return_value=[ + { + "name": "partition", + "column_names": ["dttm", None], + "unique": False, + } + ] ) + assert BigQueryEngineSpec.get_indexes( + database, + inspector, + table_name, + schema, + ) == [ + { + "name": "partition", + "column_names": ["dttm"], + "unique": False, + } + ] + @mock.patch("superset.db_engine_specs.bigquery.BigQueryEngineSpec.get_engine") @mock.patch("superset.db_engine_specs.bigquery.pandas_gbq") @mock.patch("superset.db_engine_specs.bigquery.service_account")