diff --git a/ghstack/github_fake.py b/ghstack/github_fake.py index cb27ae6..d82e969 100644 --- a/ghstack/github_fake.py +++ b/ghstack/github_fake.py @@ -45,6 +45,23 @@ }, ) +CreateIssueCommentInput = TypedDict( + "CreateIssueCommentInput", + {"body": str}, +) + +CreateIssueCommentPayload = TypedDict( + "CreateIssueCommentPayload", + { + "id": int, + }, +) + +UpdateIssueCommentInput = TypedDict( + "UpdateIssueCommentInput", + {"body": str}, +) + CreatePullRequestPayload = TypedDict( "CreatePullRequestPayload", { @@ -57,8 +74,12 @@ class GitHubState: repositories: Dict[GraphQLId, "Repository"] pull_requests: Dict[GraphQLId, "PullRequest"] + # This is very inefficient but whatever + issue_comments: Dict[GraphQLId, "IssueComment"] _next_id: int + # These are indexed by repo id _next_pull_request_number: Dict[GraphQLId, int] + _next_issue_comment_full_database_id: Dict[GraphQLId, int] root: "Root" upstream_sh: Optional[ghstack.shell.Shell] @@ -79,6 +100,14 @@ def pull_request(self, repo: "Repository", number: GitHubNumber) -> "PullRequest ) ) + def issue_comment(self, repo: "Repository", comment_id: int) -> "IssueComment": + for comment in self.issue_comments.values(): + if repo.id == comment._repository and comment.fullDatabaseId == comment_id: + return comment + raise RuntimeError( + f"unrecognized issue comment {comment_id} in repository {repo.nameWithOwner}" + ) + def next_id(self) -> GraphQLId: r = GraphQLId(str(self._next_id)) self._next_id += 1 @@ -89,6 +118,11 @@ def next_pull_request_number(self, repo_id: GraphQLId) -> GitHubNumber: self._next_pull_request_number[repo_id] += 1 return r + def next_issue_comment_full_database_id(self, repo_id: GraphQLId) -> int: + r = self._next_issue_comment_full_database_id[repo_id] + self._next_issue_comment_full_database_id[repo_id] += 1 + return r + def push_hook(self, refs: Sequence[str]) -> None: # updated_refs = set(refs) # for pr in self.pull_requests: @@ -107,8 +141,10 @@ def notify_merged(self, pr_resolved: ghstack.diff.PullRequestResolved) -> None: def __init__(self, upstream_sh: Optional[ghstack.shell.Shell]) -> None: self.repositories = {} self.pull_requests = {} + self.issue_comments = {} self._next_id = 5000 self._next_pull_request_number = {} + self._next_issue_comment_full_database_id = {} self.root = Root() # Populate it with the most important repo ;) @@ -121,6 +157,7 @@ def __init__(self, upstream_sh: Optional[ghstack.shell.Shell]) -> None: ) self.repositories[GraphQLId("1000")] = repo self._next_pull_request_number[GraphQLId("1000")] = 500 + self._next_issue_comment_full_database_id[GraphQLId("1000")] = 1500 self.upstream_sh = upstream_sh if self.upstream_sh is not None: @@ -239,6 +276,16 @@ def repository(self, info: GraphQLResolveInfo) -> Repository: return github_state(info).repositories[self._repository] +@dataclass +class IssueComment(Node): + body: str + fullDatabaseId: int + _repository: GraphQLId + + def repository(self, info: GraphQLResolveInfo) -> Repository: + return github_state(info).repositories[self._repository] + + @dataclass class PullRequestConnection: nodes: List[PullRequest] @@ -253,6 +300,8 @@ def node(self, info: GraphQLResolveInfo, id: GraphQLId) -> Node: return github_state(info).repositories[id] elif id in github_state(info).pull_requests: return github_state(info).pull_requests[id] + elif id in github_state(info).issue_comments: + return github_state(info).issue_comments[id] else: raise RuntimeError("unknown id {}".format(id)) @@ -277,6 +326,7 @@ def set_is_type_of(name: str, cls: Any) -> None: set_is_type_of("Repository", Repository) set_is_type_of("PullRequest", PullRequest) +set_is_type_of("IssueComment", IssueComment) class FakeGitHubEndpoint(ghstack.github.GitHubEndpoint): @@ -366,6 +416,35 @@ def _update_pull( if "body" in input and input["body"] is not None: pr.body = input["body"] + def _create_issue_comment( + self, owner: str, name: str, comment_id: int, input: CreateIssueCommentInput + ) -> CreateIssueCommentPayload: + state = self.state + id = state.next_id() + repo = state.repository(owner, name) + comment_id = state.next_issue_comment_full_database_id(repo.id) + comment = IssueComment( + id=id, + _repository=repo.id, + fullDatabaseId=comment_id, + body=input["body"], + ) + state.issue_comments[id] = comment + # This is only a subset of what the actual REST endpoint + # returns. + return { + "id": comment_id, + } + + def _update_issue_comment( + self, owner: str, name: str, comment_id: int, input: UpdateIssueCommentInput + ) -> None: + state = self.state + repo = state.repository(owner, name) + comment = state.issue_comment(repo, comment_id) + if (r := input.get("body")) is not None: + comment.body = r + # NB: This may have a payload, but we don't # use it so I didn't bother constructing it. def _set_default_branch( @@ -383,14 +462,19 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: raise ghstack.github.NotFoundError() elif method == "post": - m = re.match(r"^repos/([^/]+)/([^/]+)/pulls$", path) - if m: + if m := re.match(r"^repos/([^/]+)/([^/]+)/pulls$", path): return self._create_pull( m.group(1), m.group(2), cast(CreatePullRequestInput, kwargs) ) + if m := re.match(r"^repos/([^/]+)/([^/]+)/issues/([^/]+)/comments", path): + return self._create_issue_comment( + m.group(1), + m.group(2), + GitHubNumber(int(m.group(3))), + cast(CreateIssueCommentInput, kwargs), + ) elif method == "patch": - m = re.match(r"^repos/([^/]+)/([^/]+)(?:/pulls/([^/]+))?$", path) - if m: + if m := re.match(r"^repos/([^/]+)/([^/]+)(?:/pulls/([^/]+))?$", path): owner, name, number = m.groups() if number is not None: return self._update_pull( @@ -403,6 +487,13 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: return self._set_default_branch( owner, name, cast(SetDefaultBranchInput, kwargs) ) + if m := re.match(r"^repos/([^/]+)/([^/]+)/issues/comments/([^/]+)$", path): + return self._update_issue_comment( + m.group(1), + m.group(2), + int(m.group(3)), + cast(UpdateIssueCommentInput, kwargs), + ) raise NotImplementedError( "FakeGitHubEndpoint REST {} {} not implemented".format(method.upper(), path) ) diff --git a/ghstack/submit.py b/ghstack/submit.py index d95c66a..8b9f66f 100644 --- a/ghstack/submit.py +++ b/ghstack/submit.py @@ -110,6 +110,17 @@ class PreBranchState: RE_GHSTACK_SOURCE_ID = re.compile(r"^ghstack-source-id: (.+)\n?", re.MULTILINE) +# When we make a GitHub PR using --direct, we submit an extra comment which +# contains the links to the rest of the PRs in the stack. We don't put this +# inside the pull request body, because if you squash merge the PR, that body +# gets put into the commit message, but the stack information is just line +# noise and shouldn't go there. +# +# We can technically find the ghstack commit by querying GitHub API for all +# comments, but this is a more efficient way of getting it. +RE_GHSTACK_COMMENT_ID = re.compile(r"^ghstack-comment-id: (.+)\n?", re.MULTILINE) + + # repo layout: # - gh/username/23/head -- what we think GitHub's current tip for commit is # - gh/username/23/base -- what we think base commit for commit is @@ -175,6 +186,8 @@ class DiffWithGitHubMetadata: username: str # Really ought not to be optional, but for BC reasons it might be remote_source_id: Optional[str] + # Guaranteed to be set for --direct PRs + comment_id: Optional[int] title: str body: str closed: bool @@ -856,6 +869,8 @@ def elaborate_diff( remote_summary = ghstack.git.split_header(rev_list)[0] m_remote_source_id = RE_GHSTACK_SOURCE_ID.search(remote_summary.commit_msg) remote_source_id = m_remote_source_id.group(1) if m_remote_source_id else None + m_comment_id = RE_GHSTACK_COMMENT_ID.search(remote_summary.commit_msg) + comment_id = int(m_comment_id.group(1)) if m_comment_id else None return DiffWithGitHubMetadata( diff=diff, @@ -866,6 +881,7 @@ def elaborate_diff( username=username, ghnum=gh_number, remote_source_id=remote_source_id, + comment_id=comment_id, pull_request_resolved=diff.pull_request_resolved, head_ref=r["headRefName"], base_ref=r["baseRefName"], @@ -932,10 +948,15 @@ def process_commit( commit_msg = self._update_source_id(diff.summary, elab_diff) else: # Need to insert metadata for the first time - commit_msg = ( - f"{strip_mentions(diff.summary.rstrip())}\n\n" - f"ghstack-source-id: {diff.source_id}\n" - f"Pull Request resolved: {pull_request_resolved.url()}" + commit_msg = "".join( + [ + f"{strip_mentions(diff.summary.rstrip())}\n\n", + f"ghstack-source-id: {diff.source_id}\n", + f"ghstack-comment-id: {elab_diff.comment_id}\n" + if self.direct + else "", + f"Pull Request resolved: {pull_request_resolved.url()}", + ] ) return DiffMeta( @@ -1382,6 +1403,14 @@ def _create_pull_request( ) number = r["number"] + comment_id = None + if self.direct: + rc = self.github.post( + f"repos/{self.repo_owner}/{self.repo_name}/issues/{number}/comments", + body=f"{self.stack_header}:\n* (to be filled)", + ) + comment_id = rc["id"] + logging.info("Opened PR #{}".format(number)) pull_request_resolved = ghstack.diff.PullRequestResolved( @@ -1396,6 +1425,7 @@ def _create_pull_request( number=number, username=self.username, remote_source_id=diff.source_id, # in sync + comment_id=comment_id, title=title, body=body, closed=False, @@ -1432,18 +1462,26 @@ def push_updates( base_kwargs["base"] = s.base else: assert s.base == s.elab_diff.base_ref + stack_desc = self._format_stack(diffs_to_submit, s.number) self.github.patch( "repos/{owner}/{repo}/pulls/{number}".format( owner=self.repo_owner, repo=self.repo_name, number=s.number ), + # NB: this substitution does nothing on direct PRs body=RE_STACK.sub( - self._format_stack(diffs_to_submit, s.number), + stack_desc, s.body, ), title=s.title, **base_kwargs, ) + if s.elab_diff.comment_id is not None: + self.github.patch( + f"repos/{self.repo_owner}/{self.repo_name}/issues/comments/{s.elab_diff.comment_id}", + body=stack_desc, + ) + # It is VERY important that we do base updates BEFORE real # head updates, otherwise GitHub will spuriously think that # the user pushed a number of patches as part of the PR, @@ -1703,16 +1741,21 @@ def _default_title_and_body( # Don't store ghstack-source-id in the PR body; it will become # stale quickly commit_body = RE_GHSTACK_SOURCE_ID.sub("", commit_body) + # Comment ID is not necessary; source of truth is orig commit + commit_body = RE_GHSTACK_COMMENT_ID.sub("", commit_body) # Don't store Pull request resolved in the PR body; it's # unnecessary commit_body = ghstack.diff.re_pull_request_resolved_w_sp(self.github_url).sub( "", commit_body ) - if starts_with_bullet(commit_body): - commit_body = f"----\n\n{commit_body}" - pr_body = "{}:\n* (to be filled)\n\n{}{}".format( - self.stack_header, commit_body, extra - ) + if self.direct: + pr_body = f"{commit_body}{extra}" + else: + if starts_with_bullet(commit_body): + commit_body = f"----\n\n{commit_body}" + pr_body = "{}:\n* (to be filled)\n\n{}{}".format( + self.stack_header, commit_body, extra + ) return title, pr_body def _git_push(self, branches: Sequence[str], force: bool = False) -> None: diff --git a/test_ghstack.py b/test_ghstack.py index 22b2755..95ca841 100644 --- a/test_ghstack.py +++ b/test_ghstack.py @@ -38,6 +38,9 @@ DIRECT = False +# TODO: replicate github commit list + + @contextlib.contextmanager def use_direct() -> Iterator[None]: global DIRECT @@ -408,10 +411,6 @@ def test_direct_simple(self) -> None: """\ [O] #500 Commit A (gh/ezyang/1/head -> master) - Stack: - * #501 - * __->__ #500 - * c3ca023 (gh/ezyang/1/next, gh/ezyang/1/head) @@ -421,10 +420,6 @@ def test_direct_simple(self) -> None: [O] #501 Commit B (gh/ezyang/2/head -> gh/ezyang/1/head) - Stack: - * __->__ #501 - * #500 - * 09a6970 (gh/ezyang/2/next, gh/ezyang/2/head) @@ -699,9 +694,6 @@ def test_direct_amend(self) -> None: """\ [O] #500 Commit A (gh/ezyang/1/head -> master) - Stack: - * __->__ #500 - * e3902de (gh/ezyang/1/next, gh/ezyang/1/head) @@ -857,10 +849,6 @@ def test_direct_multi(self) -> None: """\ [O] #500 Commit A (gh/ezyang/1/head -> master) - Stack: - * #501 - * __->__ #500 - * c5b379e (gh/ezyang/1/next, gh/ezyang/1/head) @@ -870,10 +858,6 @@ def test_direct_multi(self) -> None: [O] #501 Commit B (gh/ezyang/2/head -> gh/ezyang/1/head) - Stack: - * __->__ #501 - * #500 - * fd9fc99 (gh/ezyang/2/next, gh/ezyang/2/head) @@ -948,10 +932,6 @@ def test_direct_amend_top(self) -> None: """\ [O] #500 Commit A (gh/ezyang/1/head -> master) - Stack: - * #501 - * __->__ #500 - * c3ca023 (gh/ezyang/1/next, gh/ezyang/1/head) @@ -961,10 +941,6 @@ def test_direct_amend_top(self) -> None: [O] #501 Commit B (gh/ezyang/2/head -> gh/ezyang/1/head) - Stack: - * __->__ #501 - * #500 - * 20bbb07 (gh/ezyang/2/next, gh/ezyang/2/head) @@ -1051,10 +1027,6 @@ def test_direct_amend_bottom(self) -> None: """\ [O] #500 Commit A (gh/ezyang/1/head -> master) - Stack: - * #501 - * __->__ #500 - * f22b24c (gh/ezyang/1/next, gh/ezyang/1/head) @@ -1066,10 +1038,6 @@ def test_direct_amend_bottom(self) -> None: [O] #501 Commit B (gh/ezyang/2/head -> gh/ezyang/1/head) - Stack: - * __->__ #501 - * #500 - * 165ebd2 (gh/ezyang/2/next, gh/ezyang/2/head) @@ -1158,10 +1126,6 @@ def test_direct_amend_all(self) -> None: """\ [O] #500 Commit A (gh/ezyang/1/head -> master) - Stack: - * #501 - * __->__ #500 - * 9d56b39 (gh/ezyang/1/next, gh/ezyang/1/head) @@ -1173,10 +1137,6 @@ def test_direct_amend_all(self) -> None: [O] #501 Commit B (gh/ezyang/2/head -> gh/ezyang/1/head) - Stack: - * __->__ #501 - * #500 - * e3873c9 (gh/ezyang/2/next, gh/ezyang/2/head) @@ -1275,10 +1235,6 @@ def test_direct_rebase(self) -> None: """\ [O] #500 Commit A (gh/ezyang/1/head -> master) - Stack: - * #501 - * __->__ #500 - * ad37802 (gh/ezyang/1/next, gh/ezyang/1/head) @@ -1292,10 +1248,6 @@ def test_direct_rebase(self) -> None: [O] #501 Commit B (gh/ezyang/2/head -> gh/ezyang/1/head) - Stack: - * __->__ #501 - * #500 - * 1d1ca2d (gh/ezyang/2/next, gh/ezyang/2/head) @@ -1385,10 +1337,6 @@ def test_direct_cherry_pick(self) -> None: """\ [O] #500 Commit A (gh/ezyang/1/head -> master) - Stack: - * #501 - * __->__ #500 - * 2949b6b (gh/ezyang/1/next, gh/ezyang/1/head) @@ -1398,9 +1346,6 @@ def test_direct_cherry_pick(self) -> None: [O] #501 Commit B (gh/ezyang/2/head -> master) - Stack: - * __->__ #501 - * fd891f3 (gh/ezyang/2/next, gh/ezyang/2/head) @@ -1485,10 +1430,6 @@ def test_direct_reorder(self) -> None: """\ [O] #500 Commit A (gh/ezyang/1/head -> gh/ezyang/2/head) - Stack: - * __->__ #500 - * #501 - * 3a17667 (gh/ezyang/1/next, gh/ezyang/1/head) @@ -1504,10 +1445,6 @@ def test_direct_reorder(self) -> None: [O] #501 Commit B (gh/ezyang/2/head -> master) - Stack: - * #500 - * __->__ #501 - * 5f812b3 (gh/ezyang/2/next, gh/ezyang/2/head)