diff --git a/.gitignore b/.gitignore index 50fa6ec..0eb3d5c 100644 --- a/.gitignore +++ b/.gitignore @@ -156,3 +156,5 @@ tmp.py test_dir/ .DS_Store +mapping.json + diff --git a/pyproject.toml b/pyproject.toml index 523aed1..b52f22b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta" [project] name = "sflkit" -version = "0.2.17" +version = "0.2.18" authors = [ { name = "Marius Smytzek", email = "marius.smytzek@cispa.de" }, ] diff --git a/src/sflkit/__init__.py b/src/sflkit/__init__.py index 228fb59..2e96be8 100644 --- a/src/sflkit/__init__.py +++ b/src/sflkit/__init__.py @@ -8,7 +8,7 @@ def instrument_config(conf: Config): - instrumentation = DirInstrumentation(conf.visitor) + instrumentation = DirInstrumentation(conf.visitor, conf.mapping.path) instrumentation.instrument( conf.target_path, conf.instrument_working, diff --git a/src/sflkit/config.py b/src/sflkit/config.py index fc5ca37..7e3fb25 100644 --- a/src/sflkit/config.py +++ b/src/sflkit/config.py @@ -77,6 +77,7 @@ def __init__(self, path: Union[str, configparser.ConfigParser] = None): self.instrument_working = None self.runner = None self.mapping = None + self.mapping_path = None if path: if isinstance(path, configparser.ConfigParser): config = path @@ -143,10 +144,12 @@ def __init__(self, path: Union[str, configparser.ConfigParser] = None): self.metrics = [Spectrum.Ochiai] run_id_generator = IDGenerator() + if "mapping" in events: + self.mapping_path = Path(events["mapping"]) try: self.mapping = EventMapping.load(self) except InstrumentationError: - self.mapping = EventMapping() + self.mapping = EventMapping(path=self.mapping_path) if "passing" in events: self.passing = self.get_event_files( list(csv.reader([events["passing"]]))[0], @@ -198,6 +201,7 @@ def create_from_values( visitor: ASTVisitor = None, passing: List[EventFile] = None, failing: List[EventFile] = None, + mapping: EventMapping = None, instrument_include: List[str] = None, instrument_exclude: List[str] = None, instrument_working: str = None, @@ -220,6 +224,10 @@ def create_from_values( conf.instrument_exclude = instrument_exclude if instrument_exclude else list() conf.instrument_working = instrument_working conf.runner = runner + if mapping: + conf.mapping = mapping + if mapping.path: + conf.mapping_path = mapping.path return conf @staticmethod @@ -263,6 +271,7 @@ def create( metrics=None, passing=None, failing=None, + mapping_path=None, working=None, include=None, exclude=None, @@ -288,6 +297,8 @@ def create( conf["events"]["passing"] = passing if failing: conf["events"]["failing"] = failing + if mapping_path: + conf["events"]["mapping"] = mapping_path if working: conf["instrumentation"]["path"] = working if include: @@ -320,6 +331,8 @@ def write(self, path): conf["events"]["passing"] = ",".join(e.path for e in self.passing) if self.failing: conf["events"]["failing"] = ",".join(e.path for e in self.failing) + if self.mapping_path: + conf["events"]["mapping"] = str(self.mapping_path) if self.instrument_working: conf["instrumentation"]["path"] = str(self.instrument_working) if self.instrument_include: diff --git a/src/sflkit/instrumentation/__init__.py b/src/sflkit/instrumentation/__init__.py index 2003d34..06974da 100644 --- a/src/sflkit/instrumentation/__init__.py +++ b/src/sflkit/instrumentation/__init__.py @@ -1,8 +1,6 @@ -import json from abc import abstractmethod -from typing import List - -from sflkitlib.events import event +from pathlib import Path +from typing import List, Optional from sflkit import Config from sflkit.language.visitor import ASTVisitor @@ -10,9 +8,9 @@ class Instrumentation: - def __init__(self, visitor: ASTVisitor): + def __init__(self, visitor: ASTVisitor, mapping_path: Optional[Path] = None): self.visitor = visitor - self.events = EventMapping() + self.events = EventMapping(path=mapping_path) @abstractmethod def instrument( diff --git a/src/sflkit/instrumentation/dir_instrumentation.py b/src/sflkit/instrumentation/dir_instrumentation.py index 0a6d608..3ecae77 100644 --- a/src/sflkit/instrumentation/dir_instrumentation.py +++ b/src/sflkit/instrumentation/dir_instrumentation.py @@ -2,6 +2,7 @@ import queue import re import shutil +from pathlib import Path from typing import List, Optional, Iterable, Tuple from sflkit.instrumentation import Instrumentation @@ -11,9 +12,9 @@ class DirInstrumentation(Instrumentation): - def __init__(self, visitor: ASTVisitor): - super().__init__(visitor) - self.file_instrumentation = FileInstrumentation(visitor) + def __init__(self, visitor: ASTVisitor, mapping_path: Optional[Path] = None): + super().__init__(visitor, mapping_path) + self.file_instrumentation = FileInstrumentation(visitor, mapping_path) @staticmethod def check_included(element: str, includes: Optional[Iterable[str]]): diff --git a/src/sflkit/instrumentation/file_instrumentation.py b/src/sflkit/instrumentation/file_instrumentation.py index 93ab3c6..f459c39 100644 --- a/src/sflkit/instrumentation/file_instrumentation.py +++ b/src/sflkit/instrumentation/file_instrumentation.py @@ -1,4 +1,5 @@ -from typing import List +from pathlib import Path +from typing import List, Optional from sflkit.instrumentation import Instrumentation from sflkit.language.visitor import ASTVisitor @@ -6,13 +7,13 @@ class FileInstrumentation(Instrumentation): - def __init__(self, visitor: ASTVisitor): - super().__init__(visitor) + def __init__(self, visitor: ASTVisitor, mapping_path: Optional[Path] = None): + super().__init__(visitor, mapping_path) def instrument( self, src: str, dst: str, suffixes: List[str] = None, file: str = "" ): self.visitor.instrument(src, dst, file) self.events = EventMapping( - {event.event_id: event for event in self.visitor.events} + {event.event_id: event for event in self.visitor.events}, self.events.path ) diff --git a/src/sflkit/mapping.py b/src/sflkit/mapping.py index 8fd3260..aecf670 100644 --- a/src/sflkit/mapping.py +++ b/src/sflkit/mapping.py @@ -13,37 +13,45 @@ class InstrumentationError(RuntimeError): class EventMapping: - def __init__(self, mapping: Dict[int, Event] = None): + def __init__( + self, mapping: Dict[int, Event] = None, path: Optional[os.PathLike] = None + ): self.mapping = mapping or dict() + self.path = path def get(self, event_id) -> Optional[Event]: return self.mapping.get(event_id, None) @staticmethod - def get_path(identifier: str): + def get_path(identifier: str) -> Path: return SFLKIT_PATH / f"{identifier}.json" @staticmethod def load(config: Any): if not hasattr(config, "identifier"): raise InstrumentationError(f"Argument does not have an identifier") - return EventMapping.load_from_file(config.identifier()) + return EventMapping.load_from_file( + config.mapping_path or EventMapping.get_path(config.identifier()), + config.target_path, + ) @staticmethod - def load_from_file(identifier: str): - file = EventMapping.get_path(identifier) + def load_from_file(file: Path, path: os.PathLike): if file.exists(): - return EventMapping(load_json(file)) + return EventMapping(load_json(file), file) else: raise InstrumentationError( - f"Cannot find information about instrumentation of {identifier or file}" + f"Cannot find information about instrumentation of {path or file}" ) def write(self, config): if not hasattr(config, "identifier"): raise InstrumentationError(f"Argument does not have an identifier") - SFLKIT_PATH.mkdir(parents=True, exist_ok=True) - file = self.get_path(config.identifier()) + if self.path: + file = self.path + else: + SFLKIT_PATH.mkdir(parents=True, exist_ok=True) + file = self.get_path(config.identifier()) with file.open("w") as fp: json.dump(list(self.mapping.values()), fp, cls=EventEncoder) diff --git a/tests/test_config.py b/tests/test_config.py index bd5aa11..5129d3f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -93,6 +93,7 @@ def test_create_config(self): config.visitor, config.passing, config.failing, + config.mapping, config.instrument_include, config.instrument_exclude, config.instrument_working, diff --git a/tests/test_instrumentation.py b/tests/test_instrumentation.py index b07da51..986c879 100644 --- a/tests/test_instrumentation.py +++ b/tests/test_instrumentation.py @@ -106,6 +106,21 @@ def test_complex_structure_include(self): ) ) + def test_mapping_output(self): + instrument_config( + Config.create( + path=os.path.join(BaseTest.TEST_RESOURCES, "test_instrumentation"), + language="python", + events="line", + predicates="line", + working=BaseTest.TEST_DIR, + mapping_path="mapping.json", + ) + ) + mapping = Path("mapping.json") + self.assertTrue(mapping.exists()) + self.assertTrue(mapping.is_file()) + class TestLib(BaseTest): @classmethod