From 355dfd2da58944bbbc4903eeea69d4b7d5a7f5a2 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 23 Aug 2021 13:04:38 +0300 Subject: [PATCH 1/2] fix(api): return total count on related endpoint --- superset/views/base_api.py | 15 +++++++--- tests/integration_tests/base_api_tests.py | 34 +++++++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/superset/views/base_api.py b/superset/views/base_api.py index b730ce12662cb..32779f025b6a1 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -490,6 +490,12 @@ def related(self, column_name: str, **kwargs: Any) -> FlaskResponse: # handle pagination page, page_size = self._handle_page_args(args) + + ids = args.get("include_ids") + if page and ids: + # pagination with forced ids is not supported + return self.response_400() + try: datamodel = self.datamodel.get_related_interface(column_name) except KeyError: @@ -504,7 +510,7 @@ def related(self, column_name: str, **kwargs: Any) -> FlaskResponse: # handle filters filters = self._get_related_filter(datamodel, column_name, args.get("filter")) # Make the query - _, rows = datamodel.query( + total_rows, rows = datamodel.query( filters, order_column, order_direction, page=page, page_size=page_size ) @@ -512,10 +518,11 @@ def related(self, column_name: str, **kwargs: Any) -> FlaskResponse: result = self._get_result_from_rows(datamodel, rows, column_name) # If ids are specified make sure we fetch and include them on the response - ids = args.get("include_ids") - self._add_extra_ids_to_result(datamodel, column_name, ids, result) + if ids: + self._add_extra_ids_to_result(datamodel, column_name, ids, result) + total_rows = len(result) - return self.response(200, count=len(result), result=result) + return self.response(200, count=total_rows, result=result) @expose("/distinct/", methods=["GET"]) @protect() diff --git a/tests/integration_tests/base_api_tests.py b/tests/integration_tests/base_api_tests.py index e6e795f4d6b27..bb816193882f5 100644 --- a/tests/integration_tests/base_api_tests.py +++ b/tests/integration_tests/base_api_tests.py @@ -202,6 +202,40 @@ def test_get_related_owners(self): for expected_user in expected_users: assert expected_user in response_users + def test_get_related_owners_paginated(self): + """ + API: Test get related owners with pagination + """ + self.login(username="admin") + page_size = 1 + argument = {"page_size": page_size} + uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}" + rv = self.client.get(uri) + assert rv.status_code == 200 + response = json.loads(rv.data.decode("utf-8")) + users = db.session.query(security_manager.user_model).all() + + # the count should correspond with the total number of users + assert response["count"] == len(users) + + # the length of the result should be at most equal to the page size + assert len(response["result"]) == min(page_size, len(users)) + + # make sure all received users are included in the full set of users + all_users = [str(user) for user in users] + for received_user in [result["text"] for result in response["result"]]: + assert received_user in all_users + + def test_get_ids_related_owners_paginated(self): + """ + API: Test get related owners with pagination returns 400 + """ + self.login(username="admin") + argument = {"page": 1, "page_size": 1, "include_ids": [2]} + uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}" + rv = self.client.get(uri) + assert rv.status_code == 400 + def test_get_filter_related_owners(self): """ API: Test get filter related owners From dc8a4b9e819a93a723f34eb2bdf28d5ab3aa8197 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Tue, 24 Aug 2021 07:20:25 +0300 Subject: [PATCH 2/2] update response code from 400 to 422 --- superset/views/base_api.py | 2 +- tests/integration_tests/base_api_tests.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 32779f025b6a1..ce922a98b77f2 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -494,7 +494,7 @@ def related(self, column_name: str, **kwargs: Any) -> FlaskResponse: ids = args.get("include_ids") if page and ids: # pagination with forced ids is not supported - return self.response_400() + return self.response_422() try: datamodel = self.datamodel.get_related_interface(column_name) diff --git a/tests/integration_tests/base_api_tests.py b/tests/integration_tests/base_api_tests.py index bb816193882f5..a76346149ef66 100644 --- a/tests/integration_tests/base_api_tests.py +++ b/tests/integration_tests/base_api_tests.py @@ -228,13 +228,13 @@ def test_get_related_owners_paginated(self): def test_get_ids_related_owners_paginated(self): """ - API: Test get related owners with pagination returns 400 + API: Test get related owners with pagination returns 422 """ self.login(username="admin") argument = {"page": 1, "page_size": 1, "include_ids": [2]} uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}" rv = self.client.get(uri) - assert rv.status_code == 400 + assert rv.status_code == 422 def test_get_filter_related_owners(self): """