Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(api): return total count on related endpoint #16397

Merged
merged 2 commits into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions superset/views/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,12 @@ def related(self, column_name: str, **kwargs: Any) -> FlaskResponse:

# handle pagination
page, page_size = self._handle_page_args(args)

ids = args.get("include_ids")
if page and ids:
# pagination with forced ids is not supported
return self.response_400()
Copy link
Member

@michael-s-molina michael-s-molina Aug 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea - will update 👍


try:
datamodel = self.datamodel.get_related_interface(column_name)
except KeyError:
Expand All @@ -504,18 +510,19 @@ def related(self, column_name: str, **kwargs: Any) -> FlaskResponse:
# handle filters
filters = self._get_related_filter(datamodel, column_name, args.get("filter"))
# Make the query
_, rows = datamodel.query(
total_rows, rows = datamodel.query(
filters, order_column, order_direction, page=page, page_size=page_size
)

# produce response
result = self._get_result_from_rows(datamodel, rows, column_name)

# If ids are specified make sure we fetch and include them on the response
ids = args.get("include_ids")
self._add_extra_ids_to_result(datamodel, column_name, ids, result)
if ids:
self._add_extra_ids_to_result(datamodel, column_name, ids, result)
total_rows = len(result)

return self.response(200, count=len(result), result=result)
return self.response(200, count=total_rows, result=result)

@expose("/distinct/<column_name>", methods=["GET"])
@protect()
Expand Down
34 changes: 34 additions & 0 deletions tests/integration_tests/base_api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,40 @@ def test_get_related_owners(self):
for expected_user in expected_users:
assert expected_user in response_users

def test_get_related_owners_paginated(self):
"""
API: Test get related owners with pagination
"""
self.login(username="admin")
page_size = 1
argument = {"page_size": page_size}
uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"
rv = self.client.get(uri)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
users = db.session.query(security_manager.user_model).all()

# the count should correspond with the total number of users
assert response["count"] == len(users)

# the length of the result should be at most equal to the page size
assert len(response["result"]) == min(page_size, len(users))

# make sure all received users are included in the full set of users
all_users = [str(user) for user in users]
for received_user in [result["text"] for result in response["result"]]:
assert received_user in all_users

def test_get_ids_related_owners_paginated(self):
"""
API: Test get related owners with pagination returns 400
"""
self.login(username="admin")
argument = {"page": 1, "page_size": 1, "include_ids": [2]}
uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"
rv = self.client.get(uri)
assert rv.status_code == 400

def test_get_filter_related_owners(self):
"""
API: Test get filter related owners
Expand Down