Skip to content

Commit

Permalink
feat: dataset REST API for distinct values (apache#10595)
Browse files Browse the repository at this point in the history
* feat: dataset REST API for distinct values

* add tests and fix lint

* fix mypy, and tests

* fix docs

* fix test

* lint

* fix test
  • Loading branch information
dpgaspar authored and Ofeknielsen committed Oct 5, 2020
1 parent b214e69 commit 30170ff
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 30 deletions.
1 change: 1 addition & 0 deletions superset/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class RouteMethod: # pylint: disable=too-few-public-methods
POST = "post"
PUT = "put"
RELATED = "related"
DISTINCT = "distinct"

# Commonly used sets
API_SET = {API_CREATE, API_DELETE, API_GET, API_READ, API_UPDATE}
Expand Down
2 changes: 2 additions & 0 deletions superset/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {
RouteMethod.EXPORT,
RouteMethod.RELATED,
RouteMethod.DISTINCT,
"refresh",
"related_objects",
}
Expand Down Expand Up @@ -151,6 +152,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
}
filter_rel_fields = {"database": [["id", DatabaseFilter, lambda: []]]}
allowed_rel_fields = {"database", "owners"}
allowed_distinct_fields = {"schema"}

openapi_spec_component_schemas = (DatasetRelatedObjectsResponse,)

Expand Down
154 changes: 126 additions & 28 deletions superset/views/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union

from apispec import APISpec
from apispec.exceptions import DuplicateComponentNameError
from flask import Blueprint, Response
from flask_appbuilder import AppBuilder, Model, ModelRestApi
from flask_appbuilder import AppBuilder, ModelRestApi
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.filters import BaseFilter, Filters
from flask_appbuilder.models.sqla.filters import FilterStartsWith
from marshmallow import Schema
from flask_appbuilder.models.sqla.interface import SQLAInterface
from marshmallow import fields, Schema
from sqlalchemy import distinct, func

from superset.stats_logger import BaseStatsLogger
from superset.typing import FlaskResponse
Expand All @@ -41,6 +44,25 @@
}


class RelatedResultResponseSchema(Schema):
value = fields.Integer(description="The related item identifier")
text = fields.String(description="The related item string representation")


class RelatedResponseSchema(Schema):
count = fields.Integer(description="The total number of related values")
result = fields.List(fields.Nested(RelatedResultResponseSchema))


class DistinctResultResponseSchema(Schema):
text = fields.String(description="The distinct item")


class DistincResponseSchema(Schema):
count = fields.Integer(description="The total number of distinct values")
result = fields.List(fields.Nested(DistinctResultResponseSchema))


def statsd_metrics(f: Callable[..., Any]) -> Callable[..., Any]:
"""
Handle sending all statsd metrics from the REST API
Expand Down Expand Up @@ -78,6 +100,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
"bulk_delete": "delete",
"info": "list",
"related": "list",
"distinct": "list",
"thumbnail": "list",
"refresh": "edit",
"data": "list",
Expand Down Expand Up @@ -112,6 +135,8 @@ class BaseSupersetModelRestApi(ModelRestApi):
""" # pylint: disable=pointless-string-statement
allowed_rel_fields: Set[str] = set()

allowed_distinct_fields: Set[str] = set()

openapi_spec_component_schemas: Tuple[Type[Schema], ...] = tuple()
"""
Add extra schemas to the OpenAPI component schemas section
Expand All @@ -123,15 +148,29 @@ class BaseSupersetModelRestApi(ModelRestApi):
show_columns: List[str]

def __init__(self) -> None:
super().__init__()
# Setup statsd
self.stats_logger = BaseStatsLogger()
# Add base API spec base query parameter schemas
if self.apispec_parameter_schemas is None: # type: ignore
self.apispec_parameter_schemas = {}
self.apispec_parameter_schemas["get_related_schema"] = get_related_schema
if self.openapi_spec_component_schemas is None:
self.openapi_spec_component_schemas = ()
self.openapi_spec_component_schemas = self.openapi_spec_component_schemas + (
RelatedResponseSchema,
DistincResponseSchema,
)
super().__init__()

def add_apispec_components(self, api_spec: APISpec) -> None:

for schema in self.openapi_spec_component_schemas:
api_spec.components.schema(
schema.__name__, schema=schema,
)
try:
api_spec.components.schema(
schema.__name__, schema=schema,
)
except DuplicateComponentNameError:
pass
super().add_apispec_components(api_spec)

def create_blueprint(
Expand All @@ -153,7 +192,7 @@ def _init_properties(self) -> None:
super()._init_properties()

def _get_related_filter(
self, datamodel: Model, column_name: str, value: str
self, datamodel: SQLAInterface, column_name: str, value: str
) -> Filters:
filter_field = self.related_field_filters.get(column_name)
if isinstance(filter_field, str):
Expand All @@ -170,6 +209,18 @@ def _get_related_filter(
)
return filters

def _get_distinct_filter(self, column_name: str, value: str) -> Filters:
filter_field = RelatedFieldFilter(column_name, FilterStartsWith)
filter_field = cast(RelatedFieldFilter, filter_field)
search_columns = [filter_field.field_name] if filter_field else None
filters = self.datamodel.get_filters(search_columns)
filters.add_filter_list(self.base_filters)
if value and filter_field:
filters.add_filter(
filter_field.field_name, filter_field.filter_class, value
)
return filters

def incr_stats(self, action: str, func_name: str) -> None:
"""
Proxy function for statsd.incr to impose a key structure for REST API's
Expand Down Expand Up @@ -251,39 +302,21 @@ def related(self, column_name: str, **kwargs: Any) -> FlaskResponse:
content:
application/json:
schema:
type: object
properties:
page_size:
type: integer
page:
type: integer
filter:
type: string
$ref: '#/components/schemas/get_related_schema'
responses:
200:
description: Related column data
content:
application/json:
schema:
type: object
properties:
count:
type: integer
result:
type: object
properties:
value:
type: integer
text:
type: string
schema:
$ref: "#/components/schemas/RelatedResponseSchema"
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
Expand Down Expand Up @@ -316,3 +349,68 @@ def related(self, column_name: str, **kwargs: Any) -> FlaskResponse:
for value in values
]
return self.response(200, count=count, result=result)

@expose("/distinct/<column_name>", methods=["GET"])
@protect()
@safe
@statsd_metrics
@rison(get_related_schema)
def distinct(self, column_name: str, **kwargs: Any) -> FlaskResponse:
"""Get distinct values from field data
---
get:
parameters:
- in: path
schema:
type: string
name: column_name
- in: query
name: q
content:
application/json:
schema:
$ref: '#/components/schemas/get_related_schema'
responses:
200:
description: Distinct field data
content:
application/json:
schema:
schema:
$ref: "#/components/schemas/DistincResponseSchema"
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
500:
$ref: '#/components/responses/500'
"""
if column_name not in self.allowed_distinct_fields:
self.incr_stats("error", self.related.__name__)
return self.response_404()
args = kwargs.get("rison", {})
# handle pagination
page, page_size = self._sanitize_page_args(*self._handle_page_args(args))
# Create generic base filters with added request filter
filters = self._get_distinct_filter(column_name, args.get("filter"))
# Make the query
query_count = self.appbuilder.get_session.query(
func.count(distinct(getattr(self.datamodel.obj, column_name)))
)
count = self.datamodel.apply_filters(query_count, filters).scalar()
if count == 0:
return self.response(200, count=count, result=[])
query = self.appbuilder.get_session.query(
distinct(getattr(self.datamodel.obj, column_name))
)
# Apply generic base filters with added request filter
query = self.datamodel.apply_filters(query, filters)
# Apply sort
query = self.datamodel.apply_order_by(query, column_name, "asc")
# Apply pagination
result = self.datamodel.apply_pagination(query, page, page_size).all()
# produce response
result = [{"text": item[0]} for item in result if item[0] is not None]
return self.response(200, count=count, result=result)
91 changes: 89 additions & 2 deletions tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""Unit tests for Superset"""
import json
from typing import List
from typing import Any, Dict, List, Tuple, Union
from unittest.mock import patch

import prison
Expand Down Expand Up @@ -129,7 +129,6 @@ def test_get_dataset_related_database_gamma(self):
"""
Dataset API: Test get dataset related databases gamma
"""
example_db = get_example_database()
self.login(username="gamma")
uri = "api/v1/dataset/related/database"
rv = self.client.get(uri)
Expand Down Expand Up @@ -170,6 +169,93 @@ def test_get_dataset_item(self):
self.assertEqual(len(response["result"]["columns"]), 3)
self.assertEqual(len(response["result"]["metrics"]), 2)

def test_get_dataset_distinct_schema(self):
"""
Dataset API: Test get dataset distinct schema
"""

def pg_test_query_parameter(query_parameter, expected_response):
uri = f"api/v1/dataset/distinct/schema?q={prison.dumps(query_parameter)}"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(response, expected_response)

example_db = get_example_database()
datasets = []
if example_db.backend == "postgresql":
datasets.append(
self.insert_dataset("ab_permission", "public", [], get_main_database())
)
datasets.append(
self.insert_dataset(
"columns", "information_schema", [], get_main_database()
)
)
expected_response = {
"count": 5,
"result": [
{"text": ""},
{"text": "admin_database"},
{"text": "information_schema"},
{"text": "public"},
{"text": "superset"},
],
}
self.login(username="admin")
uri = "api/v1/dataset/distinct/schema"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(response, expected_response)

# Test filter
query_parameter = {"filter": "inf"}
pg_test_query_parameter(
query_parameter,
{"count": 1, "result": [{"text": "information_schema"}]},
)

query_parameter = {"page": 0, "page_size": 1}
pg_test_query_parameter(
query_parameter, {"count": 5, "result": [{"text": ""}]},
)

query_parameter = {"page": 1, "page_size": 1}
pg_test_query_parameter(
query_parameter, {"count": 5, "result": [{"text": "admin_database"}]}
)

for dataset in datasets:
db.session.delete(dataset)
db.session.commit()

def test_get_dataset_distinct_not_allowed(self):
"""
Dataset API: Test get dataset distinct not allowed
"""
self.login(username="admin")
uri = "api/v1/dataset/distinct/table_name"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)

def test_get_dataset_distinct_gamma(self):
"""
Dataset API: Test get dataset distinct with gamma
"""
dataset = self.insert_default_dataset()

self.login(username="gamma")
uri = "api/v1/dataset/distinct/schema"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["count"], 0)
self.assertEqual(response["result"], [])

db.session.delete(dataset)
db.session.commit()

def test_get_dataset_info(self):
"""
Dataset API: Test get dataset info
Expand Down Expand Up @@ -358,6 +444,7 @@ def test_update_dataset_item(self):
self.assertEqual(rv.status_code, 200)
model = db.session.query(SqlaTable).get(dataset.id)
self.assertEqual(model.description, dataset_data["description"])

db.session.delete(dataset)
db.session.commit()

Expand Down

0 comments on commit 30170ff

Please sign in to comment.