Skip to content

Commit

Permalink
chore(db_engine_specs): Refactor get_index
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley committed Apr 12, 2023
1 parent 976e333 commit ebd79e5
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 22 deletions.
22 changes: 22 additions & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import sqlparse
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from deprecation import deprecated
from flask import current_app
from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as __, lazy_gettext as _
Expand Down Expand Up @@ -797,6 +798,7 @@ def get_datatype(cls, type_code: Any) -> Optional[str]:
return None

@classmethod
@deprecated(deprecated_in="3.0")
def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Normalizes indexes for more consistency across db engines
Expand Down Expand Up @@ -1179,6 +1181,26 @@ def get_view_names( # pylint: disable=unused-argument
views = {re.sub(f"^{schema}\\.", "", view) for view in views}
return views

@classmethod
def get_indexes(
cls,
database: Database, # pylint: disable=unused-argument
inspector: Inspector,
table_name: str,
schema: Optional[str],
) -> List[Dict[str, Any]]:
"""
Get the indexes associated with the specified schema/table.
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
:param table_name: The table to inspect
:param schema: The schema to inspect
:returns: The indexes
"""

return inspector.get_indexes(table_name, schema)

@classmethod
def get_table_comment(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
Expand Down
22 changes: 22 additions & 0 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pandas as pd
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from deprecation import deprecated
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.exceptions import ValidationError
Expand Down Expand Up @@ -278,6 +279,7 @@ def _truncate_label(cls, label: str) -> str:
return "_" + md5_sha_from_str(label)

@classmethod
@deprecated(deprecated_in="3.0")
def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Normalizes indexes for more consistency across db engines
Expand All @@ -296,6 +298,26 @@ def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]
normalized_idxs.append(ix)
return normalized_idxs

@classmethod
def get_indexes(
cls,
database: "Database",
inspector: Inspector,
table_name: str,
schema: Optional[str],
) -> List[Dict[str, Any]]:
"""
Get the indexes associated with the specified schema/table.
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
:param table_name: The table to inspect
:param schema: The schema to inspect
:returns: The indexes
"""

return cls.normalize_indexes(inspector.get_indexes(table_name, schema))

@classmethod
def extra_table_metadata(
cls, database: "Database", table_name: str, schema_name: Optional[str]
Expand Down
16 changes: 12 additions & 4 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,10 +561,18 @@ def latest_partition(
)

column_names = indexes[0]["column_names"]
part_fields = [(column_name, True) for column_name in column_names]
sql = cls._partition_query(table_name, database, 1, part_fields)
df = database.get_df(sql, schema)
return column_names, cls._latest_partition_from_df(df)

return column_names, cls._latest_partition_from_df(
df=database.get_df(
sql=cls._partition_query(
table_name,
database,
limit=1,
order_by=[(column_name, True) for column_name in column_names],
),
schema=schema,
)
)

@classmethod
def latest_sub_partition(
Expand Down
3 changes: 1 addition & 2 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,8 +847,7 @@ def get_indexes(
self, table_name: str, schema: Optional[str] = None
) -> List[Dict[str, Any]]:
with self.get_inspector_with_context() as inspector:
indexes = inspector.get_indexes(table_name, schema)
return self.db_engine_spec.normalize_indexes(indexes)
return self.db_engine_spec.get_indexes(self, inspector, table_name, schema)

def get_pk_constraint(
self, table_name: str, schema: Optional[str] = None
Expand Down
23 changes: 23 additions & 0 deletions tests/integration_tests/db_engine_specs/base_engine_spec_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,26 @@ def test_validate_parameters_port_closed(is_port_open, is_hostname_valid):
},
)
]


def test_get_indexes():
indexes = [
{
"name": "partition",
"column_names": ["a", "b"],
"unique": False,
},
]

inspector = mock.Mock()
inspector.get_indexes = mock.Mock(return_value=indexes)

assert (
BaseEngineSpec.get_indexes(
database=mock.Mock(),
inspector=inspector,
table_name="bar",
schema="foo",
)
== indexes
)
83 changes: 67 additions & 16 deletions tests/integration_tests/db_engine_specs/bigquery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,27 +144,78 @@ def test_extra_table_metadata(self):
)
self.assertEqual(result, expected_result)

def test_normalize_indexes(self):
"""
DB Eng Specs (bigquery): Test extra table metadata
"""
indexes = [{"name": "partition", "column_names": [None], "unique": False}]
normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes)
self.assertEqual(normalized_idx, [])
def test_get_indexes(self):
database = mock.Mock()
inspector = mock.Mock()
schema = "foo"
table_name = "bar"

indexes = [{"name": "partition", "column_names": ["dttm"], "unique": False}]
normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes)
self.assertEqual(normalized_idx, indexes)
inspector.get_indexes = mock.Mock(
return_value=[
{
"name": "partition",
"column_names": [None],
"unique": False,
}
]
)

indexes = [
{"name": "partition", "column_names": ["dttm", None], "unique": False}
assert (
BigQueryEngineSpec.get_indexes(
database,
inspector,
table_name,
schema,
)
== []
)

inspector.get_indexes = mock.Mock(
return_value=[
{
"name": "partition",
"column_names": ["dttm"],
"unique": False,
}
]
)

assert BigQueryEngineSpec.get_indexes(
database,
inspector,
table_name,
schema,
) == [
{
"name": "partition",
"column_names": ["dttm"],
"unique": False,
}
]
normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes)
self.assertEqual(
normalized_idx,
[{"name": "partition", "column_names": ["dttm"], "unique": False}],

inspector.get_indexes = mock.Mock(
return_value=[
{
"name": "partition",
"column_names": ["dttm", None],
"unique": False,
}
]
)

assert BigQueryEngineSpec.get_indexes(
database,
inspector,
table_name,
schema,
) == [
{
"name": "partition",
"column_names": ["dttm"],
"unique": False,
}
]

@mock.patch("superset.db_engine_specs.bigquery.BigQueryEngineSpec.get_engine")
@mock.patch("superset.db_engine_specs.bigquery.pandas_gbq")
@mock.patch("superset.db_engine_specs.bigquery.service_account")
Expand Down

0 comments on commit ebd79e5

Please sign in to comment.