Skip to content

Commit

Permalink
AST-traversing search, better cli, cache invalidation (#2)
Browse files Browse the repository at this point in the history
* Work in progress

* Visitor-based search is working, much better output, need to update test cases

* Report non-existing files and mypy CompileError's

* Use normalzied paths for cache invalidation

* Ignore vscode folder

* Fix not invalidated cache when randint produce same number, update test cases

* Use non-binary mypy distirbution until we'd move to mypyc build
  • Loading branch information
butvinm authored Aug 18, 2024
1 parent 166b6b1 commit 6da1e47
Show file tree
Hide file tree
Showing 24 changed files with 1,885 additions and 255 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ ipython_config.py
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
poetry.toml

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
Expand Down Expand Up @@ -163,3 +162,4 @@ cython_debug/
#.idea/

tests/results/*.failed
.vscode/
42 changes: 42 additions & 0 deletions dora/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Dora CLI."""

import argparse
import os
import sys

from mypy.errors import CompileError

from dora.search import search


def main() -> None:
"""CLI entry point."""
parser = argparse.ArgumentParser(description='Search source files by type expressions.')
parser.add_argument(
'-t',
'--type-expression',
metavar='<type_expression>',
help='The type expression to search for. If not provided, all types in the file will be listed.',
)
parser.add_argument(
'paths',
metavar='paths',
nargs='+',
help='The source files to search in.',
)
args = parser.parse_args()

for path in args.paths:
if not os.path.exists(path):
parser.error(f'The path "{path}" does not exist.')

try:
for search_result in search(args.paths, args.type_expression):
print(search_result, end='\n\n')
except CompileError as e:
print(e, file=sys.stderr)
exit(1)


if __name__ == '__main__':
main()
76 changes: 0 additions & 76 deletions dora/main.py

This file was deleted.

205 changes: 205 additions & 0 deletions dora/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
"""Search engine."""

from random import random
from typing import Any, Generator, Iterable

from mypy.build import BuildManager, BuildResult, BuildSource, build
from mypy.find_sources import create_source_list
from mypy.nodes import MypyFile, Node
from mypy.options import Options
from mypy.plugin import Plugin, ReportConfigContext
from mypy.traverser import TraverserVisitor


class DoraPlugin(Plugin):
"""Plugin to force mypy revalidate source files.
Inspired by MypycPlugin from mypyc.
"""

def __init__(self, sources: list[BuildSource], options: Options) -> None:
"""Initialize the plugin.
Args:
sources: The build sources whose cache should be invalidated.
options: The mypy options
"""
super().__init__(options)
self._sources = {source.path for source in sources}

def report_config_data(self, ctx: ReportConfigContext) -> int | None:
"""Force revalidation of the source file.
Args:
ctx: The report configuration context.
Returns:
A random number to force revalidation of the source file.
"""
if ctx.path in self._sources:
return random() # noqa: S311

return None


class SearchResult:
"""Occurrence of a type expression in a source file."""

def __init__(self, mypy_file: MypyFile, node: Node, type_expression: str) -> None:
"""Initialize the search result.
Args:
mypy_file: The source file where the type expression was found.
node: The node where the type expression was found.
type_expression: The type expression that was found.
"""
self.mypy_file = mypy_file
self.node = node
self.type_expression = type_expression

def __str__(self) -> str:
"""Render the search result as a string.
Returns:
A string representation of the search result.
"""
node_type = self.node.__class__.__name__
node_text = self._extract_node_text(self.mypy_file.path, self.node)
column_pointer_offset = ' ' * self.node.column
result_text = f'{self.mypy_file.path}:{self.node.line}:{self.node.column}\n'
result_text += f'{column_pointer_offset}{self.type_expression} ({node_type})\n'
result_text += f'{column_pointer_offset}v\n'
result_text += node_text
return result_text

@classmethod
def _extract_node_text(cls, path: str, node: Node) -> str:
"""Extract the text of a node from the source file.
Args:
path: The path to the source file.
node: The node with location context.
Returns:
Node occurrence in the file.
"""
# probably would be to slow
# we can probably provide file content as an argument
with open(path, 'r') as f:
lines = f.readlines()

line = node.line - 1
end_line = node.end_line or node.line
lines = lines[line:end_line]
return ''.join(lines)


def search(paths: list[str], type_expression: str | None) -> Iterable[SearchResult]:
"""Search for a type expression in a source file.
Args:
paths: The source files to search in.
type_expression: The type expression to search for.
Returns:
Found occurrences of the type expression.
"""
options = Options()
options.export_types = True
options.preserve_asts = True

sources = create_source_list(paths, options)

build_result = build(
sources=sources,
options=options,
extra_plugins=[DoraPlugin(sources, options)],
)
return _search(sources, type_expression, build_result)


def _search(
sources: list[BuildSource],
type_expression: str | None,
build_result: BuildResult,
) -> Generator[SearchResult, None, None]:
"""Search for a type expression in a source file.
Args:
sources: The source files to search in.
type_expression: The type expression to search for.
build_result: The build result obtained from mypy.build.build().
Yields:
Found occurrences of the type expression.
"""
for bs in sources:
state = build_result.graph.get(bs.module)
if state is None:
continue

if state.tree is None:
continue

visitor = SearchVisitor(state.tree, type_expression, build_result.manager)
state.tree.accept(visitor)
yield from visitor.search_results


class SearchVisitor(TraverserVisitor):
"""Performs a search for a type expression in a single ??? source file."""

def __init__(
self,
mypy_file: MypyFile,
type_expression: str | None,
manager: BuildManager,
) -> None:
"""Initialize the search visitor.
Args:
mypy_file: Search source file and AST root.
type_expression: The type expression to search for.
manager: The mypy BuildManager obtained from mypy.build.build() result.
"""
super().__init__()
self.mypy_file = mypy_file
self.type_expression = type_expression
self.manager = manager
self.search_results: list[SearchResult] = []

def generic_visit(self, name: str, o: Node) -> None:
"""Check type_expression against given node.
Args:
name: Visitor name (visit_*: e.g. visit_var)
o: Target node.
Returns:
Far traversing result.
"""
node_type = self.manager.all_types.get(o)
if node_type is not None:
if self.type_expression is None:
type_expression = str(node_type)
else:
type_expression = self.type_expression

if str(node_type) == type_expression:
self.search_results.append(SearchResult(self.mypy_file, o, type_expression))

return super().__getattribute__(name)(o)

def __getattribute__(self, name: str) -> Any:
"""Mock behavior of all possible visit_* methods.
Args:
name: Arg name.
Returns:
Visit method mock if name visit_* method acquired.
"""
if name.startswith('visit_'):
return lambda o: self.generic_visit(name, o)

return super().__getattribute__(name)
Loading

0 comments on commit 6da1e47

Please sign in to comment.