From be4aa5d6ae52817c24009f8c73b6529eab681135 Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Tue, 12 Jul 2022 18:19:15 +0800 Subject: [PATCH 01/12] feat: samples with filter support pagination --- .../src/components/Chart/chartAction.js | 3 +- .../DataTablesPane/test/SamplesPane.test.tsx | 8 +- superset/datasets/api.py | 30 +++++- superset/datasets/commands/samples.py | 97 +++++++++++++++---- superset/datasets/schemas.py | 12 +++ tests/integration_tests/datasets/api_tests.py | 51 ++++++++-- 6 files changed, 166 insertions(+), 35 deletions(-) diff --git a/superset-frontend/src/components/Chart/chartAction.js b/superset-frontend/src/components/Chart/chartAction.js index 139d91cd1d703..be887a4b3d505 100644 --- a/superset-frontend/src/components/Chart/chartAction.js +++ b/superset-frontend/src/components/Chart/chartAction.js @@ -602,10 +602,11 @@ export const getDatasourceSamples = async ( datasourceType, datasourceId, force, + jsonPayload, ) => { const endpoint = `/api/v1/explore/samples?force=${force}&datasource_type=${datasourceType}&datasource_id=${datasourceId}`; try { - const response = await SupersetClient.get({ endpoint }); + const response = await SupersetClient.post({ endpoint, jsonPayload }); return response.json.result; } catch (err) { const clientError = await getClientErrorObject(err); diff --git a/superset-frontend/src/explore/components/DataTablesPane/test/SamplesPane.test.tsx b/superset-frontend/src/explore/components/DataTablesPane/test/SamplesPane.test.tsx index 0aa0b03a06779..6ea8fafe378e8 100644 --- a/superset-frontend/src/explore/components/DataTablesPane/test/SamplesPane.test.tsx +++ b/superset-frontend/src/explore/components/DataTablesPane/test/SamplesPane.test.tsx @@ -29,8 +29,8 @@ import { SamplesPane } from '../components'; import { createSamplesPaneProps } from './fixture'; describe('SamplesPane', () => { - fetchMock.get( - 'end:/api/v1/explore/samples?force=false&datasource_type=table&datasource_id=34', + fetchMock.post( + '/api/v1/explore/samples?force=false&datasource_type=table&datasource_id=34', { result: { data: [], @@ -40,7 +40,7 @@ describe('SamplesPane', () => { }, ); - fetchMock.get( + fetchMock.post( 'end:/api/v1/explore/samples?force=true&datasource_type=table&datasource_id=35', { result: { @@ -54,7 +54,7 @@ describe('SamplesPane', () => { }, ); - fetchMock.get( + fetchMock.post( 'end:/api/v1/explore/samples?force=false&datasource_type=table&datasource_id=36', 400, ); diff --git a/superset/datasets/api.py b/superset/datasets/api.py index f6890655ed321..c484260ac7374 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -60,9 +60,11 @@ DatasetPostSchema, DatasetPutSchema, DatasetRelatedObjectsResponse, + DatasetSamplesQuerySchema, get_delete_ids_schema, get_export_ids_schema, ) +from superset.exceptions import QueryClauseValidationException from superset.utils.core import json_int_dttm_ser, parse_boolean_string from superset.views.base import DatasourceFilter, generate_download_headers from superset.views.base_api import ( @@ -224,7 +226,10 @@ class DatasetRestApi(BaseSupersetModelRestApi): apispec_parameter_schemas = { "get_export_ids_schema": get_export_ids_schema, } - openapi_spec_component_schemas = (DatasetRelatedObjectsResponse,) + openapi_spec_component_schemas = ( + DatasetRelatedObjectsResponse, + DatasetSamplesQuerySchema, + ) @expose("/", methods=["POST"]) @protect() @@ -776,7 +781,7 @@ def import_(self) -> Response: command.run() return self.response(200, message="OK") - @expose("//samples") + @expose("//samples", methods=["POST"]) @protect() @safe @statsd_metrics @@ -787,7 +792,7 @@ def import_(self) -> Response: def samples(self, pk: int) -> Response: """get samples from a Dataset --- - get: + post: description: >- get samples from a Dataset parameters: @@ -799,6 +804,13 @@ def samples(self, pk: int) -> Response: schema: type: boolean name: force + requestBody: + description: Filter Schema + required: false + content: + application/json: + schema: + $ref: '#/components/schemas/DatasetSamplesQuerySchema' responses: 200: description: Dataset samples @@ -822,7 +834,15 @@ def samples(self, pk: int) -> Response: """ try: force = parse_boolean_string(request.args.get("force")) - rv = SamplesDatasetCommand(pk, force).run() + page = request.args.get("page") + per_page = request.args.get("per_page") + rv = SamplesDatasetCommand( + pk, + force, + payload=request.json, + page=page, + per_page=per_page, + ).run() response_data = simplejson.dumps( {"result": rv}, default=json_int_dttm_ser, @@ -837,3 +857,5 @@ def samples(self, pk: int) -> Response: return self.response_403() except DatasetSamplesFailedError as ex: return self.response_400(message=str(ex)) + except ValidationError as ex: + return self.response_400(message=str(ex)) diff --git a/superset/datasets/commands/samples.py b/superset/datasets/commands/samples.py index e252cfb62f6c8..410597ceadbb8 100644 --- a/superset/datasets/commands/samples.py +++ b/superset/datasets/commands/samples.py @@ -14,10 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import logging -from typing import Any, Dict, Optional +from typing import Any, cast, Dict, Optional -from superset import security_manager +from superset import app, security_manager from superset.commands.base import BaseCommand from superset.common.chart_data import ChartDataResultType from superset.common.query_context_factory import QueryContextFactory @@ -30,40 +29,79 @@ DatasetSamplesFailedError, ) from superset.datasets.dao import DatasetDAO +from superset.datasets.schemas import DatasetSamplesQuerySchema from superset.exceptions import SupersetSecurityException -from superset.utils.core import QueryStatus - -logger = logging.getLogger(__name__) +from superset.utils.core import DatasourceDict, QueryStatus class SamplesDatasetCommand(BaseCommand): - def __init__(self, model_id: int, force: bool): + def __init__( + self, + model_id: int, + force: bool, + *, + payload: Optional[DatasetSamplesQuerySchema] = None, + page: Optional[int] = None, + per_page: Optional[int] = None, + ): self._model_id = model_id self._force = force self._model: Optional[SqlaTable] = None + self._payload = payload + self._page = page + self._per_page = per_page def run(self) -> Dict[str, Any]: self.validate() - if not self._model: - raise DatasetNotFoundError() + limit_clause = self.get_limit_clause(self._page, self._per_page) + self._model = cast(SqlaTable, self._model) + datasource: DatasourceDict = { + "type": self._model.type, + "id": self._model.id, + } - qc_instance = QueryContextFactory().create( - datasource={ - "type": self._model.type, - "id": self._model.id, - }, - queries=[{}], + # constructing samples query + samples_instance = QueryContextFactory().create( + datasource=datasource, + queries=[ + {**self._payload, **limit_clause} if self._payload else limit_clause + ], result_type=ChartDataResultType.SAMPLES, force=self._force, ) - results = qc_instance.get_payload() + + # constructing count(*) query + count_star_payload = { + "metrics": [ + { + "expressionType": "SQL", + "sqlExpression": "COUNT(*)", + "label": "COUNT(*)", + } + ] + } + count_star_instance = QueryContextFactory().create( + datasource=datasource, + queries=[count_star_payload], + result_type=ChartDataResultType.FULL, + force=self._force, + ) + samples_results = samples_instance.get_payload() + count_star_results = count_star_instance.get_payload() + try: - sample_data = results["queries"][0] - error_msg = sample_data.get("error") - if sample_data.get("status") == QueryStatus.FAILED and error_msg: + sample_data = samples_results["queries"][0] + count_star_data = count_star_results["queries"][0] + failed_status = ( + sample_data.get("status") == QueryStatus.FAILED + or count_star_data.get("status") == QueryStatus.FAILED + ) + error_msg = sample_data.get("error") or count_star_data.get("error") + if failed_status and error_msg: cache_key = sample_data.get("cache_key") QueryCacheManager.delete(cache_key, region=CacheRegion.DATA) raise DatasetSamplesFailedError(error_msg) + sample_data["dataset_count_star"] = count_star_data["data"][0]["COUNT(*)"] return sample_data except (IndexError, KeyError) as exc: raise DatasetSamplesFailedError from exc @@ -78,3 +116,24 @@ def validate(self) -> None: security_manager.raise_for_ownership(self._model) except SupersetSecurityException as ex: raise DatasetForbiddenError() from ex + + @staticmethod + def get_limit_clause( + page: Optional[int], per_page: Optional[int] + ) -> Dict[str, int]: + samples_row_limit = app.config.get("SAMPLES_ROW_LIMIT", 1000) + limit = samples_row_limit + offset = 0 + + if isinstance(page, int) and isinstance(per_page, int): + limit = int(per_page) + if limit < 0 or limit > samples_row_limit: + # reset limit value if input is invalid + limit = samples_row_limit + + offset = (int(page) - 1) * limit + if offset < 0: + # reset offset value if input is invalid + offset = 0 + + return {"row_offset": offset, "row_limit": limit} diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py index 8a44da458f564..7e38ad017db44 100644 --- a/superset/datasets/schemas.py +++ b/superset/datasets/schemas.py @@ -23,6 +23,7 @@ from marshmallow.validate import Length from marshmallow_sqlalchemy import SQLAlchemyAutoSchema +from superset.charts.schemas import ChartDataFilterSchema from superset.datasets.models import Dataset get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}} @@ -231,3 +232,14 @@ class Meta: # pylint: disable=too-few-public-methods model = Dataset load_instance = True include_relationships = True + + +class DatasetSamplesQuerySchema(Schema): + filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False) + + @pre_load + # pylint: disable=no-self-use, unused-argument + def handle_none(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + if data is None: + return {} + return data diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index d8e756e98efaa..0364466188121 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -2101,9 +2101,9 @@ def test_get_dataset_samples(self): # 1. should cache data # feeds data - self.client.get(uri) + self.client.post(uri) # get from cache - rv = self.client.get(uri) + rv = self.client.post(uri) rv_data = json.loads(rv.data) assert rv.status_code == 200 assert "result" in rv_data @@ -2114,9 +2114,9 @@ def test_get_dataset_samples(self): # 2. should through cache uri2 = f"api/v1/dataset/{dataset.id}/samples?force=true" # feeds data - self.client.get(uri2) + self.client.post(uri2) # force query - rv2 = self.client.get(uri2) + rv2 = self.client.post(uri2) rv_data2 = json.loads(rv2.data) assert rv_data2["result"]["cached_dttm"] is None cache_key2 = rv_data2["result"]["cache_key"] @@ -2149,7 +2149,7 @@ def test_get_dataset_samples_with_failed_cc(self): ) uri = f"api/v1/dataset/{dataset.id}/samples" dataset.columns.append(failed_column) - rv = self.client.get(uri) + rv = self.client.post(uri) assert rv.status_code == 400 rv_data = json.loads(rv.data) assert "message" in rv_data @@ -2171,7 +2171,7 @@ def test_get_dataset_samples_on_virtual_dataset(self): self.login(username="admin") uri = f"api/v1/dataset/{virtual_dataset.id}/samples" - rv = self.client.get(uri) + rv = self.client.post(uri) assert rv.status_code == 200 rv_data = json.loads(rv.data) cache_key = rv_data["result"]["cache_key"] @@ -2179,8 +2179,45 @@ def test_get_dataset_samples_on_virtual_dataset(self): # remove original column in dataset virtual_dataset.sql = "SELECT 'foo' as foo" - rv = self.client.get(uri) + rv = self.client.post(uri) + assert rv.status_code == 400 + + db.session.delete(virtual_dataset) + db.session.commit() + + def test_get_dataset_samples_with_filters(self): + virtual_dataset = SqlaTable( + table_name="virtual_dataset", + sql=("SELECT 'foo' as foo, 'bar' as bar UNION ALL SELECT 'foo2', 'bar2'"), + database=get_example_database(), + ) + TableColumn(column_name="foo", type="VARCHAR(255)", table=virtual_dataset) + TableColumn(column_name="bar", type="VARCHAR(255)", table=virtual_dataset) + SqlMetric(metric_name="count", expression="count(*)", table=virtual_dataset) + + self.login(username="admin") + uri = f"api/v1/dataset/{virtual_dataset.id}/samples" + rv = self.client.post(uri, json=None) + assert rv.status_code == 200 + + rv = self.client.post(uri, json={}) + assert rv.status_code == 200 + + rv = self.client.post(uri, json={"foo": "bar"}) assert rv.status_code == 400 + rv = self.client.post( + uri, json={"filters": [{"col": "foo", "op": "INVALID", "val": "foo2"}]} + ) + assert rv.status_code == 400 + + rv = self.client.post( + uri, json={"filters": [{"col": "foo", "op": "==", "val": "foo2"}]} + ) + assert rv.status_code == 200 + rv_data = json.loads(rv.data) + assert rv_data["result"]["colnames"] == ["foo", "bar"] + assert rv_data["result"]["rowcount"] == 1 + db.session.delete(virtual_dataset) db.session.commit() From 3167490293aa4f95bab5f9cd7ed17d9d6bce8072 Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Tue, 19 Jul 2022 17:56:33 +0800 Subject: [PATCH 02/12] move samples to datasource --- .../src/components/Chart/chartAction.js | 2 +- superset/datasets/api.py | 94 +----------- superset/datasets/commands/samples.py | 139 ------------------ superset/datasets/schemas.py | 12 -- superset/datasource/dao.py | 3 + superset/explore/api.py | 84 +---------- superset/explore/commands/samples.py | 93 ------------ superset/views/datasource/schemas.py | 29 +++- superset/views/datasource/utils.py | 116 +++++++++++++++ superset/views/datasource/views.py | 24 +++ 10 files changed, 177 insertions(+), 419 deletions(-) delete mode 100644 superset/datasets/commands/samples.py delete mode 100644 superset/explore/commands/samples.py create mode 100644 superset/views/datasource/utils.py diff --git a/superset-frontend/src/components/Chart/chartAction.js b/superset-frontend/src/components/Chart/chartAction.js index be887a4b3d505..044593eb37461 100644 --- a/superset-frontend/src/components/Chart/chartAction.js +++ b/superset-frontend/src/components/Chart/chartAction.js @@ -604,7 +604,7 @@ export const getDatasourceSamples = async ( force, jsonPayload, ) => { - const endpoint = `/api/v1/explore/samples?force=${force}&datasource_type=${datasourceType}&datasource_id=${datasourceId}`; + const endpoint = `/datasource/samples?force=${force}&datasource_type=${datasourceType}&datasource_id=${datasourceId}`; try { const response = await SupersetClient.post({ endpoint, jsonPayload }); return response.json.result; diff --git a/superset/datasets/api.py b/superset/datasets/api.py index c484260ac7374..e25e8252f9443 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -21,9 +21,8 @@ from typing import Any from zipfile import is_zipfile, ZipFile -import simplejson import yaml -from flask import make_response, request, Response, send_file +from flask import request, Response, send_file from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import ngettext @@ -46,13 +45,11 @@ DatasetInvalidError, DatasetNotFoundError, DatasetRefreshFailedError, - DatasetSamplesFailedError, DatasetUpdateFailedError, ) from superset.datasets.commands.export import ExportDatasetsCommand from superset.datasets.commands.importers.dispatcher import ImportDatasetsCommand from superset.datasets.commands.refresh import RefreshDatasetCommand -from superset.datasets.commands.samples import SamplesDatasetCommand from superset.datasets.commands.update import UpdateDatasetCommand from superset.datasets.dao import DatasetDAO from superset.datasets.filters import DatasetCertifiedFilter, DatasetIsNullOrEmptyFilter @@ -60,12 +57,10 @@ DatasetPostSchema, DatasetPutSchema, DatasetRelatedObjectsResponse, - DatasetSamplesQuerySchema, get_delete_ids_schema, get_export_ids_schema, ) -from superset.exceptions import QueryClauseValidationException -from superset.utils.core import json_int_dttm_ser, parse_boolean_string +from superset.utils.core import parse_boolean_string from superset.views.base import DatasourceFilter, generate_download_headers from superset.views.base_api import ( BaseSupersetModelRestApi, @@ -95,7 +90,6 @@ class DatasetRestApi(BaseSupersetModelRestApi): "bulk_delete", "refresh", "related_objects", - "samples", } list_columns = [ "id", @@ -226,10 +220,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): apispec_parameter_schemas = { "get_export_ids_schema": get_export_ids_schema, } - openapi_spec_component_schemas = ( - DatasetRelatedObjectsResponse, - DatasetSamplesQuerySchema, - ) + openapi_spec_component_schemas = (DatasetRelatedObjectsResponse,) @expose("/", methods=["POST"]) @protect() @@ -780,82 +771,3 @@ def import_(self) -> Response: ) command.run() return self.response(200, message="OK") - - @expose("//samples", methods=["POST"]) - @protect() - @safe - @statsd_metrics - @event_logger.log_this_with_context( - action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.samples", - log_to_statsd=False, - ) - def samples(self, pk: int) -> Response: - """get samples from a Dataset - --- - post: - description: >- - get samples from a Dataset - parameters: - - in: path - schema: - type: integer - name: pk - - in: query - schema: - type: boolean - name: force - requestBody: - description: Filter Schema - required: false - content: - application/json: - schema: - $ref: '#/components/schemas/DatasetSamplesQuerySchema' - responses: - 200: - description: Dataset samples - content: - application/json: - schema: - type: object - properties: - result: - $ref: '#/components/schemas/ChartDataResponseResult' - 401: - $ref: '#/components/responses/401' - 403: - $ref: '#/components/responses/403' - 404: - $ref: '#/components/responses/404' - 422: - $ref: '#/components/responses/422' - 500: - $ref: '#/components/responses/500' - """ - try: - force = parse_boolean_string(request.args.get("force")) - page = request.args.get("page") - per_page = request.args.get("per_page") - rv = SamplesDatasetCommand( - pk, - force, - payload=request.json, - page=page, - per_page=per_page, - ).run() - response_data = simplejson.dumps( - {"result": rv}, - default=json_int_dttm_ser, - ignore_nan=True, - ) - resp = make_response(response_data, 200) - resp.headers["Content-Type"] = "application/json; charset=utf-8" - return resp - except DatasetNotFoundError: - return self.response_404() - except DatasetForbiddenError: - return self.response_403() - except DatasetSamplesFailedError as ex: - return self.response_400(message=str(ex)) - except ValidationError as ex: - return self.response_400(message=str(ex)) diff --git a/superset/datasets/commands/samples.py b/superset/datasets/commands/samples.py deleted file mode 100644 index 410597ceadbb8..0000000000000 --- a/superset/datasets/commands/samples.py +++ /dev/null @@ -1,139 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from typing import Any, cast, Dict, Optional - -from superset import app, security_manager -from superset.commands.base import BaseCommand -from superset.common.chart_data import ChartDataResultType -from superset.common.query_context_factory import QueryContextFactory -from superset.common.utils.query_cache_manager import QueryCacheManager -from superset.connectors.sqla.models import SqlaTable -from superset.constants import CacheRegion -from superset.datasets.commands.exceptions import ( - DatasetForbiddenError, - DatasetNotFoundError, - DatasetSamplesFailedError, -) -from superset.datasets.dao import DatasetDAO -from superset.datasets.schemas import DatasetSamplesQuerySchema -from superset.exceptions import SupersetSecurityException -from superset.utils.core import DatasourceDict, QueryStatus - - -class SamplesDatasetCommand(BaseCommand): - def __init__( - self, - model_id: int, - force: bool, - *, - payload: Optional[DatasetSamplesQuerySchema] = None, - page: Optional[int] = None, - per_page: Optional[int] = None, - ): - self._model_id = model_id - self._force = force - self._model: Optional[SqlaTable] = None - self._payload = payload - self._page = page - self._per_page = per_page - - def run(self) -> Dict[str, Any]: - self.validate() - limit_clause = self.get_limit_clause(self._page, self._per_page) - self._model = cast(SqlaTable, self._model) - datasource: DatasourceDict = { - "type": self._model.type, - "id": self._model.id, - } - - # constructing samples query - samples_instance = QueryContextFactory().create( - datasource=datasource, - queries=[ - {**self._payload, **limit_clause} if self._payload else limit_clause - ], - result_type=ChartDataResultType.SAMPLES, - force=self._force, - ) - - # constructing count(*) query - count_star_payload = { - "metrics": [ - { - "expressionType": "SQL", - "sqlExpression": "COUNT(*)", - "label": "COUNT(*)", - } - ] - } - count_star_instance = QueryContextFactory().create( - datasource=datasource, - queries=[count_star_payload], - result_type=ChartDataResultType.FULL, - force=self._force, - ) - samples_results = samples_instance.get_payload() - count_star_results = count_star_instance.get_payload() - - try: - sample_data = samples_results["queries"][0] - count_star_data = count_star_results["queries"][0] - failed_status = ( - sample_data.get("status") == QueryStatus.FAILED - or count_star_data.get("status") == QueryStatus.FAILED - ) - error_msg = sample_data.get("error") or count_star_data.get("error") - if failed_status and error_msg: - cache_key = sample_data.get("cache_key") - QueryCacheManager.delete(cache_key, region=CacheRegion.DATA) - raise DatasetSamplesFailedError(error_msg) - sample_data["dataset_count_star"] = count_star_data["data"][0]["COUNT(*)"] - return sample_data - except (IndexError, KeyError) as exc: - raise DatasetSamplesFailedError from exc - - def validate(self) -> None: - # Validate/populate model exists - self._model = DatasetDAO.find_by_id(self._model_id) - if not self._model: - raise DatasetNotFoundError() - # Check ownership - try: - security_manager.raise_for_ownership(self._model) - except SupersetSecurityException as ex: - raise DatasetForbiddenError() from ex - - @staticmethod - def get_limit_clause( - page: Optional[int], per_page: Optional[int] - ) -> Dict[str, int]: - samples_row_limit = app.config.get("SAMPLES_ROW_LIMIT", 1000) - limit = samples_row_limit - offset = 0 - - if isinstance(page, int) and isinstance(per_page, int): - limit = int(per_page) - if limit < 0 or limit > samples_row_limit: - # reset limit value if input is invalid - limit = samples_row_limit - - offset = (int(page) - 1) * limit - if offset < 0: - # reset offset value if input is invalid - offset = 0 - - return {"row_offset": offset, "row_limit": limit} diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py index 7e38ad017db44..8a44da458f564 100644 --- a/superset/datasets/schemas.py +++ b/superset/datasets/schemas.py @@ -23,7 +23,6 @@ from marshmallow.validate import Length from marshmallow_sqlalchemy import SQLAlchemyAutoSchema -from superset.charts.schemas import ChartDataFilterSchema from superset.datasets.models import Dataset get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}} @@ -232,14 +231,3 @@ class Meta: # pylint: disable=too-few-public-methods model = Dataset load_instance = True include_relationships = True - - -class DatasetSamplesQuerySchema(Schema): - filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False) - - @pre_load - # pylint: disable=no-self-use, unused-argument - def handle_none(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: - if data is None: - return {} - return data diff --git a/superset/datasource/dao.py b/superset/datasource/dao.py index c8df4c8d8d968..d06da9bf84a34 100644 --- a/superset/datasource/dao.py +++ b/superset/datasource/dao.py @@ -20,10 +20,13 @@ from sqlalchemy.orm import Session +from superset import security_manager from superset.connectors.sqla.models import SqlaTable from superset.dao.base import BaseDAO from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError +from superset.datasets.commands.exceptions import DatasetForbiddenError from superset.datasets.models import Dataset +from superset.exceptions import SupersetSecurityException from superset.models.sql_lab import Query, SavedQuery from superset.tables.models import Table from superset.utils.core import DatasourceType diff --git a/superset/explore/api.py b/superset/explore/api.py index 237eb67dbbe79..7cce592d361fc 100644 --- a/superset/explore/api.py +++ b/superset/explore/api.py @@ -16,22 +16,14 @@ # under the License. import logging -import simplejson -from flask import g, make_response, request, Response +from flask import g, request, Response from flask_appbuilder.api import BaseApi, expose, protect, safe from superset.charts.commands.exceptions import ChartNotFoundError from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod -from superset.dao.exceptions import DatasourceNotFound from superset.explore.commands.get import GetExploreCommand from superset.explore.commands.parameters import CommandParameters -from superset.explore.commands.samples import SamplesDatasourceCommand -from superset.explore.exceptions import ( - DatasetAccessDeniedError, - DatasourceForbiddenError, - DatasourceSamplesFailedError, - WrongEndpointError, -) +from superset.explore.exceptions import DatasetAccessDeniedError, WrongEndpointError from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError from superset.explore.schemas import ExploreContextSchema from superset.extensions import event_logger @@ -39,16 +31,13 @@ TemporaryCacheAccessDeniedError, TemporaryCacheResourceNotFoundError, ) -from superset.utils.core import json_int_dttm_ser, parse_boolean_string logger = logging.getLogger(__name__) class ExploreRestApi(BaseApi): method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP - include_route_methods = {RouteMethod.GET} | { - "samples", - } + include_route_methods = {RouteMethod.GET} allow_browser_login = True class_permission_name = "Explore" resource_name = "explore" @@ -146,70 +135,3 @@ def get(self) -> Response: return self.response(403, message=str(ex)) except TemporaryCacheResourceNotFoundError as ex: return self.response(404, message=str(ex)) - - @expose("/samples", methods=["GET"]) - @protect() - @safe - @event_logger.log_this_with_context( - action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.samples", - log_to_statsd=False, - ) - def samples(self) -> Response: - """get samples from a Datasource - --- - get: - description: >- - get samples from a Datasource - parameters: - - in: path - schema: - type: integer - name: pk - - in: query - schema: - type: boolean - name: force - responses: - 200: - description: Datasource samples - content: - application/json: - schema: - type: object - properties: - result: - $ref: '#/components/schemas/ChartDataResponseResult' - 401: - $ref: '#/components/responses/401' - 403: - $ref: '#/components/responses/403' - 404: - $ref: '#/components/responses/404' - 422: - $ref: '#/components/responses/422' - 500: - $ref: '#/components/responses/500' - """ - try: - force = parse_boolean_string(request.args.get("force")) - rv = SamplesDatasourceCommand( - user=g.user, - datasource_type=request.args.get("datasource_type", type=str), - datasource_id=request.args.get("datasource_id", type=int), - force=force, - ).run() - - response_data = simplejson.dumps( - {"result": rv}, - default=json_int_dttm_ser, - ignore_nan=True, - ) - resp = make_response(response_data, 200) - resp.headers["Content-Type"] = "application/json; charset=utf-8" - return resp - except DatasourceNotFound: - return self.response_404() - except DatasourceForbiddenError: - return self.response_403() - except DatasourceSamplesFailedError as ex: - return self.response_400(message=str(ex)) diff --git a/superset/explore/commands/samples.py b/superset/explore/commands/samples.py deleted file mode 100644 index 7fda5c1bc1509..0000000000000 --- a/superset/explore/commands/samples.py +++ /dev/null @@ -1,93 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from typing import Any, Dict, Optional - -from flask_appbuilder.security.sqla.models import User - -from superset import db, security_manager -from superset.commands.base import BaseCommand -from superset.common.chart_data import ChartDataResultType -from superset.common.query_context_factory import QueryContextFactory -from superset.common.utils.query_cache_manager import QueryCacheManager -from superset.constants import CacheRegion -from superset.dao.exceptions import DatasourceNotFound -from superset.datasource.dao import Datasource, DatasourceDAO -from superset.exceptions import SupersetSecurityException -from superset.explore.exceptions import ( - DatasourceForbiddenError, - DatasourceSamplesFailedError, -) -from superset.utils.core import DatasourceType, QueryStatus - -logger = logging.getLogger(__name__) - - -class SamplesDatasourceCommand(BaseCommand): - def __init__( - self, - user: User, - datasource_id: Optional[int], - datasource_type: Optional[str], - force: bool, - ): - self._actor = user - self._datasource_id = datasource_id - self._datasource_type = datasource_type - self._force = force - self._model: Optional[Datasource] = None - - def run(self) -> Dict[str, Any]: - self.validate() - if not self._model: - raise DatasourceNotFound() - - qc_instance = QueryContextFactory().create( - datasource={ - "type": self._model.type, - "id": self._model.id, - }, - queries=[{}], - result_type=ChartDataResultType.SAMPLES, - force=self._force, - ) - results = qc_instance.get_payload() - try: - sample_data = results["queries"][0] - error_msg = sample_data.get("error") - if sample_data.get("status") == QueryStatus.FAILED and error_msg: - cache_key = sample_data.get("cache_key") - QueryCacheManager.delete(cache_key, region=CacheRegion.DATA) - raise DatasourceSamplesFailedError(error_msg) - return sample_data - except (IndexError, KeyError) as exc: - raise DatasourceSamplesFailedError from exc - - def validate(self) -> None: - # Validate/populate model exists - if self._datasource_type and self._datasource_id: - self._model = DatasourceDAO.get_datasource( - session=db.session, - datasource_type=DatasourceType(self._datasource_type), - datasource_id=self._datasource_id, - ) - - # Check ownership - try: - security_manager.raise_for_ownership(self._model) - except SupersetSecurityException as ex: - raise DatasourceForbiddenError() from ex diff --git a/superset/views/datasource/schemas.py b/superset/views/datasource/schemas.py index 64b2b854bb148..9951aa5fbea95 100644 --- a/superset/views/datasource/schemas.py +++ b/superset/views/datasource/schemas.py @@ -14,11 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any +from typing import Any, Dict -from marshmallow import fields, post_load, Schema +from marshmallow import fields, post_load, pre_load, Schema, validate from typing_extensions import TypedDict +from superset import app +from superset.charts.schemas import ChartDataFilterSchema +from superset.utils.core import DatasourceType + class ExternalMetadataParams(TypedDict): datasource_type: str @@ -54,3 +58,24 @@ def normalize( schema_name=data.get("schema_name", ""), table_name=data["table_name"], ) + + +class SamplesPayloadSchema(Schema): + filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False) + + @pre_load + # pylint: disable=no-self-use, unused-argument + def handle_none(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + if data is None: + return {} + return data + + +class SamplesRequestSchema(Schema): + datasource_type = fields.String( + validate=validate.OneOf([e.value for e in DatasourceType]), required=True + ) + datasource_id = fields.Integer(required=True) + force = fields.Boolean() + page = fields.Integer(load_default=1) + per_page = fields.Integer(load_default=app.config.get("SAMPLES_ROW_LIMIT", 1000)) diff --git a/superset/views/datasource/utils.py b/superset/views/datasource/utils.py new file mode 100644 index 0000000000000..80e1a3c3d1abb --- /dev/null +++ b/superset/views/datasource/utils.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, Optional + +from superset import app, db +from superset.common.chart_data import ChartDataResultType +from superset.common.query_context_factory import QueryContextFactory +from superset.common.utils.query_cache_manager import QueryCacheManager +from superset.constants import CacheRegion +from superset.datasets.commands.exceptions import DatasetSamplesFailedError +from superset.datasource.dao import DatasourceDAO +from superset.utils.core import QueryStatus +from superset.views.datasource.schemas import SamplesPayloadSchema + + +def get_limit_clause(page: Optional[int], per_page: Optional[int]) -> Dict[str, int]: + samples_row_limit = app.config.get("SAMPLES_ROW_LIMIT", 1000) + limit = samples_row_limit + offset = 0 + + if isinstance(page, int) and isinstance(per_page, int): + limit = int(per_page) + if limit < 0 or limit > samples_row_limit: + # reset limit value if input is invalid + limit = samples_row_limit + + offset = (int(page) - 1) * limit + if offset < 0: + # reset offset value if input is invalid + offset = 0 + + return {"row_offset": offset, "row_limit": limit} + + +def get_samples( + datasource_type: str, + datasource_id: int, + force: bool = False, + page: int = 1, + per_page: int = 1000, + payload: Optional[SamplesPayloadSchema] = None, +) -> Dict[str, Any]: + datasource = DatasourceDAO.get_datasource( + session=db.session, + datasource_type=datasource_type, + datasource_id=datasource_id, + ) + + limit_clause = get_limit_clause(page, per_page) + + # constructing samples query + samples_instance = QueryContextFactory().create( + datasource={ + "type": datasource.type, + "id": datasource.id, + }, + queries=[{**payload, **limit_clause} if payload else limit_clause], + result_type=ChartDataResultType.SAMPLES, + force=force, + ) + + # constructing count(*) query + count_star_metric = { + "metrics": [ + { + "expressionType": "SQL", + "sqlExpression": "COUNT(*)", + "label": "COUNT(*)", + } + ] + } + count_star_instance = QueryContextFactory().create( + datasource={ + "type": datasource.type, + "id": datasource.id, + }, + queries=[{**payload, **count_star_metric} if payload else count_star_metric], + result_type=ChartDataResultType.FULL, + force=force, + ) + samples_results = samples_instance.get_payload() + count_star_results = count_star_instance.get_payload() + + try: + sample_data = samples_results["queries"][0] + count_star_data = count_star_results["queries"][0] + failed_status = ( + sample_data.get("status") == QueryStatus.FAILED + or count_star_data.get("status") == QueryStatus.FAILED + ) + error_msg = sample_data.get("error") or count_star_data.get("error") + if failed_status and error_msg: + cache_key = sample_data.get("cache_key") + QueryCacheManager.delete(cache_key, region=CacheRegion.DATA) + raise DatasetSamplesFailedError(error_msg) + + sample_data["page"] = page + sample_data["per_page"] = per_page + sample_data["total_count"] = count_star_data["data"][0]["COUNT(*)"] + return sample_data + except (IndexError, KeyError) as exc: + raise DatasetSamplesFailedError from exc diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py index 4e43068c6f048..60ee4baddcca2 100644 --- a/superset/views/datasource/views.py +++ b/superset/views/datasource/views.py @@ -50,7 +50,10 @@ ExternalMetadataParams, ExternalMetadataSchema, get_external_metadata_schema, + SamplesPayloadSchema, + SamplesRequestSchema, ) +from superset.views.datasource.utils import get_samples from superset.views.utils import sanitize_datasource_data @@ -179,3 +182,24 @@ def external_metadata_by_name(self, **kwargs: Any) -> FlaskResponse: except (NoResultFound, NoSuchTableError) as ex: raise DatasetNotFoundError() from ex return self.json_response(external_metadata) + + @expose("/samples", methods=["POST"]) + @has_access_api + @api + @handle_api_exception + def samples(self) -> FlaskResponse: + try: + params = SamplesRequestSchema().load(request.args) + payload = SamplesPayloadSchema().load(request.json) + except ValidationError as err: + return json_error_response(err.messages, status=400) + + rv = get_samples( + datasource_type=params["datasource_type"], + datasource_id=params["datasource_id"], + force=params["force"], + page=params["page"], + per_page=params["per_page"], + payload=payload, + ) + return self.json_response({"result": rv}) From 63ab07b6972e4257fe235301714328ef292f12b8 Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Tue, 19 Jul 2022 18:09:34 +0800 Subject: [PATCH 03/12] fix lint --- superset/datasource/dao.py | 3 --- superset/views/datasource/utils.py | 7 ++----- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/superset/datasource/dao.py b/superset/datasource/dao.py index d06da9bf84a34..c8df4c8d8d968 100644 --- a/superset/datasource/dao.py +++ b/superset/datasource/dao.py @@ -20,13 +20,10 @@ from sqlalchemy.orm import Session -from superset import security_manager from superset.connectors.sqla.models import SqlaTable from superset.dao.base import BaseDAO from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError -from superset.datasets.commands.exceptions import DatasetForbiddenError from superset.datasets.models import Dataset -from superset.exceptions import SupersetSecurityException from superset.models.sql_lab import Query, SavedQuery from superset.tables.models import Table from superset.utils.core import DatasourceType diff --git a/superset/views/datasource/utils.py b/superset/views/datasource/utils.py index 80e1a3c3d1abb..4bb84ba18f909 100644 --- a/superset/views/datasource/utils.py +++ b/superset/views/datasource/utils.py @@ -38,15 +38,12 @@ def get_limit_clause(page: Optional[int], per_page: Optional[int]) -> Dict[str, # reset limit value if input is invalid limit = samples_row_limit - offset = (int(page) - 1) * limit - if offset < 0: - # reset offset value if input is invalid - offset = 0 + offset = max((int(page) - 1) * limit, 0) return {"row_offset": offset, "row_limit": limit} -def get_samples( +def get_samples( # pylint: disable=too-many-arguments,too-many-locals datasource_type: str, datasource_id: int, force: bool = False, From a8bc3049209fca63a654c816a3ba849f8440b715 Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Tue, 19 Jul 2022 20:24:03 +0800 Subject: [PATCH 04/12] fix test --- superset/views/datasource/schemas.py | 2 +- tests/integration_tests/datasets/api_tests.py | 138 ------------- tests/integration_tests/datasource_tests.py | 182 +++++++++++++++++- 3 files changed, 180 insertions(+), 142 deletions(-) diff --git a/superset/views/datasource/schemas.py b/superset/views/datasource/schemas.py index 9951aa5fbea95..366a093fb50bd 100644 --- a/superset/views/datasource/schemas.py +++ b/superset/views/datasource/schemas.py @@ -76,6 +76,6 @@ class SamplesRequestSchema(Schema): validate=validate.OneOf([e.value for e in DatasourceType]), required=True ) datasource_id = fields.Integer(required=True) - force = fields.Boolean() + force = fields.Boolean(load_default=False) page = fields.Integer(load_default=1) per_page = fields.Integer(load_default=app.config.get("SAMPLES_ROW_LIMIT", 1000)) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 0364466188121..46739f9631b1b 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -27,9 +27,7 @@ import yaml from sqlalchemy.sql import func -from superset.common.utils.query_cache_manager import QueryCacheManager from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn -from superset.constants import CacheRegion from superset.dao.exceptions import ( DAOCreateFailedError, DAODeleteFailedError, @@ -2085,139 +2083,3 @@ def test_get_datasets_is_certified_filter(self): db.session.delete(table_w_certification) db.session.commit() - - @pytest.mark.usefixtures("create_datasets") - def test_get_dataset_samples(self): - """ - Dataset API: Test get dataset samples - """ - if backend() == "sqlite": - return - - dataset = self.get_fixture_datasets()[0] - - self.login(username="admin") - uri = f"api/v1/dataset/{dataset.id}/samples" - - # 1. should cache data - # feeds data - self.client.post(uri) - # get from cache - rv = self.client.post(uri) - rv_data = json.loads(rv.data) - assert rv.status_code == 200 - assert "result" in rv_data - assert rv_data["result"]["cached_dttm"] is not None - cache_key1 = rv_data["result"]["cache_key"] - assert QueryCacheManager.has(cache_key1, region=CacheRegion.DATA) - - # 2. should through cache - uri2 = f"api/v1/dataset/{dataset.id}/samples?force=true" - # feeds data - self.client.post(uri2) - # force query - rv2 = self.client.post(uri2) - rv_data2 = json.loads(rv2.data) - assert rv_data2["result"]["cached_dttm"] is None - cache_key2 = rv_data2["result"]["cache_key"] - assert QueryCacheManager.has(cache_key2, region=CacheRegion.DATA) - - # 3. data precision - assert "colnames" in rv_data2["result"] - assert "coltypes" in rv_data2["result"] - assert "data" in rv_data2["result"] - - eager_samples = dataset.database.get_df( - f"select * from {dataset.table_name}" - f' limit {self.app.config["SAMPLES_ROW_LIMIT"]}' - ).to_dict(orient="records") - assert eager_samples == rv_data2["result"]["data"] - - @pytest.mark.usefixtures("create_datasets") - def test_get_dataset_samples_with_failed_cc(self): - if backend() == "sqlite": - return - - dataset = self.get_fixture_datasets()[0] - - self.login(username="admin") - failed_column = TableColumn( - column_name="DUMMY CC", - type="VARCHAR(255)", - table=dataset, - expression="INCORRECT SQL", - ) - uri = f"api/v1/dataset/{dataset.id}/samples" - dataset.columns.append(failed_column) - rv = self.client.post(uri) - assert rv.status_code == 400 - rv_data = json.loads(rv.data) - assert "message" in rv_data - if dataset.database.db_engine_spec.engine_name == "PostgreSQL": - assert "INCORRECT SQL" in rv_data.get("message") - - def test_get_dataset_samples_on_virtual_dataset(self): - if backend() == "sqlite": - return - - virtual_dataset = SqlaTable( - table_name="virtual_dataset", - sql=("SELECT 'foo' as foo, 'bar' as bar"), - database=get_example_database(), - ) - TableColumn(column_name="foo", type="VARCHAR(255)", table=virtual_dataset) - TableColumn(column_name="bar", type="VARCHAR(255)", table=virtual_dataset) - SqlMetric(metric_name="count", expression="count(*)", table=virtual_dataset) - - self.login(username="admin") - uri = f"api/v1/dataset/{virtual_dataset.id}/samples" - rv = self.client.post(uri) - assert rv.status_code == 200 - rv_data = json.loads(rv.data) - cache_key = rv_data["result"]["cache_key"] - assert QueryCacheManager.has(cache_key, region=CacheRegion.DATA) - - # remove original column in dataset - virtual_dataset.sql = "SELECT 'foo' as foo" - rv = self.client.post(uri) - assert rv.status_code == 400 - - db.session.delete(virtual_dataset) - db.session.commit() - - def test_get_dataset_samples_with_filters(self): - virtual_dataset = SqlaTable( - table_name="virtual_dataset", - sql=("SELECT 'foo' as foo, 'bar' as bar UNION ALL SELECT 'foo2', 'bar2'"), - database=get_example_database(), - ) - TableColumn(column_name="foo", type="VARCHAR(255)", table=virtual_dataset) - TableColumn(column_name="bar", type="VARCHAR(255)", table=virtual_dataset) - SqlMetric(metric_name="count", expression="count(*)", table=virtual_dataset) - - self.login(username="admin") - uri = f"api/v1/dataset/{virtual_dataset.id}/samples" - rv = self.client.post(uri, json=None) - assert rv.status_code == 200 - - rv = self.client.post(uri, json={}) - assert rv.status_code == 200 - - rv = self.client.post(uri, json={"foo": "bar"}) - assert rv.status_code == 400 - - rv = self.client.post( - uri, json={"filters": [{"col": "foo", "op": "INVALID", "val": "foo2"}]} - ) - assert rv.status_code == 400 - - rv = self.client.post( - uri, json={"filters": [{"col": "foo", "op": "==", "val": "foo2"}]} - ) - assert rv.status_code == 200 - rv_data = json.loads(rv.data) - assert rv_data["result"]["colnames"] == ["foo", "bar"] - assert rv_data["result"]["rowcount"] == 1 - - db.session.delete(virtual_dataset) - db.session.commit() diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 8e4d269b20126..7482a9841d440 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -23,13 +23,15 @@ import pytest from superset import app, db -from superset.connectors.sqla.models import SqlaTable +from superset.common.utils.query_cache_manager import QueryCacheManager +from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn +from superset.constants import CacheRegion from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.exceptions import SupersetGenericDBErrorException from superset.models.core import Database -from superset.utils.core import DatasourceType, get_example_default_schema -from superset.utils.database import get_example_database +from superset.utils.core import backend, get_example_default_schema +from superset.utils.database import get_example_database, get_main_database from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, @@ -416,3 +418,177 @@ def test_get_datasource_invalid_datasource_failed(self): self.login(username="admin") resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False) self.assertEqual(resp.get("error"), "'druid' is not a valid DatasourceType") + + @pytest.fixture() + def create_datasets(self): + with self.create_app().app_context(): + if backend() == "sqlite": + yield + return + + datasets = [] + admin = self.get_user("admin") + main_db = get_main_database() + for tables_name in self.fixture_tables_names: + datasets.append(self.insert_dataset(tables_name, [admin.id], main_db)) + + yield datasets + + # rollback changes + for dataset in datasets: + db.session.delete(dataset) + db.session.commit() + + def test_get_dataset_samples(self): + """ + Dataset API: Test get dataset samples + """ + if backend() == "sqlite": + return + + dataset = SqlaTable( + table_name="virtual_dataset", + sql=("SELECT 'foo' as foo, 'bar' as bar"), + database=get_example_database(), + ) + TableColumn(column_name="foo", type="VARCHAR(255)", table=dataset) + TableColumn(column_name="bar", type="VARCHAR(255)", table=dataset) + SqlMetric(metric_name="count", expression="count(*)", table=dataset) + + self.login(username="admin") + uri = f"datasource/samples?datasource_id={dataset.id}&datasource_type=table" + + # 1. should cache data + # feeds data + self.client.post(uri) + # get from cache + rv = self.client.post(uri) + rv_data = json.loads(rv.data) + assert rv.status_code == 200 + assert "result" in rv_data + assert rv_data["result"]["cached_dttm"] is not None + cache_key1 = rv_data["result"]["cache_key"] + assert QueryCacheManager.has(cache_key1, region=CacheRegion.DATA) + + # 2. should through cache + uri2 = f"datasource/samples?datasource_id={dataset.id}&datasource_type=table&force=true" + # feeds data + self.client.post(uri2) + # force query + rv2 = self.client.post(uri2) + rv_data2 = json.loads(rv2.data) + assert rv_data2["result"]["cached_dttm"] is None + cache_key2 = rv_data2["result"]["cache_key"] + assert QueryCacheManager.has(cache_key2, region=CacheRegion.DATA) + + # 3. data precision + assert "colnames" in rv_data2["result"] + assert "coltypes" in rv_data2["result"] + assert "data" in rv_data2["result"] + + eager_samples = dataset.database.get_df( + f"select * from ({dataset.sql}) as tbl" + f' limit {self.app.config["SAMPLES_ROW_LIMIT"]}' + ).to_dict(orient="records") + assert eager_samples == rv_data2["result"]["data"] + + db.session.delete(dataset) + db.session.commit() + + def test_get_dataset_samples_with_failed_cc(self): + if backend() == "sqlite": + return + + dataset = SqlaTable( + table_name="virtual_dataset", + sql=("SELECT 'foo' as foo, 'bar' as bar"), + database=get_example_database(), + ) + TableColumn(column_name="foo", type="VARCHAR(255)", table=dataset) + TableColumn(column_name="bar", type="VARCHAR(255)", table=dataset) + SqlMetric(metric_name="count", expression="count(*)", table=dataset) + + self.login(username="admin") + failed_column = TableColumn( + column_name="DUMMY CC", + type="VARCHAR(255)", + table=dataset, + expression="INCORRECT SQL", + ) + uri = f"datasource/samples?datasource_id={dataset.id}&datasource_type=table" + dataset.columns.append(failed_column) + rv = self.client.post(uri) + assert rv.status_code == 422 + rv_data = json.loads(rv.data) + assert "error" in rv_data + if dataset.database.db_engine_spec.engine_name == "PostgreSQL": + assert "INCORRECT SQL" in rv_data.get("error") + + db.session.delete(dataset) + db.session.commit() + + def test_get_datasource_samples_on_virtual_dataset(self): + if backend() == "sqlite": + return + + virtual_dataset = SqlaTable( + table_name="virtual_dataset", + sql=("SELECT 'foo' as foo, 'bar' as bar"), + database=get_example_database(), + ) + TableColumn(column_name="foo", type="VARCHAR(255)", table=virtual_dataset) + TableColumn(column_name="bar", type="VARCHAR(255)", table=virtual_dataset) + SqlMetric(metric_name="count", expression="count(*)", table=virtual_dataset) + + self.login(username="admin") + uri = f"datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table" + rv = self.client.post(uri) + assert rv.status_code == 200 + rv_data = json.loads(rv.data) + cache_key = rv_data["result"]["cache_key"] + assert QueryCacheManager.has(cache_key, region=CacheRegion.DATA) + + # remove original column in dataset + virtual_dataset.sql = "SELECT 'foo' as foo" + rv = self.client.post(uri) + assert rv.status_code == 422 + + db.session.delete(virtual_dataset) + db.session.commit() + + def test_get_datasource_samples_with_filters(self): + virtual_dataset = SqlaTable( + table_name="virtual_dataset", + sql=("SELECT 'foo' as foo, 'bar' as bar UNION ALL SELECT 'foo2', 'bar2'"), + database=get_example_database(), + ) + TableColumn(column_name="foo", type="VARCHAR(255)", table=virtual_dataset) + TableColumn(column_name="bar", type="VARCHAR(255)", table=virtual_dataset) + SqlMetric(metric_name="count", expression="count(*)", table=virtual_dataset) + + self.login(username="admin") + uri = f"datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table" + rv = self.client.post(uri, json=None) + assert rv.status_code == 200 + + rv = self.client.post(uri, json={}) + assert rv.status_code == 200 + + rv = self.client.post(uri, json={"foo": "bar"}) + assert rv.status_code == 400 + + rv = self.client.post( + uri, json={"filters": [{"col": "foo", "op": "INVALID", "val": "foo2"}]} + ) + assert rv.status_code == 400 + + rv = self.client.post( + uri, json={"filters": [{"col": "foo", "op": "==", "val": "foo2"}]} + ) + assert rv.status_code == 200 + rv_data = json.loads(rv.data) + assert rv_data["result"]["colnames"] == ["foo", "bar"] + assert rv_data["result"]["rowcount"] == 1 + + db.session.delete(virtual_dataset) + db.session.commit() From bac09f626fbce22824e818521401f0d0d6c06775 Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Tue, 19 Jul 2022 20:27:33 +0800 Subject: [PATCH 05/12] remove unused function --- tests/integration_tests/datasource_tests.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 7482a9841d440..0cb7e80c70f99 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -419,26 +419,6 @@ def test_get_datasource_invalid_datasource_failed(self): resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False) self.assertEqual(resp.get("error"), "'druid' is not a valid DatasourceType") - @pytest.fixture() - def create_datasets(self): - with self.create_app().app_context(): - if backend() == "sqlite": - yield - return - - datasets = [] - admin = self.get_user("admin") - main_db = get_main_database() - for tables_name in self.fixture_tables_names: - datasets.append(self.insert_dataset(tables_name, [admin.id], main_db)) - - yield datasets - - # rollback changes - for dataset in datasets: - db.session.delete(dataset) - db.session.commit() - def test_get_dataset_samples(self): """ Dataset API: Test get dataset samples From 1a7b381f7faef61b868255200598e307a93d48ae Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Tue, 19 Jul 2022 21:20:13 +0800 Subject: [PATCH 06/12] fix frontend --- .../cypress/integration/explore/control.test.ts | 2 +- .../components/DataTablesPane/test/SamplesPane.test.tsx | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/superset-frontend/cypress-base/cypress/integration/explore/control.test.ts b/superset-frontend/cypress-base/cypress/integration/explore/control.test.ts index f1adec9e4488f..a4b85de4deecc 100644 --- a/superset-frontend/cypress-base/cypress/integration/explore/control.test.ts +++ b/superset-frontend/cypress-base/cypress/integration/explore/control.test.ts @@ -129,7 +129,7 @@ describe('Test datatable', () => { }); it('Datapane loads view samples', () => { cy.intercept( - 'api/v1/explore/samples?force=false&datasource_type=table&datasource_id=*', + 'datasource/samples?force=false&datasource_type=table&datasource_id=*', ).as('Samples'); cy.contains('Samples') .click() diff --git a/superset-frontend/src/explore/components/DataTablesPane/test/SamplesPane.test.tsx b/superset-frontend/src/explore/components/DataTablesPane/test/SamplesPane.test.tsx index 6ea8fafe378e8..391540f4d8539 100644 --- a/superset-frontend/src/explore/components/DataTablesPane/test/SamplesPane.test.tsx +++ b/superset-frontend/src/explore/components/DataTablesPane/test/SamplesPane.test.tsx @@ -30,7 +30,7 @@ import { createSamplesPaneProps } from './fixture'; describe('SamplesPane', () => { fetchMock.post( - '/api/v1/explore/samples?force=false&datasource_type=table&datasource_id=34', + 'end:/datasource/samples?force=false&datasource_type=table&datasource_id=34', { result: { data: [], @@ -41,7 +41,7 @@ describe('SamplesPane', () => { ); fetchMock.post( - 'end:/api/v1/explore/samples?force=true&datasource_type=table&datasource_id=35', + 'end:/datasource/samples?force=true&datasource_type=table&datasource_id=35', { result: { data: [ @@ -55,7 +55,7 @@ describe('SamplesPane', () => { ); fetchMock.post( - 'end:/api/v1/explore/samples?force=false&datasource_type=table&datasource_id=36', + 'end:/datasource/samples?force=false&datasource_type=table&datasource_id=36', 400, ); From f90308ce0072065a4ea1dfc4836c0d4dd95021fa Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Thu, 21 Jul 2022 17:02:29 +0800 Subject: [PATCH 07/12] more unit tst --- superset/views/datasource/schemas.py | 5 +- superset/views/datasource/utils.py | 1 + tests/integration_tests/conftest.py | 97 ++++++ tests/integration_tests/datasource_tests.py | 335 +++++++++++--------- 4 files changed, 293 insertions(+), 145 deletions(-) diff --git a/superset/views/datasource/schemas.py b/superset/views/datasource/schemas.py index 366a093fb50bd..4c97f17e88fe8 100644 --- a/superset/views/datasource/schemas.py +++ b/superset/views/datasource/schemas.py @@ -78,4 +78,7 @@ class SamplesRequestSchema(Schema): datasource_id = fields.Integer(required=True) force = fields.Boolean(load_default=False) page = fields.Integer(load_default=1) - per_page = fields.Integer(load_default=app.config.get("SAMPLES_ROW_LIMIT", 1000)) + per_page = fields.Integer( + validate=validate.Range(min=1, max=app.config.get("SAMPLES_ROW_LIMIT", 1000)), + load_default=app.config.get("SAMPLES_ROW_LIMIT", 1000), + ) diff --git a/superset/views/datasource/utils.py b/superset/views/datasource/utils.py index 4bb84ba18f909..3c0fb7abfeaa4 100644 --- a/superset/views/datasource/utils.py +++ b/superset/views/datasource/utils.py @@ -59,6 +59,7 @@ def get_samples( # pylint: disable=too-many-arguments,too-many-locals limit_clause = get_limit_clause(page, per_page) + # todo(yongjie): Constructing count(*) and samples in the same query_context, then remove query_type==SAMPLES # constructing samples query samples_instance = QueryContextFactory().create( datasource={ diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index ea46039d8412e..5d7bf1acaf3dc 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -206,3 +206,100 @@ def wrapper(*args, **kwargs): return functools.update_wrapper(wrapper, test_fn) return decorate + + +@pytest.fixture +def virtual_dataset(): + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + + dataset = SqlaTable( + table_name="virtual_dataset", + sql=( + "SELECT 0 as col1, 'a' as col2, 1.0 as col3, NULL as col4, '2000-01-01 00:00:00' as col5 " + "UNION ALL " + "SELECT 1, 'b', 1.1, NULL, '2000-01-02 00:00:00' " + "UNION ALL " + "SELECT 2 as col1, 'c' as col2, 1.2, NULL, '2000-01-03 00:00:00' " + "UNION ALL " + "SELECT 3 as col1, 'd' as col2, 1.3, NULL, '2000-01-04 00:00:00' " + "UNION ALL " + "SELECT 4 as col1, 'e' as col2, 1.4, NULL, '2000-01-05 00:00:00' " + "UNION ALL " + "SELECT 5 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00' " + "UNION ALL " + "SELECT 6 as col1, 'g' as col2, 1.6, NULL, '2000-01-07 00:00:00' " + "UNION ALL " + "SELECT 7 as col1, 'h' as col2, 1.7, NULL, '2000-01-08 00:00:00' " + "UNION ALL " + "SELECT 8 as col1, 'i' as col2, 1.8, NULL, '2000-01-09 00:00:00' " + "UNION ALL " + "SELECT 9 as col1, 'j' as col2, 1.9, NULL, '2000-01-10 00:00:00' " + ), + database=get_example_database(), + ) + TableColumn(column_name="col1", type="INTEGER", table=dataset) + TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset) + TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset) + TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset) + # Different database dialect datetime type is not consistent, so temporarily use varchar + TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset) + + SqlMetric(metric_name="count", expression="count(*)", table=dataset) + db.session.merge(dataset) + + yield dataset + + db.session.delete(dataset) + db.session.commit() + + +@pytest.fixture +def physical_dataset(): + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + + example_database = get_example_database() + engine = example_database.get_sqla_engine() + engine.execute( + """ + CREATE TABLE IF NOT EXISTS physical_dataset( + col1 INTEGER, + col2 VARCHAR(255), + col3 DECIMAL(4,2), + col4 VARCHAR(255), + col5 VARCHAR(255) + ); + INSERT INTO physical_dataset values + (0, 'a', 1.0, NULL, '2000-01-01 00:00:00'), + (1, 'b', 1.1, NULL, '2000-01-02 00:00:00'), + (2, 'c', 1.2, NULL, '2000-01-03 00:00:00'), + (3, 'd', 1.3, NULL, '2000-01-04 00:00:00'), + (4, 'e', 1.4, NULL, '2000-01-05 00:00:00'), + (5, 'f', 1.5, NULL, '2000-01-06 00:00:00'), + (6, 'g', 1.6, NULL, '2000-01-07 00:00:00'), + (7, 'h', 1.7, NULL, '2000-01-08 00:00:00'), + (8, 'i', 1.8, NULL, '2000-01-09 00:00:00'), + (9, 'j', 1.9, NULL, '2000-01-10 00:00:00'); + """ + ) + + dataset = SqlaTable( + table_name="physical_dataset", + database=example_database, + ) + TableColumn(column_name="col1", type="INTEGER", table=dataset) + TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset) + TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset) + TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset) + TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset) + SqlMetric(metric_name="count", expression="count(*)", table=dataset) + db.session.merge(dataset) + + yield dataset + + engine.execute( + """ + DROP TABLE physical_dataset; + """ + ) + db.session.delete(dataset) + db.session.commit() diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 0cb7e80c70f99..13abed8dd95ba 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -419,156 +419,203 @@ def test_get_datasource_invalid_datasource_failed(self): resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False) self.assertEqual(resp.get("error"), "'druid' is not a valid DatasourceType") - def test_get_dataset_samples(self): - """ - Dataset API: Test get dataset samples - """ - if backend() == "sqlite": - return - - dataset = SqlaTable( - table_name="virtual_dataset", - sql=("SELECT 'foo' as foo, 'bar' as bar"), - database=get_example_database(), - ) - TableColumn(column_name="foo", type="VARCHAR(255)", table=dataset) - TableColumn(column_name="bar", type="VARCHAR(255)", table=dataset) - SqlMetric(metric_name="count", expression="count(*)", table=dataset) - self.login(username="admin") - uri = f"datasource/samples?datasource_id={dataset.id}&datasource_type=table" - - # 1. should cache data - # feeds data - self.client.post(uri) - # get from cache - rv = self.client.post(uri) - rv_data = json.loads(rv.data) - assert rv.status_code == 200 - assert "result" in rv_data - assert rv_data["result"]["cached_dttm"] is not None - cache_key1 = rv_data["result"]["cache_key"] - assert QueryCacheManager.has(cache_key1, region=CacheRegion.DATA) - - # 2. should through cache - uri2 = f"datasource/samples?datasource_id={dataset.id}&datasource_type=table&force=true" - # feeds data - self.client.post(uri2) - # force query - rv2 = self.client.post(uri2) - rv_data2 = json.loads(rv2.data) - assert rv_data2["result"]["cached_dttm"] is None - cache_key2 = rv_data2["result"]["cache_key"] - assert QueryCacheManager.has(cache_key2, region=CacheRegion.DATA) - - # 3. data precision - assert "colnames" in rv_data2["result"] - assert "coltypes" in rv_data2["result"] - assert "data" in rv_data2["result"] - - eager_samples = dataset.database.get_df( - f"select * from ({dataset.sql}) as tbl" - f' limit {self.app.config["SAMPLES_ROW_LIMIT"]}' - ).to_dict(orient="records") - assert eager_samples == rv_data2["result"]["data"] - - db.session.delete(dataset) - db.session.commit() - - def test_get_dataset_samples_with_failed_cc(self): - if backend() == "sqlite": - return - - dataset = SqlaTable( - table_name="virtual_dataset", - sql=("SELECT 'foo' as foo, 'bar' as bar"), - database=get_example_database(), - ) - TableColumn(column_name="foo", type="VARCHAR(255)", table=dataset) - TableColumn(column_name="bar", type="VARCHAR(255)", table=dataset) - SqlMetric(metric_name="count", expression="count(*)", table=dataset) +def test_get_samples(test_client, login_as_admin, virtual_dataset): + """ + Dataset API: Test get dataset samples + """ + if backend() == "sqlite": + return + + uri = f"datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table" + + # 1. should cache data + # feeds data + test_client.post(uri) + # get from cache + rv = test_client.post(uri) + rv_data = json.loads(rv.data) + assert rv.status_code == 200 + assert len(rv_data["result"]["data"]) == 10 + assert QueryCacheManager.has( + rv_data["result"]["cache_key"], + region=CacheRegion.DATA, + ) + assert rv_data["result"]["is_cached"] + + # 2. should read through cache data + uri2 = f"datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&force=true" + # feeds data + test_client.post(uri2) + # force query + rv2 = test_client.post(uri2) + rv_data2 = json.loads(rv2.data) + assert rv2.status_code == 200 + assert len(rv_data2["result"]["data"]) == 10 + assert QueryCacheManager.has( + rv_data2["result"]["cache_key"], + region=CacheRegion.DATA, + ) + assert not rv_data2["result"]["is_cached"] - self.login(username="admin") - failed_column = TableColumn( - column_name="DUMMY CC", - type="VARCHAR(255)", - table=dataset, - expression="INCORRECT SQL", - ) - uri = f"datasource/samples?datasource_id={dataset.id}&datasource_type=table" - dataset.columns.append(failed_column) - rv = self.client.post(uri) - assert rv.status_code == 422 - rv_data = json.loads(rv.data) - assert "error" in rv_data - if dataset.database.db_engine_spec.engine_name == "PostgreSQL": - assert "INCORRECT SQL" in rv_data.get("error") - - db.session.delete(dataset) - db.session.commit() - - def test_get_datasource_samples_on_virtual_dataset(self): - if backend() == "sqlite": - return - - virtual_dataset = SqlaTable( - table_name="virtual_dataset", - sql=("SELECT 'foo' as foo, 'bar' as bar"), - database=get_example_database(), - ) - TableColumn(column_name="foo", type="VARCHAR(255)", table=virtual_dataset) - TableColumn(column_name="bar", type="VARCHAR(255)", table=virtual_dataset) - SqlMetric(metric_name="count", expression="count(*)", table=virtual_dataset) + # 3. data precision + assert "colnames" in rv_data2["result"] + assert "coltypes" in rv_data2["result"] + assert "data" in rv_data2["result"] - self.login(username="admin") - uri = f"datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table" - rv = self.client.post(uri) - assert rv.status_code == 200 - rv_data = json.loads(rv.data) - cache_key = rv_data["result"]["cache_key"] - assert QueryCacheManager.has(cache_key, region=CacheRegion.DATA) - - # remove original column in dataset - virtual_dataset.sql = "SELECT 'foo' as foo" - rv = self.client.post(uri) - assert rv.status_code == 422 - - db.session.delete(virtual_dataset) - db.session.commit() - - def test_get_datasource_samples_with_filters(self): - virtual_dataset = SqlaTable( - table_name="virtual_dataset", - sql=("SELECT 'foo' as foo, 'bar' as bar UNION ALL SELECT 'foo2', 'bar2'"), - database=get_example_database(), - ) - TableColumn(column_name="foo", type="VARCHAR(255)", table=virtual_dataset) - TableColumn(column_name="bar", type="VARCHAR(255)", table=virtual_dataset) - SqlMetric(metric_name="count", expression="count(*)", table=virtual_dataset) + eager_samples = virtual_dataset.database.get_df( + f"select * from ({virtual_dataset.sql}) as tbl" + f' limit {app.config["SAMPLES_ROW_LIMIT"]}' + ) + # the col3 is Decimal + eager_samples["col3"] = eager_samples["col3"].apply(float) + eager_samples = eager_samples.to_dict(orient="records") + assert eager_samples == rv_data2["result"]["data"] - self.login(username="admin") - uri = f"datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table" - rv = self.client.post(uri, json=None) - assert rv.status_code == 200 - rv = self.client.post(uri, json={}) - assert rv.status_code == 200 +def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset): + if backend() == "sqlite": + return - rv = self.client.post(uri, json={"foo": "bar"}) - assert rv.status_code == 400 + TableColumn( + column_name="DUMMY CC", + type="VARCHAR(255)", + table=virtual_dataset, + expression="INCORRECT SQL", + ) + db.session.merge(virtual_dataset) - rv = self.client.post( - uri, json={"filters": [{"col": "foo", "op": "INVALID", "val": "foo2"}]} - ) - assert rv.status_code == 400 + uri = f"datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table" + rv = test_client.post(uri) + assert rv.status_code == 422 - rv = self.client.post( - uri, json={"filters": [{"col": "foo", "op": "==", "val": "foo2"}]} - ) - assert rv.status_code == 200 - rv_data = json.loads(rv.data) - assert rv_data["result"]["colnames"] == ["foo", "bar"] - assert rv_data["result"]["rowcount"] == 1 + rv_data = json.loads(rv.data) + assert "error" in rv_data + if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL": + assert "INCORRECT SQL" in rv_data.get("error") + + +def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset): + if backend() == "sqlite": + return + + uri = ( + f"datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table" + ) + rv = test_client.post(uri) + assert rv.status_code == 200 + rv_data = json.loads(rv.data) + assert QueryCacheManager.has( + rv_data["result"]["cache_key"], region=CacheRegion.DATA + ) + assert len(rv_data["result"]["data"]) == 10 + + +def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset): + uri = ( + f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table" + ) + rv = test_client.post(uri, json=None) + assert rv.status_code == 200 + + rv = test_client.post(uri, json={}) + assert rv.status_code == 200 + + rv = test_client.post(uri, json={"foo": "bar"}) + assert rv.status_code == 400 + + rv = test_client.post( + uri, json={"filters": [{"col": "col1", "op": "INVALID", "val": 0}]} + ) + assert rv.status_code == 400 + + rv = test_client.post( + uri, + json={ + "filters": [ + {"col": "col2", "op": "==", "val": "a"}, + {"col": "col1", "op": "==", "val": 0}, + ] + }, + ) + assert rv.status_code == 200 + rv_data = json.loads(rv.data) + assert rv_data["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"] + assert rv_data["result"]["rowcount"] == 1 + + # empty results + rv = test_client.post( + uri, + json={ + "filters": [ + {"col": "col2", "op": "==", "val": "x"}, + ] + }, + ) + assert rv.status_code == 200 + rv_data = json.loads(rv.data) + assert rv_data["result"]["colnames"] == [] + assert rv_data["result"]["rowcount"] == 0 + + +def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset): + # 1. default page, per_page and total_count + uri = ( + f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table" + ) + rv = test_client.post(uri) + rv_data = json.loads(rv.data) + assert rv_data["result"]["page"] == 1 + assert rv_data["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"] + assert rv_data["result"]["total_count"] == 10 + + # 2. incorrect per_page + per_pages = (app.config["SAMPLES_ROW_LIMIT"] + 1, 0, "xx") + for per_page in per_pages: + uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page={per_page}" + rv = test_client.post(uri) + assert rv.status_code == 400 - db.session.delete(virtual_dataset) - db.session.commit() + # 3. incorrect page or datasource_type + uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&page=xx" + rv = test_client.post(uri) + assert rv.status_code == 400 + + uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=xx" + rv = test_client.post(uri) + assert rv.status_code == 400 + + # 4. turning pages + uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=1" + rv = test_client.post(uri) + rv_data = json.loads(rv.data) + assert rv_data["result"]["page"] == 1 + assert rv_data["result"]["per_page"] == 2 + assert rv_data["result"]["total_count"] == 10 + assert [row["col1"] for row in rv_data["result"]["data"]] == [0, 1] + + uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=2" + rv = test_client.post(uri) + rv_data = json.loads(rv.data) + assert rv_data["result"]["page"] == 2 + assert rv_data["result"]["per_page"] == 2 + assert rv_data["result"]["total_count"] == 10 + assert [row["col1"] for row in rv_data["result"]["data"]] == [2, 3] + + # 4. turning pages + uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=1" + rv = test_client.post(uri) + rv_data = json.loads(rv.data) + assert rv_data["result"]["page"] == 1 + assert rv_data["result"]["per_page"] == 2 + assert rv_data["result"]["total_count"] == 10 + assert [row["col1"] for row in rv_data["result"]["data"]] == [0, 1] + + # 5. Exceeding the maximum pages + uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=6" + rv = test_client.post(uri) + rv_data = json.loads(rv.data) + assert rv_data["result"]["page"] == 6 + assert rv_data["result"]["per_page"] == 2 + assert rv_data["result"]["total_count"] == 10 + assert [row["col1"] for row in rv_data["result"]["data"]] == [] From 09df8bb24b40b9bc371b515b82bdb7d28e067bef Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Thu, 21 Jul 2022 17:04:44 +0800 Subject: [PATCH 08/12] more unit tst --- tests/integration_tests/datasource_tests.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 13abed8dd95ba..8e5cc71013827 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -602,15 +602,6 @@ def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset): assert rv_data["result"]["total_count"] == 10 assert [row["col1"] for row in rv_data["result"]["data"]] == [2, 3] - # 4. turning pages - uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=1" - rv = test_client.post(uri) - rv_data = json.loads(rv.data) - assert rv_data["result"]["page"] == 1 - assert rv_data["result"]["per_page"] == 2 - assert rv_data["result"]["total_count"] == 10 - assert [row["col1"] for row in rv_data["result"]["data"]] == [0, 1] - # 5. Exceeding the maximum pages uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=6" rv = test_client.post(uri) From ffdea1084b656e80bd9d6636c6ea4ed94a29c9de Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Thu, 21 Jul 2022 17:07:59 +0800 Subject: [PATCH 09/12] add slash on url --- tests/integration_tests/datasource_tests.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 8e5cc71013827..4501d4e159b66 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -427,7 +427,9 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset): if backend() == "sqlite": return - uri = f"datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table" + uri = ( + f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table" + ) # 1. should cache data # feeds data @@ -444,7 +446,7 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset): assert rv_data["result"]["is_cached"] # 2. should read through cache data - uri2 = f"datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&force=true" + uri2 = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&force=true" # feeds data test_client.post(uri2) # force query @@ -485,7 +487,9 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data ) db.session.merge(virtual_dataset) - uri = f"datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table" + uri = ( + f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table" + ) rv = test_client.post(uri) assert rv.status_code == 422 @@ -500,7 +504,7 @@ def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_d return uri = ( - f"datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table" + f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table" ) rv = test_client.post(uri) assert rv.status_code == 200 From 1cd99cd073d16236bc666ce0235ae22fa31badbe Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Thu, 21 Jul 2022 17:10:31 +0800 Subject: [PATCH 10/12] remove blank --- tests/integration_tests/datasource_tests.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 4501d4e159b66..47166135d77c6 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -427,11 +427,10 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset): if backend() == "sqlite": return + # 1. should cache data uri = ( f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table" ) - - # 1. should cache data # feeds data test_client.post(uri) # get from cache From 8bda88f76d66b0fd2e9ae5ef05c851a7eee5ab0a Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Thu, 21 Jul 2022 17:40:17 +0800 Subject: [PATCH 11/12] pylint --- superset/views/datasource/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/superset/views/datasource/utils.py b/superset/views/datasource/utils.py index 3c0fb7abfeaa4..0191db2947c26 100644 --- a/superset/views/datasource/utils.py +++ b/superset/views/datasource/utils.py @@ -59,7 +59,8 @@ def get_samples( # pylint: disable=too-many-arguments,too-many-locals limit_clause = get_limit_clause(page, per_page) - # todo(yongjie): Constructing count(*) and samples in the same query_context, then remove query_type==SAMPLES + # todo(yongjie): Constructing count(*) and samples in the same query_context, + # then remove query_type==SAMPLES # constructing samples query samples_instance = QueryContextFactory().create( datasource={ From 3d8cd901bdcc1a94d627e4841cc1b8b9b3c26c34 Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Thu, 21 Jul 2022 18:10:57 +0800 Subject: [PATCH 12/12] support sqlite --- tests/integration_tests/conftest.py | 7 +++++++ tests/integration_tests/datasource_tests.py | 9 --------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 5d7bf1acaf3dc..6675509d68131 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -259,6 +259,7 @@ def physical_dataset(): example_database = get_example_database() engine = example_database.get_sqla_engine() + # sqlite can only execute one statement at a time engine.execute( """ CREATE TABLE IF NOT EXISTS physical_dataset( @@ -268,6 +269,10 @@ def physical_dataset(): col4 VARCHAR(255), col5 VARCHAR(255) ); + """ + ) + engine.execute( + """ INSERT INTO physical_dataset values (0, 'a', 1.0, NULL, '2000-01-01 00:00:00'), (1, 'b', 1.1, NULL, '2000-01-02 00:00:00'), @@ -293,6 +298,8 @@ def physical_dataset(): TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset) SqlMetric(metric_name="count", expression="count(*)", table=dataset) db.session.merge(dataset) + if example_database.backend == "sqlite": + db.session.commit() yield dataset diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 47166135d77c6..ad4d625cc5ae3 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -424,9 +424,6 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset): """ Dataset API: Test get dataset samples """ - if backend() == "sqlite": - return - # 1. should cache data uri = ( f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table" @@ -475,9 +472,6 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset): def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset): - if backend() == "sqlite": - return - TableColumn( column_name="DUMMY CC", type="VARCHAR(255)", @@ -499,9 +493,6 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset): - if backend() == "sqlite": - return - uri = ( f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table" )