Skip to content

Commit

Permalink
Put the stack navigation in a separate GitHub comment (#226)
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
  • Loading branch information
ezyang authored Dec 16, 2023
1 parent d91fb59 commit e8db781
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 80 deletions.
99 changes: 95 additions & 4 deletions ghstack/github_fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,23 @@
},
)

CreateIssueCommentInput = TypedDict(
"CreateIssueCommentInput",
{"body": str},
)

CreateIssueCommentPayload = TypedDict(
"CreateIssueCommentPayload",
{
"id": int,
},
)

UpdateIssueCommentInput = TypedDict(
"UpdateIssueCommentInput",
{"body": str},
)

CreatePullRequestPayload = TypedDict(
"CreatePullRequestPayload",
{
Expand All @@ -57,8 +74,12 @@
class GitHubState:
repositories: Dict[GraphQLId, "Repository"]
pull_requests: Dict[GraphQLId, "PullRequest"]
# This is very inefficient but whatever
issue_comments: Dict[GraphQLId, "IssueComment"]
_next_id: int
# These are indexed by repo id
_next_pull_request_number: Dict[GraphQLId, int]
_next_issue_comment_full_database_id: Dict[GraphQLId, int]
root: "Root"
upstream_sh: Optional[ghstack.shell.Shell]

Expand All @@ -79,6 +100,14 @@ def pull_request(self, repo: "Repository", number: GitHubNumber) -> "PullRequest
)
)

def issue_comment(self, repo: "Repository", comment_id: int) -> "IssueComment":
for comment in self.issue_comments.values():
if repo.id == comment._repository and comment.fullDatabaseId == comment_id:
return comment
raise RuntimeError(
f"unrecognized issue comment {comment_id} in repository {repo.nameWithOwner}"
)

def next_id(self) -> GraphQLId:
r = GraphQLId(str(self._next_id))
self._next_id += 1
Expand All @@ -89,6 +118,11 @@ def next_pull_request_number(self, repo_id: GraphQLId) -> GitHubNumber:
self._next_pull_request_number[repo_id] += 1
return r

def next_issue_comment_full_database_id(self, repo_id: GraphQLId) -> int:
r = self._next_issue_comment_full_database_id[repo_id]
self._next_issue_comment_full_database_id[repo_id] += 1
return r

def push_hook(self, refs: Sequence[str]) -> None:
# updated_refs = set(refs)
# for pr in self.pull_requests:
Expand All @@ -107,8 +141,10 @@ def notify_merged(self, pr_resolved: ghstack.diff.PullRequestResolved) -> None:
def __init__(self, upstream_sh: Optional[ghstack.shell.Shell]) -> None:
self.repositories = {}
self.pull_requests = {}
self.issue_comments = {}
self._next_id = 5000
self._next_pull_request_number = {}
self._next_issue_comment_full_database_id = {}
self.root = Root()

# Populate it with the most important repo ;)
Expand All @@ -121,6 +157,7 @@ def __init__(self, upstream_sh: Optional[ghstack.shell.Shell]) -> None:
)
self.repositories[GraphQLId("1000")] = repo
self._next_pull_request_number[GraphQLId("1000")] = 500
self._next_issue_comment_full_database_id[GraphQLId("1000")] = 1500

self.upstream_sh = upstream_sh
if self.upstream_sh is not None:
Expand Down Expand Up @@ -239,6 +276,16 @@ def repository(self, info: GraphQLResolveInfo) -> Repository:
return github_state(info).repositories[self._repository]


@dataclass
class IssueComment(Node):
body: str
fullDatabaseId: int
_repository: GraphQLId

def repository(self, info: GraphQLResolveInfo) -> Repository:
return github_state(info).repositories[self._repository]


@dataclass
class PullRequestConnection:
nodes: List[PullRequest]
Expand All @@ -253,6 +300,8 @@ def node(self, info: GraphQLResolveInfo, id: GraphQLId) -> Node:
return github_state(info).repositories[id]
elif id in github_state(info).pull_requests:
return github_state(info).pull_requests[id]
elif id in github_state(info).issue_comments:
return github_state(info).issue_comments[id]
else:
raise RuntimeError("unknown id {}".format(id))

Expand All @@ -277,6 +326,7 @@ def set_is_type_of(name: str, cls: Any) -> None:

set_is_type_of("Repository", Repository)
set_is_type_of("PullRequest", PullRequest)
set_is_type_of("IssueComment", IssueComment)


class FakeGitHubEndpoint(ghstack.github.GitHubEndpoint):
Expand Down Expand Up @@ -366,6 +416,35 @@ def _update_pull(
if "body" in input and input["body"] is not None:
pr.body = input["body"]

def _create_issue_comment(
self, owner: str, name: str, comment_id: int, input: CreateIssueCommentInput
) -> CreateIssueCommentPayload:
state = self.state
id = state.next_id()
repo = state.repository(owner, name)
comment_id = state.next_issue_comment_full_database_id(repo.id)
comment = IssueComment(
id=id,
_repository=repo.id,
fullDatabaseId=comment_id,
body=input["body"],
)
state.issue_comments[id] = comment
# This is only a subset of what the actual REST endpoint
# returns.
return {
"id": comment_id,
}

def _update_issue_comment(
self, owner: str, name: str, comment_id: int, input: UpdateIssueCommentInput
) -> None:
state = self.state
repo = state.repository(owner, name)
comment = state.issue_comment(repo, comment_id)
if (r := input.get("body")) is not None:
comment.body = r

# NB: This may have a payload, but we don't
# use it so I didn't bother constructing it.
def _set_default_branch(
Expand All @@ -383,14 +462,19 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any:
raise ghstack.github.NotFoundError()

elif method == "post":
m = re.match(r"^repos/([^/]+)/([^/]+)/pulls$", path)
if m:
if m := re.match(r"^repos/([^/]+)/([^/]+)/pulls$", path):
return self._create_pull(
m.group(1), m.group(2), cast(CreatePullRequestInput, kwargs)
)
if m := re.match(r"^repos/([^/]+)/([^/]+)/issues/([^/]+)/comments", path):
return self._create_issue_comment(
m.group(1),
m.group(2),
GitHubNumber(int(m.group(3))),
cast(CreateIssueCommentInput, kwargs),
)
elif method == "patch":
m = re.match(r"^repos/([^/]+)/([^/]+)(?:/pulls/([^/]+))?$", path)
if m:
if m := re.match(r"^repos/([^/]+)/([^/]+)(?:/pulls/([^/]+))?$", path):
owner, name, number = m.groups()
if number is not None:
return self._update_pull(
Expand All @@ -403,6 +487,13 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any:
return self._set_default_branch(
owner, name, cast(SetDefaultBranchInput, kwargs)
)
if m := re.match(r"^repos/([^/]+)/([^/]+)/issues/comments/([^/]+)$", path):
return self._update_issue_comment(
m.group(1),
m.group(2),
int(m.group(3)),
cast(UpdateIssueCommentInput, kwargs),
)
raise NotImplementedError(
"FakeGitHubEndpoint REST {} {} not implemented".format(method.upper(), path)
)
63 changes: 53 additions & 10 deletions ghstack/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,17 @@ class PreBranchState:
RE_GHSTACK_SOURCE_ID = re.compile(r"^ghstack-source-id: (.+)\n?", re.MULTILINE)


# When we make a GitHub PR using --direct, we submit an extra comment which
# contains the links to the rest of the PRs in the stack. We don't put this
# inside the pull request body, because if you squash merge the PR, that body
# gets put into the commit message, but the stack information is just line
# noise and shouldn't go there.
#
# We can technically find the ghstack commit by querying GitHub API for all
# comments, but this is a more efficient way of getting it.
RE_GHSTACK_COMMENT_ID = re.compile(r"^ghstack-comment-id: (.+)\n?", re.MULTILINE)


# repo layout:
# - gh/username/23/head -- what we think GitHub's current tip for commit is
# - gh/username/23/base -- what we think base commit for commit is
Expand Down Expand Up @@ -175,6 +186,8 @@ class DiffWithGitHubMetadata:
username: str
# Really ought not to be optional, but for BC reasons it might be
remote_source_id: Optional[str]
# Guaranteed to be set for --direct PRs
comment_id: Optional[int]
title: str
body: str
closed: bool
Expand Down Expand Up @@ -856,6 +869,8 @@ def elaborate_diff(
remote_summary = ghstack.git.split_header(rev_list)[0]
m_remote_source_id = RE_GHSTACK_SOURCE_ID.search(remote_summary.commit_msg)
remote_source_id = m_remote_source_id.group(1) if m_remote_source_id else None
m_comment_id = RE_GHSTACK_COMMENT_ID.search(remote_summary.commit_msg)
comment_id = int(m_comment_id.group(1)) if m_comment_id else None

return DiffWithGitHubMetadata(
diff=diff,
Expand All @@ -866,6 +881,7 @@ def elaborate_diff(
username=username,
ghnum=gh_number,
remote_source_id=remote_source_id,
comment_id=comment_id,
pull_request_resolved=diff.pull_request_resolved,
head_ref=r["headRefName"],
base_ref=r["baseRefName"],
Expand Down Expand Up @@ -932,10 +948,15 @@ def process_commit(
commit_msg = self._update_source_id(diff.summary, elab_diff)
else:
# Need to insert metadata for the first time
commit_msg = (
f"{strip_mentions(diff.summary.rstrip())}\n\n"
f"ghstack-source-id: {diff.source_id}\n"
f"Pull Request resolved: {pull_request_resolved.url()}"
commit_msg = "".join(
[
f"{strip_mentions(diff.summary.rstrip())}\n\n",
f"ghstack-source-id: {diff.source_id}\n",
f"ghstack-comment-id: {elab_diff.comment_id}\n"
if self.direct
else "",
f"Pull Request resolved: {pull_request_resolved.url()}",
]
)

return DiffMeta(
Expand Down Expand Up @@ -1382,6 +1403,14 @@ def _create_pull_request(
)
number = r["number"]

comment_id = None
if self.direct:
rc = self.github.post(
f"repos/{self.repo_owner}/{self.repo_name}/issues/{number}/comments",
body=f"{self.stack_header}:\n* (to be filled)",
)
comment_id = rc["id"]

logging.info("Opened PR #{}".format(number))

pull_request_resolved = ghstack.diff.PullRequestResolved(
Expand All @@ -1396,6 +1425,7 @@ def _create_pull_request(
number=number,
username=self.username,
remote_source_id=diff.source_id, # in sync
comment_id=comment_id,
title=title,
body=body,
closed=False,
Expand Down Expand Up @@ -1432,18 +1462,26 @@ def push_updates(
base_kwargs["base"] = s.base
else:
assert s.base == s.elab_diff.base_ref
stack_desc = self._format_stack(diffs_to_submit, s.number)
self.github.patch(
"repos/{owner}/{repo}/pulls/{number}".format(
owner=self.repo_owner, repo=self.repo_name, number=s.number
),
# NB: this substitution does nothing on direct PRs
body=RE_STACK.sub(
self._format_stack(diffs_to_submit, s.number),
stack_desc,
s.body,
),
title=s.title,
**base_kwargs,
)

if s.elab_diff.comment_id is not None:
self.github.patch(
f"repos/{self.repo_owner}/{self.repo_name}/issues/comments/{s.elab_diff.comment_id}",
body=stack_desc,
)

# It is VERY important that we do base updates BEFORE real
# head updates, otherwise GitHub will spuriously think that
# the user pushed a number of patches as part of the PR,
Expand Down Expand Up @@ -1703,16 +1741,21 @@ def _default_title_and_body(
# Don't store ghstack-source-id in the PR body; it will become
# stale quickly
commit_body = RE_GHSTACK_SOURCE_ID.sub("", commit_body)
# Comment ID is not necessary; source of truth is orig commit
commit_body = RE_GHSTACK_COMMENT_ID.sub("", commit_body)
# Don't store Pull request resolved in the PR body; it's
# unnecessary
commit_body = ghstack.diff.re_pull_request_resolved_w_sp(self.github_url).sub(
"", commit_body
)
if starts_with_bullet(commit_body):
commit_body = f"----\n\n{commit_body}"
pr_body = "{}:\n* (to be filled)\n\n{}{}".format(
self.stack_header, commit_body, extra
)
if self.direct:
pr_body = f"{commit_body}{extra}"
else:
if starts_with_bullet(commit_body):
commit_body = f"----\n\n{commit_body}"
pr_body = "{}:\n* (to be filled)\n\n{}{}".format(
self.stack_header, commit_body, extra
)
return title, pr_body

def _git_push(self, branches: Sequence[str], force: bool = False) -> None:
Expand Down
Loading

0 comments on commit e8db781

Please sign in to comment.