diff --git a/dmoj/problem.py b/dmoj/problem.py index 096293f12..39274ff28 100644 --- a/dmoj/problem.py +++ b/dmoj/problem.py @@ -1,6 +1,9 @@ +import itertools import os +import re import subprocess import zipfile +from collections import defaultdict from functools import partial import yaml @@ -15,6 +18,9 @@ from dmoj.result import Result from dmoj.utils.module import load_module_from_file +DEFAULT_TEST_CASE_INPUT_PATTERN = r'^(?=.*?\.in|in).*?(?:(?:^|\W)(?P\d+)[^\d\s]+)?(?P\d+)[^\d\s]*$' +DEFAULT_TEST_CASE_OUTPUT_PATTERN = r'^(?=.*?\.out|out).*?(?:(?:^|\W)(?P\d+)[^\d\s]+)?(?P\d+)[^\d\s]*$' + class Problem: def __init__(self, problem_id, time_limit, memory_limit): @@ -45,6 +51,96 @@ def __init__(self, problem_id, time_limit, memory_limit): raise InvalidInitException(str(e)) self.problem_data.archive = self._resolve_archive_files() + self._resolve_test_cases() + + def _match_test_cases(self, filenames, input_case_pattern, output_case_pattern, case_points): + def try_match_int(match, group): + try: + val = match.group(group) + except IndexError: + return None + + try: + return int(val) + except (ValueError, TypeError): + return val + + def parse_position(pattern, filename): + match = pattern.match(filename) + if not match: + return None + + # Allow batches and case numbers to be alphanumeric, in which case we will sort them lexicographically. + # Still attempt to process them as integers first, though, since most problems will use this format. + return try_match_int(match, 'batch'), try_match_int(match, 'case') + + # Match all cases with the same (batch, position) mapping. + groups = defaultdict(list) + batch_ids = set() + for a in filenames: + a_parse = parse_position(input_case_pattern, a) + if a_parse is None: + continue + + for b in filenames: + b_parse = parse_position(output_case_pattern, b) + if a_parse == b_parse: + batch, case = a_parse + if case is None: + raise InvalidInitException('test case format yielded no case number') + if batch is not None: + batch_ids.add(batch) + groups[batch or case].append((case, a, b)) + + test_cases = [] + for batch_or_case_id in sorted(groups.keys()): + group_cases = groups[batch_or_case_id] + if batch_or_case_id in batch_ids: + test_cases.append({ + 'batched': [{ + 'in': input_file, + 'out': output_file, + } for _, input_file, output_file in sorted(group_cases)], + 'points': next(case_points), + }) + else: + if len(group_cases) > 1: + raise InvalidInitException('problem has conflicting test cases: %s' % group_cases) + _, input_file, output_file = group_cases[0] + test_cases.append({ + 'in': input_file, + 'out': output_file, + 'points': next(case_points), + }) + + return test_cases + + def _problem_file_list(self): + # We *could* support testcase format specifiers without an archive, but it's harder and most problems should be + # using archives in the first place. + if not self.problem_data.archive: + raise InvalidInitException('can only use test case format specifiers if `archive` is set') + return self.problem_data.archive.namelist() + + def _resolve_test_cases(self): + test_cases = self.config.test_cases + + # We support several ways for specifying cases. The first is a list of cases, and requires no extra work. + if test_cases is not None and isinstance(test_cases.unwrap(), list): + return + + def get_with_default(name, default): + if not test_cases: + return default + return test_cases[name] or default + + # If the `test_cases` node is None, we try to guess the testcase name format. + self.config['test_cases'] = self._match_test_cases( + self._problem_file_list(), + re.compile(get_with_default('input_format', DEFAULT_TEST_CASE_INPUT_PATTERN), re.IGNORECASE), + re.compile(get_with_default('output_format', DEFAULT_TEST_CASE_OUTPUT_PATTERN), re.IGNORECASE), + iter(get_with_default('case_points', itertools.repeat(1))), + ) def load_checker(self, name): if name in self._checkers: @@ -74,7 +170,8 @@ def __init__(self, problem_id, **kwargs): def __missing__(self, key): base = get_problem_root(self.problem_id) try: - return open(os.path.join(base, key), 'rb').read() + with open(os.path.join(base, key), 'rb') as f: + return f.read() except IOError: if self.archive: zipinfo = self.archive.getinfo(key) diff --git a/dmoj/tests/test_problem.py b/dmoj/tests/test_problem.py index 3f2203a0e..0f52f6cd8 100644 --- a/dmoj/tests/test_problem.py +++ b/dmoj/tests/test_problem.py @@ -3,7 +3,7 @@ from unittest import mock from dmoj.config import InvalidInitException -from dmoj.problem import Problem +from dmoj.problem import Problem, ProblemDataManager class ProblemTest(unittest.TestCase): @@ -12,6 +12,37 @@ def setUp(self): data_mock = self.data_patch.start() data_mock.side_effect = lambda problem: self.problem_data + def test_test_case_matching(self): + class MockProblem(Problem): + def _resolve_archive_files(self): + return None + + def _problem_file_list(self): + return [ + 's2.1-1.in', 's2.1-1.out', + 's2.1.2.in', 's2.1.2.out', + 's3.4.in', 's3.4.out', + '5.in', '5.OUT', + '6-1.in', '6-1.OUT', + '6.2.in', '6.2.OUT', + 'foo/a.b.c.6.3.in', 'foo/a.b.c.6.3.OUT', + 'bar.in.7', 'bar.out.7', + 'INPUT8.txt', 'OUTPUT8.txt', + '.DS_Store', + ] + + self.problem_data = ProblemDataManager('foo') + self.problem_data.update({'init.yml': 'archive: foo.zip'}) + self.assertEqual(MockProblem('test', 2, 16384).config.test_cases.unwrap(), + [{'batched': [{'in': 's2.1-1.in', 'out': 's2.1-1.out'}, + {'in': 's2.1.2.in', 'out': 's2.1.2.out'}], 'points': 1}, + {'in': 's3.4.in', 'out': 's3.4.out', 'points': 1}, + {'in': '5.in', 'out': '5.OUT', 'points': 1}, { + 'batched': [{'in': '6-1.in', 'out': '6-1.OUT'}, {'in': '6.2.in', 'out': '6.2.OUT'}, + {'in': 'foo/a.b.c.6.3.in', 'out': 'foo/a.b.c.6.3.OUT'}], 'points': 1}, + {'in': 'bar.in.7', 'out': 'bar.out.7', 'points': 1}, + {'in': 'INPUT8.txt', 'out': 'OUTPUT8.txt', 'points': 1}]) + def test_no_init(self): self.problem_data = {} with self.assertRaises(InvalidInitException):