From 919165c87f110c7991041466619883425247504f Mon Sep 17 00:00:00 2001 From: Lily Kuang Date: Fri, 12 Mar 2021 16:58:22 -0800 Subject: [PATCH] lint mypy --- RELEASING/changelog.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/RELEASING/changelog.py b/RELEASING/changelog.py index 0b178c25bf7c5..11a4fed24b6f1 100644 --- a/RELEASING/changelog.py +++ b/RELEASING/changelog.py @@ -26,7 +26,7 @@ import click try: - from github import BadCredentialsException, Github, PullRequest + from github import BadCredentialsException, Github, PullRequest, Repository except ModuleNotFoundError: print("PyGithub is a required package for this script") exit(1) @@ -73,14 +73,14 @@ def __init__( ) -> None: self._version = version self._logs = logs - self._pr_logs_with_details: Dict[str, Dict[str, str]] = {} + self._pr_logs_with_details: Dict[int, Dict[str, Any]] = {} self._github_login_cache: Dict[str, Optional[str]] = {} - self._github_prs: Dict[str, Any] = {} + self._github_prs: Dict[int, Any] = {} self._wait = 10 github_token = access_token or os.environ.get("GITHUB_TOKEN") self._github = Github(github_token) self._show_risk = risk - self._superset_repo = "" + self._superset_repo: Repository = None def _fetch_github_pr(self, pr_number: int) -> PullRequest: """ @@ -129,18 +129,20 @@ def _has_commit_migrations(self, git_sha: str) -> bool: def _get_pull_request_details(self, git_log: GitLog) -> Dict[str, Any]: pr_number = git_log.pr_number - detail = self._pr_logs_with_details.get(pr_number) - if detail: - return detail + if pr_number: + detail = self._pr_logs_with_details.get(pr_number) + if detail: + return detail + else: + pr_info = self._fetch_github_pr(pr_number) - pr_info = self._fetch_github_pr(pr_number) has_migrations = self._has_commit_migrations(git_log.sha) title = pr_info.title if pr_info else git_log.message pr_type = re.match(r"^(fix|feat|chore|refactor|docs|build|ci|/gmi)", title) if pr_type: pr_type = pr_type.group().strip('"') else: - pr_type = "" + pr_type = None labels = (" | ").join([label.name for label in pr_info.labels]) is_risky = self._is_risk_pull_request(pr_info.labels) @@ -153,11 +155,12 @@ def _get_pull_request_details(self, git_log: GitLog) -> Dict[str, Any]: "is_risky": is_risky or has_migrations, } - self._pr_logs_with_details[pr_number] = detail + if pr_number: + self._pr_logs_with_details[pr_number] = detail return detail - def _is_risk_pull_request(self, labels: [str]) -> bool: + def _is_risk_pull_request(self, labels: List[Any]) -> bool: for label in labels: risk_label = re.match( r"^(blocking|risk|hold|revert|security vulnerability)", label.name @@ -171,7 +174,7 @@ def _get_changelog_version_head(self) -> str: def _parse_change_log( self, changelog: Dict[str, str], pr_info: Dict[str, str], github_login: str, - ) -> Dict[str, str]: + ): formatted_pr = ( f"- [#{pr_info.get('id')}]" f"(https://github.com/{SUPERSET_REPO}/pull/{pr_info.get('id')}) "