Skip to content

Commit

Permalink
Merge pull request #1177 from yosupo06/refactor/nit
Browse files Browse the repository at this point in the history
nit refactor
  • Loading branch information
yosupo06 authored Jun 3, 2024
2 parents 8e131c3 + 46d3315 commit 9479ae3
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 50 deletions.
88 changes: 39 additions & 49 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import argparse
import platform
from logging import Logger, basicConfig, getLogger
from logging import basicConfig, getLogger
from os import getenv
from pathlib import Path
from typing import List
Expand All @@ -19,20 +19,18 @@
from subprocess import (PIPE, STDOUT, CalledProcessError,
TimeoutExpired, check_call, run)
from tempfile import TemporaryDirectory
from typing import Any, Iterator, List, MutableMapping, Union, Optional
from typing import Any, Iterator, List, MutableMapping, Optional

from enum import Enum
import toml

logger: Logger = getLogger(__name__)

logger = getLogger(__name__)

CASENAME_LEN_LIMIT = 40
STACK_SIZE = 2 ** 28 # 256 MB


def casename(name: Union[str, Path], i: int) -> str:
def casename(name: str | Path, i: int) -> str:
"""(random, 1) -> random_01"""
return Path(name).stem + '_' + str(i).zfill(2)

Expand All @@ -49,6 +47,17 @@ def param_to_str(key: str, value: object):
raise RuntimeError('Unsupported type of params: {}'.format(key))


def find_problem_dir(rootdir: Path, problem_name: Path) -> Optional[Path]:
tomls = list(rootdir.glob('**/{}/info.toml'.format(problem_name)))
if len(tomls) == 0:
logger.error('Cannot find problem: {}'.format(problem_name))
return None
if len(tomls) >= 2:
logger.error('Found multiple problem dirs: {}'.format(problem_name))
return None
return tomls[0].parent


def compile(src: Path, rootdir: Path, opts: list[str] = []):
if src.suffix == '.cpp':
# use clang for msys2 clang environment
Expand Down Expand Up @@ -130,6 +139,24 @@ def logging_result(result: str, start: datetime, end: datetime, message: str):


class Problem:
class Mode(Enum):
DEFAULT = 1
DEV = 2
TEST = 3
CLEAN = 5

def force_generate(self):
return self == self.DEV or self == self.TEST

def verify(self):
return self == self.DEV or self == self.TEST

def rewrite_hash(self):
return self == self.DEV

def ignore_warning(self):
return self == self.DEV

rootdir: Path # /path/to/librar-checker-problems
basedir: Path # /path/to/librar-checker-problems/sample/aplusb
ignore_warning: bool = False
Expand All @@ -142,11 +169,11 @@ def __init__(self, rootdir: Path, basedir: Path):
self.rootdir = rootdir
self.basedir = basedir
tomlpath = basedir / 'info.toml'
self.config = toml.load(tomlpath)
self.config = toml.load(tomlpath) # type: ignore
self.checker = basedir / \
self.config.get('checker', 'checker.cpp')
self.config.get('checker', 'checker.cpp') # type: ignore
self.verifier = basedir / \
self.config.get('verifier', 'verifier.cpp')
self.config.get('verifier', 'verifier.cpp') # type: ignore

def warning(self, message: str):
logger.warning(message)
Expand All @@ -163,7 +190,7 @@ def health_check(self):
self.warning('too long casename: {}'.format(cn))

gendir = self.basedir / 'gen'
gens = []
gens: list[str] = []
for test in self.config['tests']:
gen = gendir / test['name']
if gen.suffix == '.cpp':
Expand Down Expand Up @@ -215,12 +242,12 @@ def compile_solutions(self):
compile(self.basedir / 'sol' / name, self.rootdir, opts)

def check_all_solutions_used(self) -> bool:
sol_names = set()
sol_names: set[str] = set()
sol_names.add('correct.cpp')
for sol in self.config.get('solutions', []):
sol_names.add(sol['name'])

file_names = set()
file_names: set[str] = set()
file_names.update(p.name for p in (self.basedir / 'sol').glob('*.cpp'))
file_names.update(p.name for p in (self.basedir / 'sol').glob('*.py'))
return sol_names == file_names
Expand Down Expand Up @@ -260,7 +287,7 @@ def verify_inputs(self):
logger.error('verify failed: {}'.format(inname))
exit(1)

def make_outputs(self, check):
def make_outputs(self, check: bool):
indir = self.basedir / 'in'
outdir = self.basedir / 'out'
soldir = self.basedir / 'sol'
Expand Down Expand Up @@ -363,17 +390,6 @@ def problem_version(self) -> str:
all_hash.update(h)
return all_hash.hexdigest()

# return "version" of testcase
def testcase_version(self) -> str:
all_hash = hashlib.sha256()
with open(str(self.checker), 'rb') as f:
all_hash.update(hashlib.sha256(f.read()).digest())
with open(str(self.basedir / 'hash.json'), 'r') as f:
cases = json.load(f)
for name, sha in sorted(cases.items(), key=lambda x: x[0]):
all_hash.update(sha.encode('ascii'))
return all_hash.hexdigest()

def judge(self, src: Path, config: dict):
indir = self.basedir / 'in'
outdir = self.basedir / 'out'
Expand Down Expand Up @@ -485,21 +501,6 @@ def clean(self):
if (self.basedir / 'out').exists():
shutil.rmtree(self.basedir / 'out')

class Mode(Enum):
DEFAULT = 1
DEV = 2
TEST = 3
CLEAN = 5

def force_generate(self):
return self == self.DEV or self == self.TEST

def verify(self):
return self == self.DEV or self == self.TEST

def rewrite_hash(self):
return self == self.DEV

def generate(self, mode: Mode):
if mode == self.Mode.DEV:
self.ignore_warning = True
Expand Down Expand Up @@ -548,17 +549,6 @@ def generate(self, mode: Mode):
self.assert_hashes()


def find_problem_dir(rootdir: Path, problem_name: Path) -> Optional[Path]:
tomls = list(rootdir.glob('**/{}/info.toml'.format(problem_name)))
if len(tomls) == 0:
logger.error('Cannot find problem: {}'.format(problem_name))
return None
if len(tomls) >= 2:
logger.error('Found multiple problem dirs: {}'.format(problem_name))
return None
return tomls[0].parent


def main(args: List[str]):
try:
import colorlog
Expand Down
9 changes: 8 additions & 1 deletion generate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from shutil import copy
from pathlib import Path
from tempfile import TemporaryDirectory
from generate import Problem, param_to_str
from generate import Problem, param_to_str, casename
from typing import List

logger = getLogger(__name__)
Expand Down Expand Up @@ -276,6 +276,13 @@ def test_list_depending_files(self):
self.assertTrue(find_verifier)


class TestCasename(unittest.TestCase):
# select problem by problem id
def test_casename(self):
self.assertEqual(casename('example', 0), "example_00")
self.assertEqual(casename('example', 1), "example_01")
self.assertEqual(casename('random', 10), "random_10")

class TestParam(unittest.TestCase):
# select problem by problem id
def test_convert_integer(self):
Expand Down

0 comments on commit 9479ae3

Please sign in to comment.