From 0757e39549ae5c379ab0d899a2f0494617dc05d2 Mon Sep 17 00:00:00 2001 From: Daniel Vaz Gaspar Date: Tue, 7 Jul 2020 13:26:54 +0100 Subject: [PATCH] feat(api): bump marshmallow and FAB to version 3 (#9964) * feat(api): bump marshmallow and FAB to version 3 * revert query context tests changes * obey mypy * fix tests * ignore types that collide with marshmallow * preparing for RC2 * fix tests for marshmallow 3 * typing fixes for marshmallow * fix tests and black * fix tests * bump to RC3 and lint * Test RC4 * Final 3.0.0 * Address comments, fix tests, better naming, docs * fix test * couple of fixes, addressing comments * bumping marshmallow --- requirements.txt | 8 +-- setup.py | 2 +- superset/charts/api.py | 39 ++++++++------- superset/charts/commands/update.py | 2 +- superset/charts/schemas.py | 14 +++--- superset/commands/exceptions.py | 6 +-- superset/dashboards/api.py | 28 ++++++----- superset/dashboards/commands/create.py | 2 +- superset/dashboards/commands/exceptions.py | 2 +- superset/dashboards/schemas.py | 14 ++++-- superset/datasets/api.py | 24 +++++---- superset/datasets/commands/create.py | 2 +- superset/datasets/commands/exceptions.py | 42 ++++++++-------- superset/datasets/commands/update.py | 2 +- superset/views/base_api.py | 2 +- superset/views/base_schemas.py | 6 ++- tests/base_api_tests.py | 58 ++++++++++++++++------ tests/charts/api_tests.py | 6 ++- tests/charts/schema_tests.py | 17 +++---- tests/dashboards/api_tests.py | 4 +- tests/datasets/api_tests.py | 9 ++++ tests/query_context_tests.py | 3 +- 22 files changed, 173 insertions(+), 119 deletions(-) diff --git a/requirements.txt b/requirements.txt index 951ef76e604e2..3b6292de8c75d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ aiohttp==3.6.2 # via slackclient alembic==1.4.2 # via flask-migrate amqp==2.5.2 # via kombu -apispec[yaml]==1.3.3 # via flask-appbuilder +apispec[yaml]==3.3.1 # via flask-appbuilder async-timeout==3.0.1 # via aiohttp attrs==19.3.0 # via aiohttp, jsonschema babel==2.8.0 # via flask-babel @@ -29,7 +29,7 @@ decorator==4.4.2 # via retry defusedxml==0.6.0 # via python3-openid dnspython==1.16.0 # via email-validator email-validator==1.1.0 # via flask-appbuilder -flask-appbuilder==2.3.4 # via apache-superset (setup.py) +flask-appbuilder==3.0.0 # via apache-superset (setup.py) flask-babel==1.0.0 # via flask-appbuilder flask-caching==1.8.0 # via apache-superset (setup.py) flask-compress==1.5.0 # via apache-superset (setup.py) @@ -58,7 +58,7 @@ markdown==3.2.2 # via apache-superset (setup.py) markupsafe==1.1.1 # via jinja2, mako, wtforms marshmallow-enum==1.5.1 # via flask-appbuilder marshmallow-sqlalchemy==0.23.0 # via flask-appbuilder -marshmallow==2.21.0 # via flask-appbuilder, marshmallow-enum, marshmallow-sqlalchemy +marshmallow==3.6.1 # via flask-appbuilder, marshmallow-enum, marshmallow-sqlalchemy msgpack==1.0.0 # via apache-superset (setup.py) multidict==4.7.6 # via aiohttp, yarl numpy==1.18.4 # via pandas, pyarrow @@ -100,4 +100,4 @@ yarl==1.4.2 # via aiohttp zipp==3.1.0 # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: -# setuptools \ No newline at end of file +# setuptools diff --git a/setup.py b/setup.py index 141fd448664b8..200d90251457b 100644 --- a/setup.py +++ b/setup.py @@ -78,7 +78,7 @@ def get_git_sha(): "cryptography>=2.4.2", "dataclasses<0.7", "flask>=1.1.0, <2.0.0", - "flask-appbuilder>=2.3.4, <2.4.0", + "flask-appbuilder>=3.0.0, <4.0.0", "flask-caching", "flask-compress", "flask-talisman", diff --git a/superset/charts/api.py b/superset/charts/api.py index a1a8b58feeb62..ad78df007cc7c 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -23,6 +23,7 @@ from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import gettext as _, ngettext +from marshmallow import ValidationError from werkzeug.wrappers import Response as WerkzeugResponse from werkzeug.wsgi import FileWrapper @@ -99,6 +100,8 @@ class ChartRestApi(BaseSupersetModelRestApi): "params", "cache_timeout", ] + show_select_columns = show_columns + ["table.id"] + list_columns = [ "id", "slice_name", @@ -121,6 +124,7 @@ class ChartRestApi(BaseSupersetModelRestApi): "params", "cache_timeout", ] + order_columns = [ "slice_name", "viz_type", @@ -215,13 +219,14 @@ def post(self) -> Response: """ if not request.is_json: return self.response_400(message="Request is not JSON") - item = self.add_model_schema.load(request.json) + try: + item = self.add_model_schema.load(request.json) # This validates custom Schema with custom validations - if item.errors: - return self.response_400(message=item.errors) + except ValidationError as error: + return self.response_400(message=error.messages) try: - new_model = CreateChartCommand(g.user, item.data).run() - return self.response(201, id=new_model.id, result=item.data) + new_model = CreateChartCommand(g.user, item).run() + return self.response(201, id=new_model.id, result=item) except ChartInvalidError as ex: return self.response_422(message=ex.normalized_messages()) except ChartCreateFailedError as ex: @@ -281,13 +286,14 @@ def put( # pylint: disable=too-many-return-statements, arguments-differ """ if not request.is_json: return self.response_400(message="Request is not JSON") - item = self.edit_model_schema.load(request.json) + try: + item = self.edit_model_schema.load(request.json) # This validates custom Schema with custom validations - if item.errors: - return self.response_400(message=item.errors) + except ValidationError as error: + return self.response_400(message=error.messages) try: - changed_model = UpdateChartCommand(g.user, pk, item.data).run() - return self.response(200, id=changed_model.id, result=item.data) + changed_model = UpdateChartCommand(g.user, pk, item).run() + return self.response(200, id=changed_model.id, result=item) except ChartNotFoundError: return self.response_404() except ChartForbiddenError: @@ -442,8 +448,7 @@ def data(self) -> Response: $ref: '#/components/responses/400' 500: $ref: '#/components/responses/500' - """ - + """ if request.is_json: json_body = request.json elif request.form.get("form_data"): @@ -452,13 +457,13 @@ def data(self) -> Response: else: return self.response_400(message="Request is not JSON") try: - query_context, errors = ChartDataQueryContextSchema().load(json_body) - if errors: - return self.response_400( - message=_("Request is incorrect: %(error)s", error=errors) - ) + query_context = ChartDataQueryContextSchema().load(json_body) except KeyError: return self.response_400(message="Request is incorrect") + except ValidationError as error: + return self.response_400( + _("Request is incorrect: %(error)s", error=error.messages) + ) try: query_context.raise_for_access() except SupersetSecurityException: diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py index 70055bf4725af..8541930f56944 100644 --- a/superset/charts/commands/update.py +++ b/superset/charts/commands/update.py @@ -58,7 +58,7 @@ def run(self) -> Model: return chart def validate(self) -> None: - exceptions = list() + exceptions: List[ValidationError] = list() dashboard_ids = self._properties.get("dashboards", []) owner_ids: Optional[List[int]] = self._properties.get("owners") diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 20b73e220d77f..60cec21b2b8f8 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -246,8 +246,7 @@ def __init__(self) -> None: class ChartDataPostProcessingOperationOptionsSchema(Schema): - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + pass class ChartDataAggregateOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): @@ -369,7 +368,7 @@ class ChartDataSelectOptionsSchema(ChartDataPostProcessingOperationOptionsSchema "referenced here.", example=["country", "gender", "age"], ) - exclude = fields.List( + exclude = fields.List( # type: ignore fields.String(), description="Columns to exclude from selection.", example=["my_temp_column"], @@ -676,6 +675,9 @@ class ChartDataQueryObjectSchema(Schema): timeseries_limit = fields.Integer( description="Maximum row count for timeseries queries. Default: `0`", ) + timeseries_limit_metric = fields.Integer( + description="Metric used to limit timeseries queries by.", allow_none=True, + ) row_limit = fields.Integer( description='Maximum row count. Default: `config["ROW_LIMIT"]`', validate=[ @@ -744,13 +746,13 @@ class ChartDataQueryContextSchema(Schema): validate=validate.OneOf(choices=("json", "csv")), ) - # pylint: disable=no-self-use + # pylint: disable=no-self-use,unused-argument @post_load - def make_query_context(self, data: Dict[str, Any]) -> QueryContext: + def make_query_context(self, data: Dict[str, Any], **kwargs: Any) -> QueryContext: query_context = QueryContext(**data) return query_context - # pylint: enable=no-self-use + # pylint: enable=no-self-use,unused-argument class ChartDataResponseResult(Schema): diff --git a/superset/commands/exceptions.py b/superset/commands/exceptions.py index cf67ea9227159..d470043bb38b2 100644 --- a/superset/commands/exceptions.py +++ b/superset/commands/exceptions.py @@ -49,7 +49,7 @@ def add_list(self, exceptions: List[ValidationError]) -> None: def normalized_messages(self) -> Dict[Any, Any]: errors: Dict[Any, Any] = {} for exception in self._invalid_exceptions: - errors.update(exception.normalized_messages()) + errors.update(exception.normalized_messages()) # type: ignore return errors @@ -77,11 +77,11 @@ class OwnersNotFoundValidationError(ValidationError): status = 422 def __init__(self) -> None: - super().__init__(_("Owners are invalid"), field_names=["owners"]) + super().__init__([_("Owners are invalid")], field_name="owners") class DatasourceNotFoundValidationError(ValidationError): status = 404 def __init__(self) -> None: - super().__init__(_("Datasource does not exist"), field_names=["datasource_id"]) + super().__init__([_("Datasource does not exist")], field_name="datasource_id") diff --git a/superset/dashboards/api.py b/superset/dashboards/api.py index 7fef29fe684a9..29572cc541d2a 100644 --- a/superset/dashboards/api.py +++ b/superset/dashboards/api.py @@ -21,6 +21,7 @@ from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import ngettext +from marshmallow import ValidationError from werkzeug.wrappers import Response as WerkzeugResponse from werkzeug.wsgi import FileWrapper @@ -118,7 +119,7 @@ class DashboardRestApi(BaseSupersetModelRestApi): "owners.first_name", "owners.last_name", ] - edit_columns = [ + add_columns = [ "dashboard_title", "slug", "owners", @@ -127,9 +128,10 @@ class DashboardRestApi(BaseSupersetModelRestApi): "json_metadata", "published", ] + edit_columns = add_columns + search_columns = ("dashboard_title", "slug", "owners", "published") search_filters = {"dashboard_title": [DashboardTitleOrSlugFilter]} - add_columns = edit_columns base_order = ("changed_on", "desc") add_model_schema = DashboardPostSchema() @@ -197,13 +199,14 @@ def post(self) -> Response: """ if not request.is_json: return self.response_400(message="Request is not JSON") - item = self.add_model_schema.load(request.json) + try: + item = self.add_model_schema.load(request.json) # This validates custom Schema with custom validations - if item.errors: - return self.response_400(message=item.errors) + except ValidationError as error: + return self.response_400(message=error.messages) try: - new_model = CreateDashboardCommand(g.user, item.data).run() - return self.response(201, id=new_model.id, result=item.data) + new_model = CreateDashboardCommand(g.user, item).run() + return self.response(201, id=new_model.id, result=item) except DashboardInvalidError as ex: return self.response_422(message=ex.normalized_messages()) except DashboardCreateFailedError as ex: @@ -263,13 +266,14 @@ def put( # pylint: disable=too-many-return-statements, arguments-differ """ if not request.is_json: return self.response_400(message="Request is not JSON") - item = self.edit_model_schema.load(request.json) + try: + item = self.edit_model_schema.load(request.json) # This validates custom Schema with custom validations - if item.errors: - return self.response_400(message=item.errors) + except ValidationError as error: + return self.response_400(message=error.messages) try: - changed_model = UpdateDashboardCommand(g.user, pk, item.data).run() - return self.response(200, id=changed_model.id, result=item.data) + changed_model = UpdateDashboardCommand(g.user, pk, item).run() + return self.response(200, id=changed_model.id, result=item) except DashboardNotFoundError: return self.response_404() except DashboardForbiddenError: diff --git a/superset/dashboards/commands/create.py b/superset/dashboards/commands/create.py index 83763691a0d26..dff25d6fcc4a5 100644 --- a/superset/dashboards/commands/create.py +++ b/superset/dashboards/commands/create.py @@ -50,7 +50,7 @@ def run(self) -> Model: return dashboard def validate(self) -> None: - exceptions = list() + exceptions: List[ValidationError] = list() owner_ids: Optional[List[int]] = self._properties.get("owners") slug: str = self._properties.get("slug", "") diff --git a/superset/dashboards/commands/exceptions.py b/superset/dashboards/commands/exceptions.py index 13b0d5b74bc91..c6f78ee8b79bc 100644 --- a/superset/dashboards/commands/exceptions.py +++ b/superset/dashboards/commands/exceptions.py @@ -33,7 +33,7 @@ class DashboardSlugExistsValidationError(ValidationError): """ def __init__(self) -> None: - super().__init__(_("Must be unique"), field_names=["slug"]) + super().__init__([_("Must be unique")], field_name="slug") class DashboardInvalidError(CommandInvalidError): diff --git a/superset/dashboards/schemas.py b/superset/dashboards/schemas.py index 0d8c60ad170b0..258770586639f 100644 --- a/superset/dashboards/schemas.py +++ b/superset/dashboards/schemas.py @@ -18,7 +18,7 @@ import re from typing import Any, Dict, Union -from marshmallow import fields, pre_load, Schema +from marshmallow import fields, post_load, Schema from marshmallow.validate import Length, ValidationError from superset.exceptions import SupersetException @@ -92,7 +92,7 @@ def validate_json_metadata(value: Union[bytes, bytearray, str]) -> None: value_obj = json.loads(value) except json.decoder.JSONDecodeError: raise ValidationError("JSON not valid") - errors = DashboardJSONMetadataSchema(strict=True).validate(value_obj, partial=False) + errors = DashboardJSONMetadataSchema().validate(value_obj, partial=False) if errors: raise ValidationError(errors) @@ -110,12 +110,16 @@ class DashboardJSONMetadataSchema(Schema): class BaseDashboardSchema(Schema): - @pre_load - def pre_load(self, data: Dict[str, Any]) -> None: # pylint: disable=no-self-use + # pylint: disable=no-self-use,unused-argument + @post_load + def post_load(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: if data.get("slug"): data["slug"] = data["slug"].strip() data["slug"] = data["slug"].replace(" ", "-") data["slug"] = re.sub(r"[^\w\-]+", "", data["slug"]) + return data + + # pylint: disable=no-self-use,unused-argument class DashboardPostSchema(BaseDashboardSchema): @@ -133,7 +137,7 @@ class DashboardPostSchema(BaseDashboardSchema): ) css = fields.String() json_metadata = fields.String( - description=json_metadata_description, validate=validate_json_metadata + description=json_metadata_description, validate=validate_json_metadata, ) published = fields.Boolean(description=published_description) diff --git a/superset/datasets/api.py b/superset/datasets/api.py index d6e5fe62dc40e..691138192ae5c 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -21,6 +21,7 @@ from flask import g, request, Response from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface +from marshmallow import ValidationError from superset.connectors.sqla.models import SqlaTable from superset.constants import RouteMethod @@ -61,7 +62,6 @@ class DatasetRestApi(BaseSupersetModelRestApi): resource_name = "dataset" allow_browser_login = True - class_permission_name = "TableModelView" include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | { RouteMethod.EXPORT, @@ -179,13 +179,14 @@ def post(self) -> Response: """ if not request.is_json: return self.response_400(message="Request is not JSON") - item = self.add_model_schema.load(request.json) + try: + item = self.add_model_schema.load(request.json) # This validates custom Schema with custom validations - if item.errors: - return self.response_400(message=item.errors) + except ValidationError as error: + return self.response_400(message=error.messages) try: - new_model = CreateDatasetCommand(g.user, item.data).run() - return self.response(201, id=new_model.id, result=item.data) + new_model = CreateDatasetCommand(g.user, item).run() + return self.response(201, id=new_model.id, result=item) except DatasetInvalidError as ex: return self.response_422(message=ex.normalized_messages()) except DatasetCreateFailedError as ex: @@ -245,13 +246,14 @@ def put( # pylint: disable=too-many-return-statements, arguments-differ """ if not request.is_json: return self.response_400(message="Request is not JSON") - item = self.edit_model_schema.load(request.json) + try: + item = self.edit_model_schema.load(request.json) # This validates custom Schema with custom validations - if item.errors: - return self.response_400(message=item.errors) + except ValidationError as error: + return self.response_400(message=error.messages) try: - changed_model = UpdateDatasetCommand(g.user, pk, item.data).run() - return self.response(200, id=changed_model.id, result=item.data) + changed_model = UpdateDatasetCommand(g.user, pk, item).run() + return self.response(200, id=changed_model.id, result=item) except DatasetNotFoundError: return self.response_404() except DatasetForbiddenError: diff --git a/superset/datasets/commands/create.py b/superset/datasets/commands/create.py index 436fdd263663e..8dd0f4a2b8a88 100644 --- a/superset/datasets/commands/create.py +++ b/superset/datasets/commands/create.py @@ -67,7 +67,7 @@ def run(self) -> Model: return dataset def validate(self) -> None: - exceptions = list() + exceptions: List[ValidationError] = list() database_id = self._properties["database"] table_name = self._properties["table_name"] schema = self._properties.get("schema", "") diff --git a/superset/datasets/commands/exceptions.py b/superset/datasets/commands/exceptions.py index 69bd7cd7d3e4f..b651b6feb9fbf 100644 --- a/superset/datasets/commands/exceptions.py +++ b/superset/datasets/commands/exceptions.py @@ -34,7 +34,7 @@ class DatabaseNotFoundValidationError(ValidationError): """ def __init__(self) -> None: - super().__init__(_("Database does not exist"), field_names=["database"]) + super().__init__([_("Database does not exist")], field_name="database") class DatabaseChangeValidationError(ValidationError): @@ -43,7 +43,7 @@ class DatabaseChangeValidationError(ValidationError): """ def __init__(self) -> None: - super().__init__(_("Database not allowed to change"), field_names=["database"]) + super().__init__([_("Database not allowed to change")], field_name="database") class DatasetExistsValidationError(ValidationError): @@ -53,7 +53,7 @@ class DatasetExistsValidationError(ValidationError): def __init__(self, table_name: str) -> None: super().__init__( - get_datasource_exist_error_msg(table_name), field_names=["table_name"] + get_datasource_exist_error_msg(table_name), field_name="table_name" ) @@ -63,7 +63,7 @@ class DatasetColumnNotFoundValidationError(ValidationError): """ def __init__(self) -> None: - super().__init__(_("One or more columns do not exist"), field_names=["columns"]) + super().__init__([_("One or more columns do not exist")], field_name="columns") class DatasetColumnsDuplicateValidationError(ValidationError): @@ -73,7 +73,7 @@ class DatasetColumnsDuplicateValidationError(ValidationError): def __init__(self) -> None: super().__init__( - _("One or more columns are duplicated"), field_names=["columns"] + [_("One or more columns are duplicated")], field_name="columns" ) @@ -83,9 +83,7 @@ class DatasetColumnsExistsValidationError(ValidationError): """ def __init__(self) -> None: - super().__init__( - _("One or more columns already exist"), field_names=["columns"] - ) + super().__init__([_("One or more columns already exist")], field_name="columns") class DatasetMetricsNotFoundValidationError(ValidationError): @@ -94,7 +92,7 @@ class DatasetMetricsNotFoundValidationError(ValidationError): """ def __init__(self) -> None: - super().__init__(_("One or more metrics do not exist"), field_names=["metrics"]) + super().__init__([_("One or more metrics do not exist")], field_name="metrics") class DatasetMetricsDuplicateValidationError(ValidationError): @@ -104,7 +102,7 @@ class DatasetMetricsDuplicateValidationError(ValidationError): def __init__(self) -> None: super().__init__( - _("One or more metrics are duplicated"), field_names=["metrics"] + [_("One or more metrics are duplicated")], field_name="metrics" ) @@ -114,9 +112,7 @@ class DatasetMetricsExistsValidationError(ValidationError): """ def __init__(self) -> None: - super().__init__( - _("One or more metrics already exist"), field_names=["metrics"] - ) + super().__init__([_("One or more metrics already exist")], field_name="metrics") class TableNotFoundValidationError(ValidationError): @@ -126,20 +122,22 @@ class TableNotFoundValidationError(ValidationError): def __init__(self, table_name: str) -> None: super().__init__( - _( - "Table [%(table_name)s] could not be found, " - "please double check your " - "database connection, schema, and " - "table name", - table_name=table_name, - ), - field_names=["table_name"], + [ + _( + "Table [%(table_name)s] could not be found, " + "please double check your " + "database connection, schema, and " + "table name", + table_name=table_name, + ) + ], + field_name="table_name", ) class OwnersNotFoundValidationError(ValidationError): def __init__(self) -> None: - super().__init__(_("Owners are invalid"), field_names=["owners"]) + super().__init__([_("Owners are invalid")], field_name="owners") class DatasetNotFoundError(CommandException): diff --git a/superset/datasets/commands/update.py b/superset/datasets/commands/update.py index 14cc08785297a..dfc3986c09f22 100644 --- a/superset/datasets/commands/update.py +++ b/superset/datasets/commands/update.py @@ -66,7 +66,7 @@ def run(self) -> Model: raise DatasetUpdateFailedError() def validate(self) -> None: - exceptions = list() + exceptions: List[ValidationError] = list() owner_ids: Optional[List[int]] = self._properties.get("owners") # Validate/populate model exists self._model = DatasetDAO.find_by_id(self._model_id) diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 2279adee5fac8..a64fc1b5f15f6 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -113,7 +113,7 @@ class BaseSupersetModelRestApi(ModelRestApi): """ # pylint: disable=pointless-string-statement allowed_rel_fields: Set[str] = set() - openapi_spec_component_schemas: Tuple[Schema, ...] = tuple() + openapi_spec_component_schemas: Tuple[Type[Schema], ...] = tuple() """ Add extra schemas to the OpenAPI component schemas section """ # pylint: disable=pointless-string-statement diff --git a/superset/views/base_schemas.py b/superset/views/base_schemas.py index 87a9190596eb4..659ec3fdfcc1a 100644 --- a/superset/views/base_schemas.py +++ b/superset/views/base_schemas.py @@ -52,11 +52,15 @@ def load( # pylint: disable=arguments-differ self, data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]], many: Optional[bool] = None, - partial: Optional[Union[bool, Sequence[str], Set[str]]] = None, + partial: Union[bool, Sequence[str], Set[str], None] = None, instance: Optional[Model] = None, **kwargs: Any, ) -> Any: self.instance = instance + if many is None: + many = False + if partial is None: + partial = False return super().load(data, many=many, partial=partial, **kwargs) @post_load diff --git a/tests/base_api_tests.py b/tests/base_api_tests.py index e1f6e4893b2a2..ba7dde5cbea91 100644 --- a/tests/base_api_tests.py +++ b/tests/base_api_tests.py @@ -51,7 +51,10 @@ class Model1Api(BaseSupersetModelRestApi): class TestBaseModelRestApi(SupersetTestCase): def test_default_missing_declaration_get(self): """ - API: Test default missing declaration on get + API: Test default missing declaration on get + + We want to make sure that not declared list_columns will + not render all columns by default but just the model's pk """ # Check get list response self.login(username="admin") @@ -73,6 +76,12 @@ def test_default_missing_declaration_get(self): self.assertEqual(list(response["result"].keys()), ["id"]) def test_default_missing_declaration_put_spec(self): + """ + API: Test default missing declaration on put openapi spec + + We want to make sure that not declared edit_columns will + not render all columns by default but just the model's pk + """ self.login(username="admin") uri = "api/v1/_openapi" rv = self.client.get(uri) @@ -91,6 +100,12 @@ def test_default_missing_declaration_put_spec(self): ) def test_default_missing_declaration_post(self): + """ + API: Test default missing declaration on post + + We want to make sure that not declared add_columns will + not accept all columns by default + """ dashboard_data = { "dashboard_title": "title1", "slug": "slug1", @@ -102,30 +117,41 @@ def test_default_missing_declaration_post(self): self.login(username="admin") uri = "api/v1/model1api/" rv = self.client.post(uri, json=dashboard_data) - # dashboard model accepts all fields are null - self.assertEqual(rv.status_code, 201) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(list(response["result"].keys()), ["id"]) - model = db.session.query(Dashboard).get(response["id"]) - self.assertEqual(model.dashboard_title, None) - self.assertEqual(model.slug, None) - self.assertEqual(model.position_json, None) - self.assertEqual(model.json_metadata, None) - db.session.delete(model) - db.session.commit() + self.assertEqual(rv.status_code, 422) + expected_response = { + "message": { + "css": ["Unknown field."], + "dashboard_title": ["Unknown field."], + "json_metadata": ["Unknown field."], + "position_json": ["Unknown field."], + "published": ["Unknown field."], + "slug": ["Unknown field."], + } + } + self.assertEqual(response, expected_response) def test_default_missing_declaration_put(self): + """ + API: Test default missing declaration on put + + We want to make sure that not declared edit_columns will + not accept all columns by default + """ dashboard = db.session.query(Dashboard).first() dashboard_data = {"dashboard_title": "CHANGED", "slug": "CHANGED"} self.login(username="admin") uri = f"api/v1/model1api/{dashboard.id}" rv = self.client.put(uri, json=dashboard_data) - # dashboard model accepts all fields are null - self.assertEqual(rv.status_code, 200) response = json.loads(rv.data.decode("utf-8")) - changed_dashboard = db.session.query(Dashboard).get(dashboard.id) - self.assertNotEqual(changed_dashboard.dashboard_title, "CHANGED") - self.assertNotEqual(changed_dashboard.slug, "CHANGED") + self.assertEqual(rv.status_code, 422) + expected_response = { + "message": { + "dashboard_title": ["Unknown field."], + "slug": ["Unknown field."], + } + } + self.assertEqual(response, expected_response) class ApiOwnersTestCaseMixin: diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index 24aa1e8496761..c3e3b50effb11 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -338,7 +338,8 @@ def test_create_chart_validate_datasource(self): self.assertEqual(rv.status_code, 400) response = json.loads(rv.data.decode("utf-8")) self.assertEqual( - response, {"message": {"datasource_type": ["Not a valid choice."]}} + response, + {"message": {"datasource_type": ["Must be one of: druid, table, view."]}}, ) chart_data = { "slice_name": "title1", @@ -444,7 +445,8 @@ def test_update_chart_validate_datasource(self): self.assertEqual(rv.status_code, 400) response = json.loads(rv.data.decode("utf-8")) self.assertEqual( - response, {"message": {"datasource_type": ["Not a valid choice."]}} + response, + {"message": {"datasource_type": ["Must be one of: druid, table, view."]}}, ) chart_data = {"datasource_id": 0, "datasource_type": "table"} uri = f"api/v1/chart/{chart.id}" diff --git a/tests/charts/schema_tests.py b/tests/charts/schema_tests.py index 4c998fc27abf4..354ed823c4cd9 100644 --- a/tests/charts/schema_tests.py +++ b/tests/charts/schema_tests.py @@ -18,6 +18,7 @@ """Unit tests for Superset""" from typing import Any, Dict, Tuple +from marshmallow import ValidationError from tests.test_app import app from superset.charts.schemas import ChartDataQueryContextSchema from superset.common.query_context import QueryContext @@ -39,8 +40,7 @@ def test_query_context_limit_and_offset(self): # Use defaults payload["queries"][0].pop("row_limit", None) payload["queries"][0].pop("row_offset", None) - query_context, errors = load_query_context(payload) - self.assertEqual(errors, {}) + query_context = load_query_context(payload) query_object = query_context.queries[0] self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"]) self.assertEqual(query_object.row_offset, 0) @@ -48,8 +48,7 @@ def test_query_context_limit_and_offset(self): # Valid limit and offset payload["queries"][0]["row_limit"] = 100 payload["queries"][0]["row_offset"] = 200 - query_context, errors = ChartDataQueryContextSchema().load(payload) - self.assertEqual(errors, {}) + query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] self.assertEqual(query_object.row_limit, 100) self.assertEqual(query_object.row_offset, 200) @@ -57,9 +56,10 @@ def test_query_context_limit_and_offset(self): # too low limit and offset payload["queries"][0]["row_limit"] = 0 payload["queries"][0]["row_offset"] = -1 - query_context, errors = ChartDataQueryContextSchema().load(payload) - self.assertIn("row_limit", errors["queries"][0]) - self.assertIn("row_offset", errors["queries"][0]) + with self.assertRaises(ValidationError) as context: + _ = ChartDataQueryContextSchema().load(payload) + self.assertIn("row_limit", context.exception.messages["queries"][0]) + self.assertIn("row_offset", context.exception.messages["queries"][0]) def test_query_context_null_timegrain(self): self.login(username="admin") @@ -68,5 +68,4 @@ def test_query_context_null_timegrain(self): payload = get_query_context(table.name, table.id, table.type) payload["queries"][0]["extras"]["time_grain_sqla"] = None - _, errors = ChartDataQueryContextSchema().load(payload) - self.assertEqual(errors, {}) + _ = ChartDataQueryContextSchema().load(payload) diff --git a/tests/dashboards/api_tests.py b/tests/dashboards/api_tests.py index 145706c784a57..5637e06d6480d 100644 --- a/tests/dashboards/api_tests.py +++ b/tests/dashboards/api_tests.py @@ -40,7 +40,7 @@ class TestDashboardApi(SupersetTestCase, ApiOwnersTestCaseMixin): "slug": "slug1_changed", "position_json": '{"b": "B"}', "css": "css_changed", - "json_metadata": '{"a": "A"}', + "json_metadata": '{"refresh_frequency": 30}', "published": False, } @@ -473,7 +473,7 @@ def test_create_dashboard(self): "owners": [admin_id], "position_json": '{"a": "A"}', "css": "css", - "json_metadata": '{"b": "B"}', + "json_metadata": '{"refresh_frequency": 30}', "published": True, } self.login(username="admin") diff --git a/tests/datasets/api_tests.py b/tests/datasets/api_tests.py index 08c751f01dd7c..431c53d0a752c 100644 --- a/tests/datasets/api_tests.py +++ b/tests/datasets/api_tests.py @@ -371,6 +371,11 @@ def test_update_dataset_create_column(self): self.login(username="admin") rv = self.get_assert_metric(uri, "get") data = json.loads(rv.data.decode("utf-8")) + + for column in data["result"]["columns"]: + column.pop("changed_on", None) + column.pop("created_on", None) + data["result"]["columns"].append(new_column_data) rv = self.client.put(uri, json={"columns": data["result"]["columns"]}) @@ -404,6 +409,10 @@ def test_update_dataset_update_column(self): # Get current cols and alter one rv = self.get_assert_metric(uri, "get") resp_columns = json.loads(rv.data.decode("utf-8"))["result"]["columns"] + for column in resp_columns: + column.pop("changed_on", None) + column.pop("created_on", None) + resp_columns[0]["groupby"] = False resp_columns[0]["filterable"] = False v = self.client.put(uri, json={"columns": resp_columns}) diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py index 2de62f88f635a..4b625b5a06c2d 100644 --- a/tests/query_context_tests.py +++ b/tests/query_context_tests.py @@ -39,8 +39,7 @@ def test_schema_deserialization(self): payload = get_query_context( table.name, table.id, table.type, add_postprocessing_operations=True ) - query_context, errors = ChartDataQueryContextSchema().load(payload) - self.assertDictEqual(errors, {}) + query_context = ChartDataQueryContextSchema().load(payload) self.assertEqual(len(query_context.queries), len(payload["queries"])) for query_idx, query in enumerate(query_context.queries): payload_query = payload["queries"][query_idx]