diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index cb1673c8287a7..7ac5675e6ecd2 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1932,7 +1932,10 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult: :return: Tuple with lists of added, removed and modified column names. """ new_columns = self.external_metadata() - metrics = [] + metrics = [ + SqlMetric(**metric) + for metric in self.database.get_metrics(self.table_name, self.schema) + ] any_date_col = None db_engine_spec = self.db_engine_spec @@ -1989,14 +1992,6 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult: columns.extend([col for col in old_columns if col.expression]) self.columns = columns - metrics.append( - SqlMetric( - metric_name="count", - verbose_name="COUNT(*)", - metric_type="count", - expression="COUNT(*)", - ) - ) if not self.main_dttm_col: self.main_dttm_col = any_date_col self.add_missing_metrics(metrics) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index e95e39c1fb50f..368770e2612f5 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -154,6 +154,21 @@ class LimitMethod: # pylint: disable=too-few-public-methods FORCE_LIMIT = "force_limit" +class MetricType(TypedDict, total=False): + """ + Type for metrics return by `get_metrics`. + """ + + metric_name: str + expression: str + verbose_name: Optional[str] + metric_type: Optional[str] + description: Optional[str] + d3format: Optional[str] + warning_text: Optional[str] + extra: Optional[str] + + class BaseEngineSpec: # pylint: disable=too-many-public-methods """Abstract class for database engine specific configurations @@ -1054,6 +1069,26 @@ def get_columns( """ return inspector.get_columns(table_name, schema) + @classmethod + def get_metrics( # pylint: disable=unused-argument + cls, + database: "Database", + inspector: Inspector, + table_name: str, + schema: Optional[str], + ) -> List[MetricType]: + """ + Get all metrics from a given schema and table. + """ + return [ + { + "metric_name": "count", + "verbose_name": "COUNT(*)", + "metric_type": "count", + "expression": "COUNT(*)", + } + ] + @classmethod def where_latest_partition( # pylint: disable=too-many-arguments,unused-argument cls, diff --git a/superset/models/core.py b/superset/models/core.py index 617c23ef9edb1..b5a4aa6537da2 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -55,7 +55,7 @@ from superset import app, db_engine_specs, is_feature_enabled from superset.databases.utils import make_url_safe -from superset.db_engine_specs.base import TimeGrain +from superset.db_engine_specs.base import MetricType, TimeGrain from superset.extensions import cache_manager, encrypted_field_factory, security_manager from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.models.tags import FavStarUpdater @@ -693,6 +693,13 @@ def get_columns( ) -> List[Dict[str, Any]]: return self.db_engine_spec.get_columns(self.inspector, table_name, schema) + def get_metrics( + self, + table_name: str, + schema: Optional[str] = None, + ) -> List[MetricType]: + return self.db_engine_spec.get_metrics(self, self.inspector, table_name, schema) + def get_indexes( self, table_name: str, schema: Optional[str] = None ) -> List[Dict[str, Any]]: diff --git a/tests/unit_tests/models/__init__.py b/tests/unit_tests/models/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/unit_tests/models/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py new file mode 100644 index 0000000000000..3338ddcb61441 --- /dev/null +++ b/tests/unit_tests/models/core_test.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel + +from typing import List, Optional + +from pytest_mock import MockFixture +from sqlalchemy.engine.reflection import Inspector + + +def test_get_metrics(mocker: MockFixture) -> None: + """ + Tests for ``get_metrics``. + """ + from superset.db_engine_specs.base import MetricType + from superset.db_engine_specs.sqlite import SqliteEngineSpec + from superset.models.core import Database + + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + assert database.get_metrics("table") == [ + { + "expression": "COUNT(*)", + "metric_name": "count", + "metric_type": "count", + "verbose_name": "COUNT(*)", + } + ] + + class CustomSqliteEngineSpec(SqliteEngineSpec): + @classmethod + def get_metrics( + cls, + database: Database, + inspector: Inspector, + table_name: str, + schema: Optional[str], + ) -> List[MetricType]: + return [ + { + "expression": "COUNT(DISTINCT user_id)", + "metric_name": "count_distinct_user_id", + "metric_type": "count_distinct", + "verbose_name": "COUNT(DISTINCT user_id)", + }, + ] + + database.get_db_engine_spec_for_backend = mocker.MagicMock( # type: ignore + return_value=CustomSqliteEngineSpec + ) + assert database.get_metrics("table") == [ + { + "expression": "COUNT(DISTINCT user_id)", + "metric_name": "count_distinct_user_id", + "metric_type": "count_distinct", + "verbose_name": "COUNT(DISTINCT user_id)", + }, + ]