Skip to content

Commit

Permalink
Added multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
obicons committed Jul 25, 2022
1 parent ac0d268 commit fd0e69d
Showing 1 changed file with 78 additions and 37 deletions.
115 changes: 78 additions & 37 deletions sa4u_z3/tu.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import ccsyspath
import clang.cindex as cindex
import json
import multiprocessing.pool
import os
import queue
import time
import z3
from dataclasses import dataclass
Expand All @@ -10,6 +12,12 @@

_tu_filename_to_stu: Dict[str, 'SerializedTU'] = {}

# Number of threads to concurrently read translation units.
_NUM_WORKERS = 8

# Maximum number of waiting TUs that need analyzed.
_MAX_WAITING_TUS = 128


@dataclass
class SerializedTU:
Expand All @@ -21,48 +29,81 @@ class SerializedTU:

def translation_units(compile_commands: cindex.CompilationDatabase, cache_path: Optional[str]) -> Iterator[Union[cindex.TranslationUnit, SerializedTU]]:
'''Returns an iterator over a translation unit for each file in the compilation database.'''
compile_command: cindex.CompileCommand
for compile_command in compile_commands.getAllCompileCommands():
if cache_path:
full_path = os.path.join(
compile_command.directory,
compile_command.filename,
)
serialized_tu = read_tu(
cache_path,
full_path,
)
modified_time = os.path.getmtime(full_path)
if serialized_tu.serialization_time >= modified_time:
log(LogLevel.INFO, f'Using cached analysis for {full_path}')
yield serialized_tu
continue
with multiprocessing.pool.ThreadPool(processes=_NUM_WORKERS) as pool:
compile_commands = list(compile_commands.getAllCompileCommands())
q: queue.Queue[Union[cindex.TranslationUnit, SerializedTU]] = queue.Queue(
maxsize=_MAX_WAITING_TUS,
)

log(
LogLevel.INFO,
f'parsing {compile_command.filename}'
pool.starmap_async(
_mp_parse_tu,
[(cmd, q, cache_path) for cmd in compile_commands],
)
try:
if 'lua' in compile_command.filename:
continue
os.chdir(compile_command.directory)
translation_unit = cindex.TranslationUnit.from_source(
os.path.join(compile_command.directory,
compile_command.filename),
args=[arg for arg in compile_command.arguments
if arg != compile_command.filename] + ['-I' + inc.decode() for inc in ccsyspath.system_include_paths('clang')],
)
for diag in translation_unit.diagnostics:
log(
LogLevel.WARNING,
f'Parsing: {compile_command.filename}: {diag}'
)
yield translation_unit
except cindex.TranslationUnitLoadError:

for _ in enumerate(compile_commands):
item = q.get()
if isinstance(item, Exception):
raise item
elif item is not None:
yield item

# compile_command: cindex.CompileCommand
# for compile_command in compile_commands.getAllCompileCommands():
# maybe_tu = _parse_tu(compile_command, cache_path)
# if maybe_tu:
# yield maybe_tu


def _mp_parse_tu(compile_command: cindex.CompileCommand, q: multiprocessing.Queue, cache_path: Optional[str]):
'''Parse a translation unit, and place the parsed result in the queue.'''
try:
q.put(_parse_tu(compile_command, cache_path))
except Exception as err:
q.put(err)


def _parse_tu(compile_command: cindex.CompileCommand, cache_path: Optional[str] = None) -> Optional[Union[cindex.TranslationUnit, SerializedTU]]:
'''Parses the translation unit.'''
if cache_path:
full_path = os.path.join(
compile_command.directory,
compile_command.filename,
)
serialized_tu = read_tu(
cache_path,
full_path,
)
modified_time = os.path.getmtime(full_path)
if serialized_tu.serialization_time >= modified_time:
log(LogLevel.INFO, f'Using cached analysis for {full_path}')
return serialized_tu

log(
LogLevel.INFO,
f'parsing {compile_command.filename}',
)
try:
if 'lua' in compile_command.filename:
return None
os.chdir(compile_command.directory)
translation_unit = cindex.TranslationUnit.from_source(
os.path.join(compile_command.directory,
compile_command.filename),
args=[arg for arg in compile_command.arguments
if arg != compile_command.filename] + ['-I' + inc.decode() for inc in ccsyspath.system_include_paths('clang')],
)
for diag in translation_unit.diagnostics:
log(
LogLevel.WARNING,
f'could not parse {os.path.join(compile_command.directory, compile_command.filename)}',
f'Parsing: {compile_command.filename}: {diag}'
)
return translation_unit
except cindex.TranslationUnitLoadError:
log(
LogLevel.WARNING,
f'could not parse {os.path.join(compile_command.directory, compile_command.filename)}',
)
return None


def read_tu(path: str, file_path: str) -> SerializedTU:
Expand Down

0 comments on commit fd0e69d

Please sign in to comment.