From a51e9f46569a399e719a16c8118b98c7641a3f7b Mon Sep 17 00:00:00 2001 From: Andrei Kashin Date: Mon, 5 Feb 2024 15:41:11 +0000 Subject: [PATCH 1/2] Add type annotations to zkasm-result.py This should make the code safer and easier to read --- ci/zkasm-result.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/ci/zkasm-result.py b/ci/zkasm-result.py index fa712b9bfc56..9329e10acdc1 100644 --- a/ci/zkasm-result.py +++ b/ci/zkasm-result.py @@ -4,6 +4,7 @@ import sys import argparse import json +from typing import Any, TextIO from operator import countOf @@ -20,7 +21,7 @@ class TestResult: # Number of cycles it took to execute the test if it was successful. cycles: int | None - def to_csv_record(self): + def to_csv_record(self) -> dict[str, Any]: return { "Test": self.test_name, "Status": self.status, @@ -28,7 +29,7 @@ def to_csv_record(self): } @staticmethod - def from_csv_record(record): + def from_csv_record(record: dict[str, Any]) -> "TestResult": if record["Cycles"]: cycles = int(record["Cycles"]) else: @@ -39,7 +40,9 @@ def from_csv_record(record): ) -def record_failed_compilation_results(tests_path, generated_dir, test_results): +def record_failed_compilation_results( + tests_path: str, generated_dir: str, test_results: dict[str, TestResult] +) -> None: for file in os.listdir(tests_path): if not file.endswith(".wat"): continue @@ -52,7 +55,7 @@ def record_failed_compilation_results(tests_path, generated_dir, test_results): ) -def parse_test_result(result_json): +def parse_test_result(result_json: dict[str, Any]) -> TestResult: test_name, _ = os.path.splitext(os.path.basename(result_json["path"])) status = result_json["status"] if status == "pass": @@ -63,7 +66,7 @@ def parse_test_result(result_json): return TestResult(test_name=test_name, status=status, cycles=cycles) -def read_test_execution_results(input_handle): +def read_test_execution_results(input_handle: TextIO) -> dict[str, TestResult]: test_results = {} for test_result_json in json.load(input_handle): test_result = parse_test_result(test_result_json) @@ -71,13 +74,13 @@ def read_test_execution_results(input_handle): return test_results -def read_summary(filepath): +def read_summary(filepath: str) -> dict[str, Any]: with open(filepath, "r", newline="") as csvfile: reader = csv.DictReader(csvfile) return {row["Suite path"]: row for row in reader} -def write_summary(filepath, summary): +def write_summary(filepath: str, summary: dict[str, Any]) -> None: with open(filepath, "w", newline="") as csvfile: writer = csv.DictWriter( csvfile, @@ -88,13 +91,15 @@ def write_summary(filepath, summary): writer.writerow({"Suite path": path, **value}) -def read_test_results(test_results_path): +def read_test_results(test_results_path: str) -> dict[str, TestResult]: with open(test_results_path, newline="") as csvfile: reader = csv.DictReader(csvfile) return {row["Test"]: TestResult.from_csv_record(row) for row in reader} -def write_test_results(test_results, test_results_path, tests_path): +def write_test_results( + test_results: dict[str, TestResult], test_results_path: str, tests_path: str +) -> None: with open(test_results_path, "w", newline="") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=CSV_FIELD_NAMES) writer.writeheader() @@ -114,7 +119,7 @@ def write_test_results(test_results, test_results_path, tests_path): write_summary(TEST_SUMMARY_FILE_PATH, summary) -def assert_dict_equals(actual, expected): +def assert_dict_equals(actual: dict[str, Any], expected: dict[str, Any]) -> None: if actual == expected: return @@ -129,7 +134,7 @@ def assert_dict_equals(actual, expected): ) -def main(): +def main() -> None: parser = argparse.ArgumentParser( description="Example script to demonstrate flag usage." ) From 9ce34ab8ce1c44d11ae2cee1e86bd8f16b6a29c5 Mon Sep 17 00:00:00 2001 From: Andrei Kashin Date: Tue, 6 Feb 2024 14:17:18 +0000 Subject: [PATCH 2/2] Run mypy in strict mode --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index f36101cce1d6..65b4857b3984 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ select = ["F", "E"] [tool.mypy] files = ["ci/*.py"] +strict = true [tool.pdm.scripts] lint = "ruff ."