Skip to content

Commit

Permalink
ci: Allow upstream git refs to be used for benchmarking (#3730)
Browse files Browse the repository at this point in the history
Had trouble running the benchmarking tool when local branch names didn't
match remote branch names. Fix it so that we check for upstream branch
names and use them.

E.g. if I run the tool on `local-branch-name`, it finds
`origin/user/remote-branch-name` then runs the action on
`user/remote-branch-name`.
  • Loading branch information
desmondcheongzx authored Jan 30, 2025
1 parent 684505b commit 3462732
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions tools/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,36 @@ def get_latest_run(workflow: Workflow) -> WorkflowRun:
raise RuntimeError("Unable to list all workflow invocations")


def get_name_and_commit_hash(branch_name: Optional[str]) -> tuple[str, str]:
branch_name = branch_name or "HEAD"
name = (
subprocess.check_output(["git", "rev-parse", "--abbrev-ref", branch_name], stderr=subprocess.STDOUT)
def get_name_and_commit_hash(local_branch_name: Optional[str]) -> tuple[str, str]:
local_branch_name = local_branch_name or "HEAD"
remote_branch_name = local_branch_name

try:
# Check if the branch has a remote tracking branch.
local_branch_name = (
subprocess.check_output(
["git", "rev-parse", "--abbrev-ref", f"{local_branch_name}@{{upstream}}"], stderr=subprocess.STDOUT
)
.strip()
.decode("utf-8")
)
# Strip the upstream name from the branch to get the branch name on the remote repo.
remote_branch_name = local_branch_name.split("/", 1)[1]
except subprocess.CalledProcessError:
local_branch_name = (
subprocess.check_output(["git", "rev-parse", "--abbrev-ref", local_branch_name], stderr=subprocess.STDOUT)
.strip()
.decode("utf-8")
)
remote_branch_name = local_branch_name

commit_hash = (
subprocess.check_output(["git", "rev-parse", local_branch_name], stderr=subprocess.STDOUT)
.strip()
.decode("utf-8")
)
commit_hash = (
subprocess.check_output(["git", "rev-parse", branch_name], stderr=subprocess.STDOUT).strip().decode("utf-8")
)
return name, commit_hash
# Return the remote branch name for the github action.
return remote_branch_name, commit_hash


def parse_questions(questions: Optional[str], total_number_of_questions: int) -> list[int]:
Expand Down

0 comments on commit 3462732

Please sign in to comment.