From fd0e69d3898303a6af726f2741ea8b14f9ef0401 Mon Sep 17 00:00:00 2001 From: Max Taylor Date: Mon, 25 Jul 2022 15:03:26 -0400 Subject: [PATCH] Added multiprocessing --- sa4u_z3/tu.py | 115 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 78 insertions(+), 37 deletions(-) diff --git a/sa4u_z3/tu.py b/sa4u_z3/tu.py index f30bde8..56f9e8c 100644 --- a/sa4u_z3/tu.py +++ b/sa4u_z3/tu.py @@ -1,7 +1,9 @@ import ccsyspath import clang.cindex as cindex import json +import multiprocessing.pool import os +import queue import time import z3 from dataclasses import dataclass @@ -10,6 +12,12 @@ _tu_filename_to_stu: Dict[str, 'SerializedTU'] = {} +# Number of threads to concurrently read translation units. +_NUM_WORKERS = 8 + +# Maximum number of waiting TUs that need analyzed. +_MAX_WAITING_TUS = 128 + @dataclass class SerializedTU: @@ -21,48 +29,81 @@ class SerializedTU: def translation_units(compile_commands: cindex.CompilationDatabase, cache_path: Optional[str]) -> Iterator[Union[cindex.TranslationUnit, SerializedTU]]: '''Returns an iterator over a translation unit for each file in the compilation database.''' - compile_command: cindex.CompileCommand - for compile_command in compile_commands.getAllCompileCommands(): - if cache_path: - full_path = os.path.join( - compile_command.directory, - compile_command.filename, - ) - serialized_tu = read_tu( - cache_path, - full_path, - ) - modified_time = os.path.getmtime(full_path) - if serialized_tu.serialization_time >= modified_time: - log(LogLevel.INFO, f'Using cached analysis for {full_path}') - yield serialized_tu - continue + with multiprocessing.pool.ThreadPool(processes=_NUM_WORKERS) as pool: + compile_commands = list(compile_commands.getAllCompileCommands()) + q: queue.Queue[Union[cindex.TranslationUnit, SerializedTU]] = queue.Queue( + maxsize=_MAX_WAITING_TUS, + ) - log( - LogLevel.INFO, - f'parsing {compile_command.filename}' + pool.starmap_async( + _mp_parse_tu, + [(cmd, q, cache_path) for cmd in compile_commands], ) - try: - if 'lua' in compile_command.filename: - continue - os.chdir(compile_command.directory) - translation_unit = cindex.TranslationUnit.from_source( - os.path.join(compile_command.directory, - compile_command.filename), - args=[arg for arg in compile_command.arguments - if arg != compile_command.filename] + ['-I' + inc.decode() for inc in ccsyspath.system_include_paths('clang')], - ) - for diag in translation_unit.diagnostics: - log( - LogLevel.WARNING, - f'Parsing: {compile_command.filename}: {diag}' - ) - yield translation_unit - except cindex.TranslationUnitLoadError: + + for _ in enumerate(compile_commands): + item = q.get() + if isinstance(item, Exception): + raise item + elif item is not None: + yield item + + # compile_command: cindex.CompileCommand + # for compile_command in compile_commands.getAllCompileCommands(): + # maybe_tu = _parse_tu(compile_command, cache_path) + # if maybe_tu: + # yield maybe_tu + + +def _mp_parse_tu(compile_command: cindex.CompileCommand, q: multiprocessing.Queue, cache_path: Optional[str]): + '''Parse a translation unit, and place the parsed result in the queue.''' + try: + q.put(_parse_tu(compile_command, cache_path)) + except Exception as err: + q.put(err) + + +def _parse_tu(compile_command: cindex.CompileCommand, cache_path: Optional[str] = None) -> Optional[Union[cindex.TranslationUnit, SerializedTU]]: + '''Parses the translation unit.''' + if cache_path: + full_path = os.path.join( + compile_command.directory, + compile_command.filename, + ) + serialized_tu = read_tu( + cache_path, + full_path, + ) + modified_time = os.path.getmtime(full_path) + if serialized_tu.serialization_time >= modified_time: + log(LogLevel.INFO, f'Using cached analysis for {full_path}') + return serialized_tu + + log( + LogLevel.INFO, + f'parsing {compile_command.filename}', + ) + try: + if 'lua' in compile_command.filename: + return None + os.chdir(compile_command.directory) + translation_unit = cindex.TranslationUnit.from_source( + os.path.join(compile_command.directory, + compile_command.filename), + args=[arg for arg in compile_command.arguments + if arg != compile_command.filename] + ['-I' + inc.decode() for inc in ccsyspath.system_include_paths('clang')], + ) + for diag in translation_unit.diagnostics: log( LogLevel.WARNING, - f'could not parse {os.path.join(compile_command.directory, compile_command.filename)}', + f'Parsing: {compile_command.filename}: {diag}' ) + return translation_unit + except cindex.TranslationUnitLoadError: + log( + LogLevel.WARNING, + f'could not parse {os.path.join(compile_command.directory, compile_command.filename)}', + ) + return None def read_tu(path: str, file_path: str) -> SerializedTU: