Skip to content

Commit

Permalink
feature(pre-commit): Keeping up with the Joneses
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley committed Jun 1, 2023
1 parent f6e769a commit b1cda7c
Show file tree
Hide file tree
Showing 448 changed files with 3,084 additions and 3,305 deletions.
22 changes: 18 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,28 @@
# limitations under the License.
#
repos:
- repo: https://github.com/MarcoGorelli/auto-walrus
rev: v0.2.2
hooks:
- id: auto-walrus
- repo: https://github.com/asottile/pyupgrade
rev: v3.4.0
hooks:
- id: pyupgrade
args:
- --py39-plus
- repo: https://github.com/hadialqattan/pycln
rev: v2.1.2
hooks:
- id: pycln
args:
- --disable-all-dunder-policy
- --exclude=superset/config.py
- --extend-exclude=tests/integration_tests/superset_test_config.*.py
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/MarcoGorelli/auto-walrus
rev: v0.2.2
hooks:
- id: auto-walrus
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
hooks:
Expand Down
29 changes: 15 additions & 14 deletions RELEASING/changelog.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import os
import re
import sys
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Union
from typing import Any, Optional, Union

import click
from click.core import Context
Expand Down Expand Up @@ -67,15 +68,15 @@ class GitChangeLog:
def __init__(
self,
version: str,
logs: List[GitLog],
logs: list[GitLog],
access_token: Optional[str] = None,
risk: Optional[bool] = False,
) -> None:
self._version = version
self._logs = logs
self._pr_logs_with_details: Dict[int, Dict[str, Any]] = {}
self._github_login_cache: Dict[str, Optional[str]] = {}
self._github_prs: Dict[int, Any] = {}
self._pr_logs_with_details: dict[int, dict[str, Any]] = {}
self._github_login_cache: dict[str, Optional[str]] = {}
self._github_prs: dict[int, Any] = {}
self._wait = 10
github_token = access_token or os.environ.get("GITHUB_TOKEN")
self._github = Github(github_token)
Expand Down Expand Up @@ -126,7 +127,7 @@ def _has_commit_migrations(self, git_sha: str) -> bool:
"superset/migrations/versions/" in file.filename for file in commit.files
)

def _get_pull_request_details(self, git_log: GitLog) -> Dict[str, Any]:
def _get_pull_request_details(self, git_log: GitLog) -> dict[str, Any]:
pr_number = git_log.pr_number
if pr_number:
detail = self._pr_logs_with_details.get(pr_number)
Expand Down Expand Up @@ -156,7 +157,7 @@ def _get_pull_request_details(self, git_log: GitLog) -> Dict[str, Any]:

return detail

def _is_risk_pull_request(self, labels: List[Any]) -> bool:
def _is_risk_pull_request(self, labels: list[Any]) -> bool:
for label in labels:
risk_label = re.match(SUPERSET_RISKY_LABELS, label.name)
if risk_label is not None:
Expand All @@ -174,8 +175,8 @@ def _get_changelog_version_head(self) -> str:

def _parse_change_log(
self,
changelog: Dict[str, str],
pr_info: Dict[str, str],
changelog: dict[str, str],
pr_info: dict[str, str],
github_login: str,
) -> None:
formatted_pr = (
Expand Down Expand Up @@ -227,7 +228,7 @@ def __repr__(self) -> str:
result += f"**{key}** {changelog[key]}\n"
return result

def __iter__(self) -> Iterator[Dict[str, Any]]:
def __iter__(self) -> Iterator[dict[str, Any]]:
for log in self._logs:
yield {
"pr_number": log.pr_number,
Expand All @@ -250,20 +251,20 @@ class GitLogs:

def __init__(self, git_ref: str) -> None:
self._git_ref = git_ref
self._logs: List[GitLog] = []
self._logs: list[GitLog] = []

@property
def git_ref(self) -> str:
return self._git_ref

@property
def logs(self) -> List[GitLog]:
def logs(self) -> list[GitLog]:
return self._logs

def fetch(self) -> None:
self._logs = list(map(self._parse_log, self._git_logs()))[::-1]

def diff(self, git_logs: "GitLogs") -> List[GitLog]:
def diff(self, git_logs: "GitLogs") -> list[GitLog]:
return [log for log in git_logs.logs if log not in self._logs]

def __repr__(self) -> str:
Expand All @@ -284,7 +285,7 @@ def _git_checkout(self, git_ref: str) -> None:
print(f"Could not checkout {git_ref}")
sys.exit(1)

def _git_logs(self) -> List[str]:
def _git_logs(self) -> list[str]:
# let's get current git ref so we can revert it back
current_git_ref = self._git_get_current_head()
self._git_checkout(self._git_ref)
Expand Down
8 changes: 4 additions & 4 deletions RELEASING/generate_email.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Dict, List
from typing import Any

from click.core import Context

Expand All @@ -34,7 +34,7 @@
PROJECT_DESCRIPTION = "Apache Superset is a modern, enterprise-ready business intelligence web application"


def string_comma_to_list(message: str) -> List[str]:
def string_comma_to_list(message: str) -> list[str]:
if not message:
return []
return [element.strip() for element in message.split(",")]
Expand All @@ -52,15 +52,15 @@ def render_template(template_file: str, **kwargs: Any) -> str:
return template.render(kwargs)


class BaseParameters(object):
class BaseParameters:
def __init__(
self,
version: str,
version_rc: str,
) -> None:
self.version = version
self.version_rc = version_rc
self.template_arguments: Dict[str, Any] = {}
self.template_arguments: dict[str, Any] = {}

def __repr__(self) -> str:
return f"Apache Credentials: {self.version}/{self.version_rc}"
Expand Down
7 changes: 3 additions & 4 deletions docker/pythonpath_dev/superset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#
import logging
import os
from datetime import timedelta
from typing import Optional

from cachelib.file import FileSystemCache
Expand All @@ -42,7 +41,7 @@ def get_env_variable(var_name: str, default: Optional[str] = None) -> str:
error_msg = "The environment variable {} was missing, abort...".format(
var_name
)
raise EnvironmentError(error_msg)
raise OSError(error_msg)


DATABASE_DIALECT = get_env_variable("DATABASE_DIALECT")
Expand All @@ -53,7 +52,7 @@ def get_env_variable(var_name: str, default: Optional[str] = None) -> str:
DATABASE_DB = get_env_variable("DATABASE_DB")

# The SQLAlchemy connection string.
SQLALCHEMY_DATABASE_URI = "%s://%s:%s@%s:%s/%s" % (
SQLALCHEMY_DATABASE_URI = "{}://{}:{}@{}:{}/{}".format(
DATABASE_DIALECT,
DATABASE_USER,
DATABASE_PASSWORD,
Expand All @@ -80,7 +79,7 @@ def get_env_variable(var_name: str, default: Optional[str] = None) -> str:
DATA_CACHE_CONFIG = CACHE_CONFIG


class CeleryConfig(object):
class CeleryConfig:
broker_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"
imports = ("superset.sql_lab",)
result_backend = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_RESULTS_DB}"
Expand Down
24 changes: 11 additions & 13 deletions scripts/benchmark_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from inspect import getsource
from pathlib import Path
from types import ModuleType
from typing import Any, Dict, List, Set, Type
from typing import Any

import click
from flask import current_app
Expand All @@ -48,12 +48,10 @@ def import_migration_script(filepath: Path) -> ModuleType:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module
raise Exception(
"No module spec found in location: `{path}`".format(path=str(filepath))
)
raise Exception(f"No module spec found in location: `{str(filepath)}`")


def extract_modified_tables(module: ModuleType) -> Set[str]:
def extract_modified_tables(module: ModuleType) -> set[str]:
"""
Extract the tables being modified by a migration script.
Expand All @@ -62,7 +60,7 @@ def extract_modified_tables(module: ModuleType) -> Set[str]:
actually traversing the AST.
"""

tables: Set[str] = set()
tables: set[str] = set()
for function in {"upgrade", "downgrade"}:
source = getsource(getattr(module, function))
tables.update(re.findall(r'alter_table\(\s*"(\w+?)"\s*\)', source, re.DOTALL))
Expand All @@ -72,11 +70,11 @@ def extract_modified_tables(module: ModuleType) -> Set[str]:
return tables


def find_models(module: ModuleType) -> List[Type[Model]]:
def find_models(module: ModuleType) -> list[type[Model]]:
"""
Find all models in a migration script.
"""
models: List[Type[Model]] = []
models: list[type[Model]] = []
tables = extract_modified_tables(module)

# add models defined explicitly in the migration script
Expand Down Expand Up @@ -123,7 +121,7 @@ def find_models(module: ModuleType) -> List[Type[Model]]:
sorter: TopologicalSorter[Any] = TopologicalSorter()
for model in models:
inspector = inspect(model)
dependent_tables: List[str] = []
dependent_tables: list[str] = []
for column in inspector.columns.values():
for foreign_key in column.foreign_keys:
if foreign_key.column.table.name != model.__tablename__:
Expand Down Expand Up @@ -174,30 +172,30 @@ def main(

print("\nIdentifying models used in the migration:")
models = find_models(module)
model_rows: Dict[Type[Model], int] = {}
model_rows: dict[type[Model], int] = {}
for model in models:
rows = session.query(model).count()
print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
model_rows[model] = rows
session.close()

print("Benchmarking migration")
results: Dict[str, float] = {}
results: dict[str, float] = {}
start = time.time()
upgrade(revision=revision)
duration = time.time() - start
results["Current"] = duration
print(f"Migration on current DB took: {duration:.2f} seconds")

min_entities = 10
new_models: Dict[Type[Model], List[Model]] = defaultdict(list)
new_models: dict[type[Model], list[Model]] = defaultdict(list)
while min_entities <= limit:
downgrade(revision=down_revision)
print(f"Running with at least {min_entities} entities of each model")
for model in models:
missing = min_entities - model_rows[model]
if missing > 0:
entities: List[Model] = []
entities: list[Model] = []
print(f"- Adding {missing} entities to the {model.__name__} model")
bar = ChargingBar("Processing", max=missing)
try:
Expand Down
23 changes: 11 additions & 12 deletions scripts/cancel_github_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,21 @@
./cancel_github_workflows.py 1024 --include-last
"""
import os
from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
from collections.abc import Iterable, Iterator
from typing import Any, Literal, Optional, Union

import click
import requests
from click.exceptions import ClickException
from dateutil import parser
from typing_extensions import Literal

github_token = os.environ.get("GITHUB_TOKEN")
github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset")


def request(
method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, **kwargs: Any
) -> Dict[str, Any]:
) -> dict[str, Any]:
resp = requests.request(
method,
f"https://api.github.com/{endpoint.lstrip('/')}",
Expand All @@ -61,8 +61,8 @@ def request(

def list_runs(
repo: str,
params: Optional[Dict[str, str]] = None,
) -> Iterator[Dict[str, Any]]:
params: Optional[dict[str, str]] = None,
) -> Iterator[dict[str, Any]]:
"""List all github workflow runs.
Returns:
An iterator that will iterate through all pages of matching runs."""
Expand All @@ -77,16 +77,15 @@ def list_runs(
params={**params, "per_page": 100, "page": page},
)
total_count = result["total_count"]
for item in result["workflow_runs"]:
yield item
yield from result["workflow_runs"]
page += 1


def cancel_run(repo: str, run_id: Union[str, int]) -> Dict[str, Any]:
def cancel_run(repo: str, run_id: Union[str, int]) -> dict[str, Any]:
return request("POST", f"/repos/{repo}/actions/runs/{run_id}/cancel")


def get_pull_request(repo: str, pull_number: Union[str, int]) -> Dict[str, Any]:
def get_pull_request(repo: str, pull_number: Union[str, int]) -> dict[str, Any]:
return request("GET", f"/repos/{repo}/pulls/{pull_number}")


Expand All @@ -96,7 +95,7 @@ def get_runs(
user: Optional[str] = None,
statuses: Iterable[str] = ("queued", "in_progress"),
events: Iterable[str] = ("pull_request", "push"),
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Get workflow runs associated with the given branch"""
return [
item
Expand All @@ -108,7 +107,7 @@ def get_runs(
]


def print_commit(commit: Dict[str, Any], branch: str) -> None:
def print_commit(commit: dict[str, Any], branch: str) -> None:
"""Print out commit message for verification"""
indented_message = " \n".join(commit["message"].split("\n"))
date_str = (
Expand Down Expand Up @@ -155,7 +154,7 @@ def print_commit(commit: Dict[str, Any], branch: str) -> None:
def cancel_github_workflows(
branch_or_pull: Optional[str],
repo: str,
event: List[str],
event: list[str],
include_last: bool,
include_running: bool,
) -> None:
Expand Down
Loading

0 comments on commit b1cda7c

Please sign in to comment.