Skip to content

Commit

Permalink
updated base handling
Browse files Browse the repository at this point in the history
  • Loading branch information
smythi93 committed Jul 28, 2024
1 parent 59db337 commit 17e5a26
Showing 1 changed file with 35 additions and 5 deletions.
40 changes: 35 additions & 5 deletions src/sflkit/runners/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import string
import subprocess
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from typing import List, Dict, Optional, Tuple, Set

from sflkit.logger import LOGGER

Expand Down Expand Up @@ -269,7 +269,10 @@ def __init__(
def _common_base(directory: Path, tests: List[str]) -> Path:
parts = directory.parts
common_bases = {Path(*parts[:i]) for i in range(1, len(parts) + 1)}
leaves_paths = {Path(r.split("::")[0]) for r in tests}
if "::" in tests[0]:
leaves_paths = {Path(r.split("::")[0]) for r in tests}
else:
leaves_paths = {Path(r) for r in tests}
common_bases = set(
filter(
lambda p: all(map(lambda r: Path(p, *r.parts).exists(), leaves_paths)),
Expand All @@ -284,11 +287,28 @@ def _common_base(directory: Path, tests: List[str]) -> Path:
def _normalize_paths(
self,
tests: List[str],
bases: Optional[Set[str]] = None,
directory: Optional[Path] = None,
root_dir: Optional[Path] = None,
):
result = tests
if directory:
if bases:
for base in bases:
common = self._common_base(directory, [base])
if common:
base = common / base
base = self._common_base(base, tests)
if base is not None:
result = []
for r in tests:
path, test = r.split("::", 1)
result.append(
str((base / path).relative_to(directory))
+ "::"
+ test
)
return result
base = self._common_base(directory, tests)
if base is None and root_dir:
base = self._common_base(root_dir, tests)
Expand Down Expand Up @@ -321,11 +341,21 @@ def get_tests(
if k:
c.append("-k")
c.append(k)
bases = set()
if files:
if isinstance(files, (str, os.PathLike)):
c.append(str(files))
str_files = [str(files)]
else:
c += [str(f) for f in files]
str_files = [str(f) for f in files]
common_base = self._common_base(directory, [str(files)])
if common_base:
bases.add(common_base)
elif base:
common_base = self._common_base(base, [str(files)])
if common_base:
bases.add(common_base)
c += str_files

if base:
if not files:
c.append(str(base))
Expand All @@ -347,7 +377,7 @@ def get_tests(
)
LOGGER.info(f"pytest collection finished with {process.returncode}")
tests = PytestStructure.parse_tests(process.stdout.decode("utf8"))
return self._normalize_paths(tests, directory, root_dir)
return self._normalize_paths(tests, bases, directory, root_dir)

@staticmethod
def __get_pytest_result__(
Expand Down

0 comments on commit 17e5a26

Please sign in to comment.