From 46d3315d63f1a58b1e4015b5c17fd389ff13b6d6 Mon Sep 17 00:00:00 2001 From: Kohei Morita Date: Mon, 3 Jun 2024 07:57:52 +0900 Subject: [PATCH] nit refactor --- generate.py | 88 +++++++++++++++++++++--------------------------- generate_test.py | 9 ++++- 2 files changed, 47 insertions(+), 50 deletions(-) diff --git a/generate.py b/generate.py index 38f8f345..29410c26 100755 --- a/generate.py +++ b/generate.py @@ -3,7 +3,7 @@ import sys import argparse import platform -from logging import Logger, basicConfig, getLogger +from logging import basicConfig, getLogger from os import getenv from pathlib import Path from typing import List @@ -19,20 +19,18 @@ from subprocess import (PIPE, STDOUT, CalledProcessError, TimeoutExpired, check_call, run) from tempfile import TemporaryDirectory -from typing import Any, Iterator, List, MutableMapping, Union, Optional +from typing import Any, Iterator, List, MutableMapping, Optional from enum import Enum import toml -logger: Logger = getLogger(__name__) - logger = getLogger(__name__) CASENAME_LEN_LIMIT = 40 STACK_SIZE = 2 ** 28 # 256 MB -def casename(name: Union[str, Path], i: int) -> str: +def casename(name: str | Path, i: int) -> str: """(random, 1) -> random_01""" return Path(name).stem + '_' + str(i).zfill(2) @@ -49,6 +47,17 @@ def param_to_str(key: str, value: object): raise RuntimeError('Unsupported type of params: {}'.format(key)) +def find_problem_dir(rootdir: Path, problem_name: Path) -> Optional[Path]: + tomls = list(rootdir.glob('**/{}/info.toml'.format(problem_name))) + if len(tomls) == 0: + logger.error('Cannot find problem: {}'.format(problem_name)) + return None + if len(tomls) >= 2: + logger.error('Found multiple problem dirs: {}'.format(problem_name)) + return None + return tomls[0].parent + + def compile(src: Path, rootdir: Path, opts: list[str] = []): if src.suffix == '.cpp': # use clang for msys2 clang environment @@ -130,6 +139,24 @@ def logging_result(result: str, start: datetime, end: datetime, message: str): class Problem: + class Mode(Enum): + DEFAULT = 1 + DEV = 2 + TEST = 3 + CLEAN = 5 + + def force_generate(self): + return self == self.DEV or self == self.TEST + + def verify(self): + return self == self.DEV or self == self.TEST + + def rewrite_hash(self): + return self == self.DEV + + def ignore_warning(self): + return self == self.DEV + rootdir: Path # /path/to/librar-checker-problems basedir: Path # /path/to/librar-checker-problems/sample/aplusb ignore_warning: bool = False @@ -142,11 +169,11 @@ def __init__(self, rootdir: Path, basedir: Path): self.rootdir = rootdir self.basedir = basedir tomlpath = basedir / 'info.toml' - self.config = toml.load(tomlpath) + self.config = toml.load(tomlpath) # type: ignore self.checker = basedir / \ - self.config.get('checker', 'checker.cpp') + self.config.get('checker', 'checker.cpp') # type: ignore self.verifier = basedir / \ - self.config.get('verifier', 'verifier.cpp') + self.config.get('verifier', 'verifier.cpp') # type: ignore def warning(self, message: str): logger.warning(message) @@ -163,7 +190,7 @@ def health_check(self): self.warning('too long casename: {}'.format(cn)) gendir = self.basedir / 'gen' - gens = [] + gens: list[str] = [] for test in self.config['tests']: gen = gendir / test['name'] if gen.suffix == '.cpp': @@ -215,12 +242,12 @@ def compile_solutions(self): compile(self.basedir / 'sol' / name, self.rootdir, opts) def check_all_solutions_used(self) -> bool: - sol_names = set() + sol_names: set[str] = set() sol_names.add('correct.cpp') for sol in self.config.get('solutions', []): sol_names.add(sol['name']) - file_names = set() + file_names: set[str] = set() file_names.update(p.name for p in (self.basedir / 'sol').glob('*.cpp')) file_names.update(p.name for p in (self.basedir / 'sol').glob('*.py')) return sol_names == file_names @@ -260,7 +287,7 @@ def verify_inputs(self): logger.error('verify failed: {}'.format(inname)) exit(1) - def make_outputs(self, check): + def make_outputs(self, check: bool): indir = self.basedir / 'in' outdir = self.basedir / 'out' soldir = self.basedir / 'sol' @@ -363,17 +390,6 @@ def problem_version(self) -> str: all_hash.update(h) return all_hash.hexdigest() - # return "version" of testcase - def testcase_version(self) -> str: - all_hash = hashlib.sha256() - with open(str(self.checker), 'rb') as f: - all_hash.update(hashlib.sha256(f.read()).digest()) - with open(str(self.basedir / 'hash.json'), 'r') as f: - cases = json.load(f) - for name, sha in sorted(cases.items(), key=lambda x: x[0]): - all_hash.update(sha.encode('ascii')) - return all_hash.hexdigest() - def judge(self, src: Path, config: dict): indir = self.basedir / 'in' outdir = self.basedir / 'out' @@ -485,21 +501,6 @@ def clean(self): if (self.basedir / 'out').exists(): shutil.rmtree(self.basedir / 'out') - class Mode(Enum): - DEFAULT = 1 - DEV = 2 - TEST = 3 - CLEAN = 5 - - def force_generate(self): - return self == self.DEV or self == self.TEST - - def verify(self): - return self == self.DEV or self == self.TEST - - def rewrite_hash(self): - return self == self.DEV - def generate(self, mode: Mode): if mode == self.Mode.DEV: self.ignore_warning = True @@ -548,17 +549,6 @@ def generate(self, mode: Mode): self.assert_hashes() -def find_problem_dir(rootdir: Path, problem_name: Path) -> Optional[Path]: - tomls = list(rootdir.glob('**/{}/info.toml'.format(problem_name))) - if len(tomls) == 0: - logger.error('Cannot find problem: {}'.format(problem_name)) - return None - if len(tomls) >= 2: - logger.error('Found multiple problem dirs: {}'.format(problem_name)) - return None - return tomls[0].parent - - def main(args: List[str]): try: import colorlog diff --git a/generate_test.py b/generate_test.py index 41181a0b..8a7f2d35 100755 --- a/generate_test.py +++ b/generate_test.py @@ -9,7 +9,7 @@ from shutil import copy from pathlib import Path from tempfile import TemporaryDirectory -from generate import Problem, param_to_str +from generate import Problem, param_to_str, casename from typing import List logger = getLogger(__name__) @@ -276,6 +276,13 @@ def test_list_depending_files(self): self.assertTrue(find_verifier) +class TestCasename(unittest.TestCase): + # select problem by problem id + def test_casename(self): + self.assertEqual(casename('example', 0), "example_00") + self.assertEqual(casename('example', 1), "example_01") + self.assertEqual(casename('random', 10), "random_10") + class TestParam(unittest.TestCase): # select problem by problem id def test_convert_integer(self):