From 30170ff1e86240ee835f2de620b6454dea2fe0a9 Mon Sep 17 00:00:00 2001 From: Daniel Vaz Gaspar Date: Mon, 17 Aug 2020 15:46:59 +0100 Subject: [PATCH] feat: dataset REST API for distinct values (#10595) * feat: dataset REST API for distinct values * add tests and fix lint * fix mypy, and tests * fix docs * fix test * lint * fix test --- superset/constants.py | 1 + superset/datasets/api.py | 2 + superset/views/base_api.py | 154 +++++++++++++++++++++++++++++------- tests/datasets/api_tests.py | 91 ++++++++++++++++++++- 4 files changed, 218 insertions(+), 30 deletions(-) diff --git a/superset/constants.py b/superset/constants.py index dbe6fd12086bd..ea14a38a0be2a 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -55,6 +55,7 @@ class RouteMethod: # pylint: disable=too-few-public-methods POST = "post" PUT = "put" RELATED = "related" + DISTINCT = "distinct" # Commonly used sets API_SET = {API_CREATE, API_DELETE, API_GET, API_READ, API_UPDATE} diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 0d7973b4a79d2..b04cef8184454 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -67,6 +67,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | { RouteMethod.EXPORT, RouteMethod.RELATED, + RouteMethod.DISTINCT, "refresh", "related_objects", } @@ -151,6 +152,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): } filter_rel_fields = {"database": [["id", DatabaseFilter, lambda: []]]} allowed_rel_fields = {"database", "owners"} + allowed_distinct_fields = {"schema"} openapi_spec_component_schemas = (DatasetRelatedObjectsResponse,) diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 56fdba0197b9d..4d7fe7d267c95 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -19,12 +19,15 @@ from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union from apispec import APISpec +from apispec.exceptions import DuplicateComponentNameError from flask import Blueprint, Response -from flask_appbuilder import AppBuilder, Model, ModelRestApi +from flask_appbuilder import AppBuilder, ModelRestApi from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.filters import BaseFilter, Filters from flask_appbuilder.models.sqla.filters import FilterStartsWith -from marshmallow import Schema +from flask_appbuilder.models.sqla.interface import SQLAInterface +from marshmallow import fields, Schema +from sqlalchemy import distinct, func from superset.stats_logger import BaseStatsLogger from superset.typing import FlaskResponse @@ -41,6 +44,25 @@ } +class RelatedResultResponseSchema(Schema): + value = fields.Integer(description="The related item identifier") + text = fields.String(description="The related item string representation") + + +class RelatedResponseSchema(Schema): + count = fields.Integer(description="The total number of related values") + result = fields.List(fields.Nested(RelatedResultResponseSchema)) + + +class DistinctResultResponseSchema(Schema): + text = fields.String(description="The distinct item") + + +class DistincResponseSchema(Schema): + count = fields.Integer(description="The total number of distinct values") + result = fields.List(fields.Nested(DistinctResultResponseSchema)) + + def statsd_metrics(f: Callable[..., Any]) -> Callable[..., Any]: """ Handle sending all statsd metrics from the REST API @@ -78,6 +100,7 @@ class BaseSupersetModelRestApi(ModelRestApi): "bulk_delete": "delete", "info": "list", "related": "list", + "distinct": "list", "thumbnail": "list", "refresh": "edit", "data": "list", @@ -112,6 +135,8 @@ class BaseSupersetModelRestApi(ModelRestApi): """ # pylint: disable=pointless-string-statement allowed_rel_fields: Set[str] = set() + allowed_distinct_fields: Set[str] = set() + openapi_spec_component_schemas: Tuple[Type[Schema], ...] = tuple() """ Add extra schemas to the OpenAPI component schemas section @@ -123,15 +148,29 @@ class BaseSupersetModelRestApi(ModelRestApi): show_columns: List[str] def __init__(self) -> None: - super().__init__() + # Setup statsd self.stats_logger = BaseStatsLogger() + # Add base API spec base query parameter schemas + if self.apispec_parameter_schemas is None: # type: ignore + self.apispec_parameter_schemas = {} + self.apispec_parameter_schemas["get_related_schema"] = get_related_schema + if self.openapi_spec_component_schemas is None: + self.openapi_spec_component_schemas = () + self.openapi_spec_component_schemas = self.openapi_spec_component_schemas + ( + RelatedResponseSchema, + DistincResponseSchema, + ) + super().__init__() def add_apispec_components(self, api_spec: APISpec) -> None: for schema in self.openapi_spec_component_schemas: - api_spec.components.schema( - schema.__name__, schema=schema, - ) + try: + api_spec.components.schema( + schema.__name__, schema=schema, + ) + except DuplicateComponentNameError: + pass super().add_apispec_components(api_spec) def create_blueprint( @@ -153,7 +192,7 @@ def _init_properties(self) -> None: super()._init_properties() def _get_related_filter( - self, datamodel: Model, column_name: str, value: str + self, datamodel: SQLAInterface, column_name: str, value: str ) -> Filters: filter_field = self.related_field_filters.get(column_name) if isinstance(filter_field, str): @@ -170,6 +209,18 @@ def _get_related_filter( ) return filters + def _get_distinct_filter(self, column_name: str, value: str) -> Filters: + filter_field = RelatedFieldFilter(column_name, FilterStartsWith) + filter_field = cast(RelatedFieldFilter, filter_field) + search_columns = [filter_field.field_name] if filter_field else None + filters = self.datamodel.get_filters(search_columns) + filters.add_filter_list(self.base_filters) + if value and filter_field: + filters.add_filter( + filter_field.field_name, filter_field.filter_class, value + ) + return filters + def incr_stats(self, action: str, func_name: str) -> None: """ Proxy function for statsd.incr to impose a key structure for REST API's @@ -251,39 +302,21 @@ def related(self, column_name: str, **kwargs: Any) -> FlaskResponse: content: application/json: schema: - type: object - properties: - page_size: - type: integer - page: - type: integer - filter: - type: string + $ref: '#/components/schemas/get_related_schema' responses: 200: description: Related column data content: application/json: schema: - type: object - properties: - count: - type: integer - result: - type: object - properties: - value: - type: integer - text: - type: string + schema: + $ref: "#/components/schemas/RelatedResponseSchema" 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' 404: $ref: '#/components/responses/404' - 422: - $ref: '#/components/responses/422' 500: $ref: '#/components/responses/500' """ @@ -316,3 +349,68 @@ def related(self, column_name: str, **kwargs: Any) -> FlaskResponse: for value in values ] return self.response(200, count=count, result=result) + + @expose("/distinct/", methods=["GET"]) + @protect() + @safe + @statsd_metrics + @rison(get_related_schema) + def distinct(self, column_name: str, **kwargs: Any) -> FlaskResponse: + """Get distinct values from field data + --- + get: + parameters: + - in: path + schema: + type: string + name: column_name + - in: query + name: q + content: + application/json: + schema: + $ref: '#/components/schemas/get_related_schema' + responses: + 200: + description: Distinct field data + content: + application/json: + schema: + schema: + $ref: "#/components/schemas/DistincResponseSchema" + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + if column_name not in self.allowed_distinct_fields: + self.incr_stats("error", self.related.__name__) + return self.response_404() + args = kwargs.get("rison", {}) + # handle pagination + page, page_size = self._sanitize_page_args(*self._handle_page_args(args)) + # Create generic base filters with added request filter + filters = self._get_distinct_filter(column_name, args.get("filter")) + # Make the query + query_count = self.appbuilder.get_session.query( + func.count(distinct(getattr(self.datamodel.obj, column_name))) + ) + count = self.datamodel.apply_filters(query_count, filters).scalar() + if count == 0: + return self.response(200, count=count, result=[]) + query = self.appbuilder.get_session.query( + distinct(getattr(self.datamodel.obj, column_name)) + ) + # Apply generic base filters with added request filter + query = self.datamodel.apply_filters(query, filters) + # Apply sort + query = self.datamodel.apply_order_by(query, column_name, "asc") + # Apply pagination + result = self.datamodel.apply_pagination(query, page, page_size).all() + # produce response + result = [{"text": item[0]} for item in result if item[0] is not None] + return self.response(200, count=count, result=result) diff --git a/tests/datasets/api_tests.py b/tests/datasets/api_tests.py index a7795f962af31..798e0dde8ab52 100644 --- a/tests/datasets/api_tests.py +++ b/tests/datasets/api_tests.py @@ -16,7 +16,7 @@ # under the License. """Unit tests for Superset""" import json -from typing import List +from typing import Any, Dict, List, Tuple, Union from unittest.mock import patch import prison @@ -129,7 +129,6 @@ def test_get_dataset_related_database_gamma(self): """ Dataset API: Test get dataset related databases gamma """ - example_db = get_example_database() self.login(username="gamma") uri = "api/v1/dataset/related/database" rv = self.client.get(uri) @@ -170,6 +169,93 @@ def test_get_dataset_item(self): self.assertEqual(len(response["result"]["columns"]), 3) self.assertEqual(len(response["result"]["metrics"]), 2) + def test_get_dataset_distinct_schema(self): + """ + Dataset API: Test get dataset distinct schema + """ + + def pg_test_query_parameter(query_parameter, expected_response): + uri = f"api/v1/dataset/distinct/schema?q={prison.dumps(query_parameter)}" + rv = self.client.get(uri) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + self.assertEqual(response, expected_response) + + example_db = get_example_database() + datasets = [] + if example_db.backend == "postgresql": + datasets.append( + self.insert_dataset("ab_permission", "public", [], get_main_database()) + ) + datasets.append( + self.insert_dataset( + "columns", "information_schema", [], get_main_database() + ) + ) + expected_response = { + "count": 5, + "result": [ + {"text": ""}, + {"text": "admin_database"}, + {"text": "information_schema"}, + {"text": "public"}, + {"text": "superset"}, + ], + } + self.login(username="admin") + uri = "api/v1/dataset/distinct/schema" + rv = self.client.get(uri) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + self.assertEqual(response, expected_response) + + # Test filter + query_parameter = {"filter": "inf"} + pg_test_query_parameter( + query_parameter, + {"count": 1, "result": [{"text": "information_schema"}]}, + ) + + query_parameter = {"page": 0, "page_size": 1} + pg_test_query_parameter( + query_parameter, {"count": 5, "result": [{"text": ""}]}, + ) + + query_parameter = {"page": 1, "page_size": 1} + pg_test_query_parameter( + query_parameter, {"count": 5, "result": [{"text": "admin_database"}]} + ) + + for dataset in datasets: + db.session.delete(dataset) + db.session.commit() + + def test_get_dataset_distinct_not_allowed(self): + """ + Dataset API: Test get dataset distinct not allowed + """ + self.login(username="admin") + uri = "api/v1/dataset/distinct/table_name" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + def test_get_dataset_distinct_gamma(self): + """ + Dataset API: Test get dataset distinct with gamma + """ + dataset = self.insert_default_dataset() + + self.login(username="gamma") + uri = "api/v1/dataset/distinct/schema" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["count"], 0) + self.assertEqual(response["result"], []) + + db.session.delete(dataset) + db.session.commit() + def test_get_dataset_info(self): """ Dataset API: Test get dataset info @@ -358,6 +444,7 @@ def test_update_dataset_item(self): self.assertEqual(rv.status_code, 200) model = db.session.query(SqlaTable).get(dataset.id) self.assertEqual(model.description, dataset_data["description"]) + db.session.delete(dataset) db.session.commit()