From 855eb17a29f9d78a6e73418bb0e92a30ace0e73f Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Sun, 28 Jul 2024 14:53:08 +0800 Subject: [PATCH] :sparkles: Feature: add graphql pagination support (#121) --- README.md | 75 +++++++++++++++++- githubkit/exception.py | 24 +++++- githubkit/github.py | 24 ++---- githubkit/graphql/__init__.py | 102 ++++++++++++++++++------ githubkit/graphql/paginator.py | 120 +++++++++++++++++++++++++++++ tests/test_graphql/test_graphql.py | 28 +++++++ 6 files changed, 329 insertions(+), 44 deletions(-) create mode 100644 githubkit/graphql/paginator.py diff --git a/README.md b/README.md index dc3fc5945..5b633c8aa 100644 --- a/README.md +++ b/README.md @@ -457,7 +457,7 @@ Current supported versions are: (you can find it in the section `[[tool.codegen. - 2022-11-28 (latest) - ghec-2022-11-28 -### Pagination +### Rest API Pagination Pagination type checking is also supported: @@ -509,6 +509,79 @@ Simple async call: data: Dict[str, Any] = await github.async_graphql(query, variables={"foo": "bar"}) ``` +### GraphQL Pagination + +githubkit also provides a helper function to paginate the GraphQL API. + +First, You must accept a `cursor` parameter and return a `pageInfo` object in your query. For example: + +```graphql +query ($owner: String!, $repo: String!, $cursor: String) { + repository(owner: $owner, name: $repo) { + issues(first: 10, after: $cursor) { + nodes { + number + } + pageInfo { + hasNextPage + endCursor + } + } + } +} +``` + +The `pageInfo` object in your query must be one of the following types depending on the direction of the pagination: + +For forward pagination, use: + +```graphql +pageInfo { + hasNextPage + endCursor +} +``` + +For backward pagination, use: + +```graphql +pageInfo { + hasPreviousPage + startCursor +} +``` + +If you provide all 4 properties in a `pageInfo`, githubkit will default to forward pagination. + +Then, you can iterate over the paginated results by using the graphql `paginate` method: + +```python +for result in github.graphql.paginate( + query, variables={"owner": "owner", "repo": "repo"} +): + print(result) +``` + +Note that the `result` is a dict containing the list of nodes/edges for each page and the `pageInfo` object. You should iterate over the `nodes` or `edges` list to get the actual data. For example: + +```python +for result in g.graphql.paginate(query, {"owner": "owner", "repo": "repo"}): + for issue in result["repository"]["issues"]["nodes"]: + print(issue) +``` + +You can also provide a initial cursor value to start pagination from a specific point: + +```python +for result in github.graphql.paginate( + query, variables={"owner": "owner", "repo": "repo", "cursor": "initial_cursor"} +): + print(result) +``` + +> [!NOTE] +> Nested pagination is not supported. + ### Auto Retry By default, githubkit will retry the request when specific exception encountered. When rate limit exceeded, githubkit will retry once after GitHub suggested waiting time. When server error encountered (http status >= 500), githubkit will retry max three times. diff --git a/githubkit/exception.py b/githubkit/exception.py index fa36ffd98..cf36659e5 100644 --- a/githubkit/exception.py +++ b/githubkit/exception.py @@ -73,7 +73,11 @@ class SecondaryRateLimitExceeded(RateLimitExceeded): """API request failed with secondary rate limit exceeded""" -class GraphQLFailed(GitHubException): +class GraphQLError(GitHubException): + """Simple GraphQL request error""" + + +class GraphQLFailed(GraphQLError): """GraphQL request with errors in response""" def __init__(self, response: "GraphQLResponse"): @@ -83,6 +87,24 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self.response.errors!r})" +class GraphQLPaginationError(GraphQLError): + """GraphQL paginate response error""" + + def __init__(self, response: "GraphQLResponse"): + self.response = response + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.response})" + + +class GraphQLMissingPageInfo(GraphQLPaginationError): + """GraphQL paginate response missing PageInfo object""" + + +class GraphQLMissingCursorChange(GraphQLPaginationError): + """GraphQL paginate response missing cursor change""" + + class WebhookTypeNotFound(GitHubException): """Webhook event type not found""" diff --git a/githubkit/github.py b/githubkit/github.py index 8273f1e5f..e423fd5c0 100644 --- a/githubkit/github.py +++ b/githubkit/github.py @@ -17,9 +17,9 @@ from .response import Response from .paginator import Paginator from .auth import BaseAuthStrategy +from .graphql import GraphQLNamespace from .typing import RetryDecisionFunc from .versions import RestVersionSwitcher, WebhooksVersionSwitcher -from .graphql import GraphQLResponse, build_graphql_request, parse_graphql_response if TYPE_CHECKING: import httpx @@ -133,27 +133,15 @@ def rest(self) -> RestVersionSwitcher: webhooks = WebhooksVersionSwitcher() # graphql - def graphql( - self, query: str, variables: Optional[Dict[str, Any]] = None - ) -> Dict[str, Any]: - json = build_graphql_request(query, variables) - - return parse_graphql_response( - self, - self.request("POST", "/graphql", json=json, response_model=GraphQLResponse), - ) + @cached_property + def graphql(self) -> GraphQLNamespace: + return GraphQLNamespace(self) + # alias for graphql.arequest async def async_graphql( self, query: str, variables: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: - json = build_graphql_request(query, variables) - - return parse_graphql_response( - self, - await self.arequest( - "POST", "/graphql", json=json, response_model=GraphQLResponse - ), - ) + return await self.graphql.arequest(query, variables) # rest pagination paginate = Paginator diff --git a/githubkit/graphql/__init__.py b/githubkit/graphql/__init__.py index ade1ed4b2..11dcf9cf0 100644 --- a/githubkit/graphql/__init__.py +++ b/githubkit/graphql/__init__.py @@ -1,7 +1,9 @@ +from weakref import ref from typing import TYPE_CHECKING, Any, Dict, Optional, cast from githubkit.exception import GraphQLFailed, PrimaryRateLimitExceeded +from .paginator import Paginator as Paginator from .models import GraphQLError as GraphQLError from .models import SourceLocation as SourceLocation from .models import GraphQLResponse as GraphQLResponse @@ -11,27 +13,79 @@ from githubkit.response import Response -def build_graphql_request( - query: str, variables: Optional[Dict[str, Any]] = None -) -> Dict[str, Any]: - json: Dict[str, Any] = {"query": query} - if variables: - json["variables"] = variables - return json - - -def parse_graphql_response( - github: "GitHubCore", response: "Response[GraphQLResponse]" -) -> Dict[str, Any]: - response_data = response.parsed_data - if response_data.errors: - # check rate limit exceeded - # https://docs.github.com/en/graphql/overview/rate-limits-and-node-limits-for-the-graphql-api#exceeding-the-rate-limit - # x-ratelimit-remaining may not be 0, ignore it - # https://github.com/octokit/plugin-throttling.js/pull/636 - if any(error.type == "RATE_LIMITED" for error in response_data.errors): - raise PrimaryRateLimitExceeded( - response, github._extract_retry_after(response) - ) - raise GraphQLFailed(response_data) - return cast(Dict[str, Any], response_data.data) +class GraphQLNamespace: + def __init__(self, github: "GitHubCore") -> None: + self._github_ref = ref(github) + + @property + def _github(self) -> "GitHubCore": + if g := self._github_ref(): + return g + raise RuntimeError( + "GitHub client has already been collected. " + "Do not use the namespace after the client has been collected." + ) + + @staticmethod + def build_graphql_request( + query: str, variables: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + json: Dict[str, Any] = {"query": query} + if variables: + json["variables"] = variables + return json + + def parse_graphql_response( + self, response: "Response[GraphQLResponse]" + ) -> Dict[str, Any]: + response_data = response.parsed_data + if response_data.errors: + # check rate limit exceeded + # https://docs.github.com/en/graphql/overview/rate-limits-and-node-limits-for-the-graphql-api#exceeding-the-rate-limit + # x-ratelimit-remaining may not be 0, ignore it + # https://github.com/octokit/plugin-throttling.js/pull/636 + if any(error.type == "RATE_LIMITED" for error in response_data.errors): + raise PrimaryRateLimitExceeded( + response, self._github._extract_retry_after(response) + ) + raise GraphQLFailed(response_data) + return cast(Dict[str, Any], response_data.data) + + def _request( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> "Response[GraphQLResponse]": + json = self.build_graphql_request(query, variables) + + return self._github.request( + "POST", "/graphql", json=json, response_model=GraphQLResponse + ) + + def request( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + return self.parse_graphql_response(self._request(query, variables)) + + async def _arequest( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> "Response[GraphQLResponse]": + json = self.build_graphql_request(query, variables) + + return await self._github.arequest( + "POST", "/graphql", json=json, response_model=GraphQLResponse + ) + + async def arequest( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + return self.parse_graphql_response(await self._arequest(query, variables)) + + # backport for calling graphql directly + def __call__( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + return self.request(query, variables) + + def paginate( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> Paginator: + return Paginator(self, query, variables) diff --git a/githubkit/graphql/paginator.py b/githubkit/graphql/paginator.py new file mode 100644 index 000000000..5285c77fc --- /dev/null +++ b/githubkit/graphql/paginator.py @@ -0,0 +1,120 @@ +from weakref import ref +from typing_extensions import Self +from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict + +from githubkit.exception import GraphQLMissingPageInfo, GraphQLMissingCursorChange + +if TYPE_CHECKING: + from githubkit.response import Response + + from . import GraphQLNamespace + from .models import GraphQLResponse + +CURSOR_VARNAME = "cursor" + + +class PageInfo(TypedDict, total=False): + """PageInfo object returned by the GraphQL API. + + See: https://docs.github.com/en/graphql/reference/objects#pageinfo + """ + + hasNextPage: bool + hasPreviousPage: bool + startCursor: str + endCursor: str + + +class Paginator: + def __init__( + self, + graphql: "GraphQLNamespace", + query: str, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + self._graphql_ref = ref(graphql) + self.query = query + + self._has_next_page: bool = True + self._current_variables = variables.copy() if variables is not None else {} + + @property + def _graphql(self) -> "GraphQLNamespace": + if g := self._graphql_ref(): + return g + raise RuntimeError( + "GraphQL client has already been collected. " + "Do not use the paginator after the client has been collected." + ) + + def __next__(self) -> Dict[str, Any]: + if not self._has_next_page: + raise StopIteration + + return self._get_next_page() + + def __iter__(self: Self) -> Self: + return self + + async def __anext__(self) -> Dict[str, Any]: + if not self._has_next_page: + raise StopAsyncIteration + + return await self._aget_next_page() + + def __aiter__(self: Self) -> Self: + return self + + def _extract_page_info(self, data: Dict[str, Any]) -> Optional[PageInfo]: + if "pageInfo" in data: + return data["pageInfo"] + + for value in data.values(): + if isinstance(value, dict): + return self._extract_page_info(value) + + # not found + return None + + def _extract_has_next_page(self, page_info: PageInfo) -> bool: + return ( + page_info["hasNextPage"] + if "hasNextPage" in page_info + else page_info["hasPreviousPage"] # type: ignore + ) + + def _extract_cursor(self, page_info: PageInfo) -> str: + return ( + page_info["endCursor"] # type: ignore + if "hasNextPage" in page_info + else page_info["startCursor"] # type: ignore + ) + + def _do_update(self, response: "Response[GraphQLResponse]") -> Dict[str, Any]: + result = self._graphql.parse_graphql_response(response) + + page_info = self._extract_page_info(result) + if page_info is None: + raise GraphQLMissingPageInfo(response.parsed_data) + + self._has_next_page = self._extract_has_next_page(page_info) + next_cursor = self._extract_cursor(page_info) + + # make sure we don't request the same page again + if ( + self._has_next_page + and CURSOR_VARNAME in self._current_variables + and next_cursor == self._current_variables[CURSOR_VARNAME] + ): + raise GraphQLMissingCursorChange(response.parsed_data) + + self._current_variables[CURSOR_VARNAME] = next_cursor + return result + + def _get_next_page(self) -> Dict[str, Any]: + response = self._graphql._request(self.query, self._current_variables) + return self._do_update(response) + + async def _aget_next_page(self) -> Dict[str, Any]: + response = await self._graphql._arequest(self.query, self._current_variables) + return self._do_update(response) diff --git a/tests/test_graphql/test_graphql.py b/tests/test_graphql/test_graphql.py index 49c104466..a371fec41 100644 --- a/tests/test_graphql/test_graphql.py +++ b/tests/test_graphql/test_graphql.py @@ -12,6 +12,21 @@ } } """ +TEST_PAGINATION_QUERY = """ +query($owner:String!, $repo: String!, $cursor: String) { + repository(owner:$owner, name:$repo) { + issues(first: 10, after: $cursor) { + nodes { + number + } + pageInfo { + hasNextPage + endCursor + } + } + } +} +""" TEST_VARS = {"owner": "yanyongyu", "repo": "githubkit"} TEST_RESULT = {"repository": {"owner": {"login": "yanyongyu"}, "name": "githubkit"}} @@ -25,3 +40,16 @@ def test_graphql(g: GitHub): async def test_async_graphql(g: GitHub): result = await g.async_graphql(TEST_QUERY, TEST_VARS) assert result == TEST_RESULT + + +def test_paginate(g: GitHub): + paginator = g.graphql.paginate(TEST_PAGINATION_QUERY, TEST_VARS) + for result in paginator: + assert isinstance(result["repository"]["issues"]["nodes"], list) + + +@pytest.mark.anyio +async def test_async_paginate(g: GitHub): + paginator = g.graphql.paginate(TEST_PAGINATION_QUERY, TEST_VARS) + async for result in paginator: + assert isinstance(result["repository"]["issues"]["nodes"], list)