diff --git a/src/databricks/labs/ucx/source_code/base.py b/src/databricks/labs/ucx/source_code/base.py index b0f9c91b21..85f5b598f6 100644 --- a/src/databricks/labs/ucx/source_code/base.py +++ b/src/databricks/labs/ucx/source_code/base.py @@ -12,15 +12,15 @@ from pathlib import Path from typing import Any, BinaryIO, TextIO -from astroid import AstroidSyntaxError, NodeNG # type: ignore -from sqlglot import Expression, parse as parse_sql, ParseError as SqlParseError +from astroid import NodeNG # type: ignore +from sqlglot import Expression, parse as parse_sql +from sqlglot.errors import SqlglotError from databricks.sdk.service import compute from databricks.sdk.service.workspace import Language from databricks.labs.blueprint.paths import WorkspacePath -from databricks.labs.ucx.source_code.python.python_ast import Tree if sys.version_info >= (3, 11): from typing import Self @@ -137,12 +137,13 @@ class SqlLinter(Linter): def lint(self, code: str) -> Iterable[Advice]: try: + # TODO: unify with SqlParser.walk_expressions(...) expressions = parse_sql(code, read='databricks') for expression in expressions: if not expression: continue yield from self.lint_expression(expression) - except SqlParseError as e: + except SqlglotError as e: logger.debug(f"Failed to parse SQL: {code}", exc_info=e) yield self.sql_parse_failure(code) @@ -162,16 +163,6 @@ def sql_parse_failure(code: str) -> Failure: def lint_expression(self, expression: Expression) -> Iterable[Advice]: ... -class PythonLinter(Linter): - - def lint(self, code: str) -> Iterable[Advice]: - tree = Tree.normalize_and_parse(code) - yield from self.lint_tree(tree) - - @abstractmethod - def lint_tree(self, tree: Tree) -> Iterable[Advice]: ... - - class Fixer(ABC): @property @@ -271,20 +262,6 @@ class UsedTableNode: node: NodeNG -class TablePyCollector(TableCollector, ABC): - - def collect_tables(self, source_code: str) -> Iterable[UsedTable]: - try: - tree = Tree.normalize_and_parse(source_code) - for table_node in self.collect_tables_from_tree(tree): - yield table_node.table - except AstroidSyntaxError as e: - logger.warning('syntax-error', exc_info=e) - - @abstractmethod - def collect_tables_from_tree(self, tree: Tree) -> Iterable[UsedTableNode]: ... - - class TableSqlCollector(TableCollector, ABC): ... @@ -309,17 +286,6 @@ class DfsaCollector(ABC): def collect_dfsas(self, source_code: str) -> Iterable[DirectFsAccess]: ... -class DfsaPyCollector(DfsaCollector, ABC): - - def collect_dfsas(self, source_code: str) -> Iterable[DirectFsAccess]: - tree = Tree.normalize_and_parse(source_code) - for dfsa_node in self.collect_dfsas_from_tree(tree): - yield dfsa_node.dfsa - - @abstractmethod - def collect_dfsas_from_tree(self, tree: Tree) -> Iterable[DirectFsAccessNode]: ... - - class DfsaSqlCollector(DfsaCollector, ABC): ... @@ -395,83 +361,6 @@ def collect_tables(self, source_code: str) -> Iterable[UsedTable]: yield from collector.collect_tables(source_code) -class PythonSequentialLinter(Linter, DfsaCollector, TableCollector): - - def __init__( - self, - linters: list[PythonLinter], - dfsa_collectors: list[DfsaPyCollector], - table_collectors: list[TablePyCollector], - ): - self._linters = linters - self._dfsa_collectors = dfsa_collectors - self._table_collectors = table_collectors - self._tree: Tree | None = None - - def lint(self, code: str) -> Iterable[Advice]: - try: - tree = self._parse_and_append(code) - yield from self.lint_tree(tree) - except AstroidSyntaxError as e: - yield Failure('syntax-error', str(e), 0, 0, 0, 0) - - def lint_tree(self, tree: Tree) -> Iterable[Advice]: - for linter in self._linters: - yield from linter.lint_tree(tree) - - def _parse_and_append(self, code: str) -> Tree: - tree = Tree.normalize_and_parse(code) - self.append_tree(tree) - return tree - - def append_tree(self, tree: Tree) -> None: - self._make_tree().append_tree(tree) - - def append_nodes(self, nodes: list[NodeNG]) -> None: - self._make_tree().append_nodes(nodes) - - def append_globals(self, globs: dict) -> None: - self._make_tree().append_globals(globs) - - def process_child_cell(self, code: str) -> None: - try: - this_tree = self._make_tree() - tree = Tree.normalize_and_parse(code) - this_tree.append_tree(tree) - except AstroidSyntaxError as e: - # error already reported when linting enclosing notebook - logger.warning(f"Failed to parse Python cell: {code}", exc_info=e) - - def collect_dfsas(self, source_code: str) -> Iterable[DirectFsAccess]: - try: - tree = self._parse_and_append(source_code) - for dfsa_node in self.collect_dfsas_from_tree(tree): - yield dfsa_node.dfsa - except AstroidSyntaxError as e: - logger.warning('syntax-error', exc_info=e) - - def collect_dfsas_from_tree(self, tree: Tree) -> Iterable[DirectFsAccessNode]: - for collector in self._dfsa_collectors: - yield from collector.collect_dfsas_from_tree(tree) - - def collect_tables(self, source_code: str) -> Iterable[UsedTable]: - try: - tree = self._parse_and_append(source_code) - for table_node in self.collect_tables_from_tree(tree): - yield table_node.table - except AstroidSyntaxError as e: - logger.warning('syntax-error', exc_info=e) - - def collect_tables_from_tree(self, tree: Tree) -> Iterable[UsedTableNode]: - for collector in self._table_collectors: - yield from collector.collect_tables_from_tree(tree) - - def _make_tree(self) -> Tree: - if self._tree is None: - self._tree = Tree.new_module() - return self._tree - - SUPPORTED_EXTENSION_LANGUAGES = { '.py': Language.PYTHON, '.sql': Language.SQL, diff --git a/src/databricks/labs/ucx/source_code/jobs.py b/src/databricks/labs/ucx/source_code/jobs.py index 5c00795442..5565f1ef11 100644 --- a/src/databricks/labs/ucx/source_code/jobs.py +++ b/src/databricks/labs/ucx/source_code/jobs.py @@ -34,7 +34,6 @@ SourceInfo, UsedTable, LineageAtom, - PythonSequentialLinter, read_text, ) from databricks.labs.ucx.source_code.directfs_access import ( @@ -52,7 +51,7 @@ ) from databricks.labs.ucx.source_code.linters.context import LinterContext from databricks.labs.ucx.source_code.notebooks.cells import CellLanguage -from databricks.labs.ucx.source_code.python.python_ast import Tree +from databricks.labs.ucx.source_code.python.python_ast import Tree, PythonSequentialLinter from databricks.labs.ucx.source_code.notebooks.sources import FileLinter, Notebook from databricks.labs.ucx.source_code.path_lookup import PathLookup from databricks.labs.ucx.source_code.used_table import UsedTablesCrawler @@ -638,8 +637,12 @@ def _collect_from_notebook( if cell.language is CellLanguage.PYTHON: if inherited_tree is None: inherited_tree = Tree.new_module() - tree = Tree.normalize_and_parse(cell.original_code) - inherited_tree.append_tree(tree) + maybe_tree = Tree.maybe_normalized_parse(cell.original_code) + if maybe_tree.failure: + logger.warning(maybe_tree.failure.message) + continue + assert maybe_tree.tree is not None + inherited_tree.append_tree(maybe_tree.tree) def _collect_from_source( self, diff --git a/src/databricks/labs/ucx/source_code/linters/context.py b/src/databricks/labs/ucx/source_code/linters/context.py index 498cd2c0a8..27686b414c 100644 --- a/src/databricks/labs/ucx/source_code/linters/context.py +++ b/src/databricks/labs/ucx/source_code/linters/context.py @@ -8,16 +8,18 @@ Linter, SqlSequentialLinter, CurrentSessionState, - PythonSequentialLinter, - PythonLinter, SqlLinter, - TablePyCollector, TableSqlCollector, TableCollector, DfsaCollector, - DfsaPyCollector, DfsaSqlCollector, ) +from databricks.labs.ucx.source_code.python.python_ast import ( + PythonLinter, + TablePyCollector, + DfsaPyCollector, + PythonSequentialLinter, +) from databricks.labs.ucx.source_code.linters.directfs import DirectFsAccessPyLinter, DirectFsAccessSqlLinter from databricks.labs.ucx.source_code.linters.imports import DbutilsPyLinter diff --git a/src/databricks/labs/ucx/source_code/linters/directfs.py b/src/databricks/labs/ucx/source_code/linters/directfs.py index 4961d9c361..d5a8272079 100644 --- a/src/databricks/labs/ucx/source_code/linters/directfs.py +++ b/src/databricks/labs/ucx/source_code/linters/directfs.py @@ -9,14 +9,18 @@ Advice, Deprecation, CurrentSessionState, - PythonLinter, SqlLinter, - DfsaPyCollector, DirectFsAccessNode, DfsaSqlCollector, DirectFsAccess, ) -from databricks.labs.ucx.source_code.python.python_ast import Tree, TreeVisitor, TreeHelper +from databricks.labs.ucx.source_code.python.python_ast import ( + Tree, + TreeVisitor, + TreeHelper, + PythonLinter, + DfsaPyCollector, +) from databricks.labs.ucx.source_code.python.python_infer import InferredValue from databricks.labs.ucx.source_code.sql.sql_parser import SqlParser, SqlExpression diff --git a/src/databricks/labs/ucx/source_code/linters/imports.py b/src/databricks/labs/ucx/source_code/linters/imports.py index b8be0bbed4..906a24cd19 100644 --- a/src/databricks/labs/ucx/source_code/linters/imports.py +++ b/src/databricks/labs/ucx/source_code/linters/imports.py @@ -17,8 +17,8 @@ NodeNG, ) -from databricks.labs.ucx.source_code.base import Advice, Advisory, CurrentSessionState, PythonLinter -from databricks.labs.ucx.source_code.python.python_ast import Tree, NodeBase, TreeVisitor +from databricks.labs.ucx.source_code.base import Advice, Advisory, CurrentSessionState +from databricks.labs.ucx.source_code.python.python_ast import Tree, NodeBase, TreeVisitor, PythonLinter from databricks.labs.ucx.source_code.python.python_infer import InferredValue from databricks.labs.ucx.source_code.path_lookup import PathLookup diff --git a/src/databricks/labs/ucx/source_code/linters/pyspark.py b/src/databricks/labs/ucx/source_code/linters/pyspark.py index f5d0b5a647..ce3832d601 100644 --- a/src/databricks/labs/ucx/source_code/linters/pyspark.py +++ b/src/databricks/labs/ucx/source_code/linters/pyspark.py @@ -11,20 +11,24 @@ Advisory, Deprecation, CurrentSessionState, - PythonLinter, SqlLinter, Fixer, UsedTable, UsedTableNode, - TablePyCollector, TableSqlCollector, - DfsaPyCollector, DfsaSqlCollector, ) from databricks.labs.ucx.source_code.linters.directfs import DIRECT_FS_ACCESS_PATTERNS, DirectFsAccessNode from databricks.labs.ucx.source_code.python.python_infer import InferredValue from databricks.labs.ucx.source_code.linters.from_table import FromTableSqlLinter -from databricks.labs.ucx.source_code.python.python_ast import Tree, TreeHelper, MatchingVisitor +from databricks.labs.ucx.source_code.python.python_ast import ( + Tree, + TreeHelper, + MatchingVisitor, + PythonLinter, + TablePyCollector, + DfsaPyCollector, +) logger = logging.getLogger(__name__) @@ -408,7 +412,12 @@ def lint_tree(self, tree: Tree) -> Iterable[Advice]: yield from matcher.lint(self._from_table, self._index, self._session_state, node) def apply(self, code: str) -> str: - tree = Tree.parse(code) + maybe_tree = Tree.maybe_parse(code) + if not maybe_tree.tree: + assert maybe_tree.failure is not None + logger.warning(maybe_tree.failure.message) + return code + tree = maybe_tree.tree # we won't be doing it like this in production, but for the sake of the example for node in tree.walk(): matcher = self._find_matcher(node) @@ -477,7 +486,12 @@ def lint_tree(self, tree: Tree) -> Iterable[Advice]: def apply(self, code: str) -> str: if not self._sql_fixer: return code - tree = Tree.normalize_and_parse(code) + maybe_tree = Tree.maybe_normalized_parse(code) + if maybe_tree.failure: + logger.warning(maybe_tree.failure.message) + return code + assert maybe_tree.tree is not None + tree = maybe_tree.tree for _call_node, query in self._visit_call_nodes(tree): if not isinstance(query, Const) or not isinstance(query.value, str): continue diff --git a/src/databricks/labs/ucx/source_code/linters/spark_connect.py b/src/databricks/labs/ucx/source_code/linters/spark_connect.py index dc67840745..78a0ba514c 100644 --- a/src/databricks/labs/ucx/source_code/linters/spark_connect.py +++ b/src/databricks/labs/ucx/source_code/linters/spark_connect.py @@ -6,12 +6,11 @@ from databricks.labs.ucx.source_code.base import ( Advice, Failure, - PythonLinter, CurrentSessionState, ) from databricks.sdk.service.compute import DataSecurityMode -from databricks.labs.ucx.source_code.python.python_ast import Tree, TreeHelper +from databricks.labs.ucx.source_code.python.python_ast import Tree, TreeHelper, PythonLinter @dataclass diff --git a/src/databricks/labs/ucx/source_code/linters/table_creation.py b/src/databricks/labs/ucx/source_code/linters/table_creation.py index 4c27865016..c90f59153b 100644 --- a/src/databricks/labs/ucx/source_code/linters/table_creation.py +++ b/src/databricks/labs/ucx/source_code/linters/table_creation.py @@ -7,9 +7,8 @@ from databricks.labs.ucx.source_code.base import ( Advice, - PythonLinter, ) -from databricks.labs.ucx.source_code.python.python_ast import Tree, TreeHelper +from databricks.labs.ucx.source_code.python.python_ast import Tree, TreeHelper, PythonLinter @dataclass diff --git a/src/databricks/labs/ucx/source_code/notebooks/sources.py b/src/databricks/labs/ucx/source_code/notebooks/sources.py index db257054cc..5889166f89 100644 --- a/src/databricks/labs/ucx/source_code/notebooks/sources.py +++ b/src/databricks/labs/ucx/source_code/notebooks/sources.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import cast -from astroid import AstroidSyntaxError, Module, NodeNG # type: ignore +from astroid import Module, NodeNG # type: ignore from databricks.sdk.service.workspace import Language @@ -17,7 +17,6 @@ Advice, Failure, Linter, - PythonSequentialLinter, CurrentSessionState, Advisory, file_language, @@ -37,7 +36,7 @@ UnresolvedPath, ) from databricks.labs.ucx.source_code.notebooks.magic import MagicLine -from databricks.labs.ucx.source_code.python.python_ast import Tree, NodeBase +from databricks.labs.ucx.source_code.python.python_ast import Tree, NodeBase, PythonSequentialLinter from databricks.labs.ucx.source_code.notebooks.cells import ( CellLanguage, Cell, @@ -196,13 +195,14 @@ def _load_tree_from_notebook(self, notebook: Notebook, register_trees: bool) -> continue def _load_tree_from_python_cell(self, python_cell: PythonCell, register_trees: bool) -> Iterable[Advice]: - try: - tree = Tree.normalize_and_parse(python_cell.original_code) - if register_trees: - self._python_trees[python_cell] = tree - yield from self._load_children_from_tree(tree) - except AstroidSyntaxError as e: - yield Failure('syntax-error', str(e), 0, 0, 0, 0) + maybe_tree = Tree.maybe_normalized_parse(python_cell.original_code) + if maybe_tree.failure: + yield maybe_tree.failure + assert maybe_tree.tree is not None + tree = maybe_tree.tree + if register_trees: + self._python_trees[python_cell] = tree + yield from self._load_children_from_tree(tree) def _load_children_from_tree(self, tree: Tree) -> Iterable[Advice]: assert isinstance(tree.node, Module) diff --git a/src/databricks/labs/ucx/source_code/python/python_analyzer.py b/src/databricks/labs/ucx/source_code/python/python_analyzer.py index 853026d637..d831452487 100644 --- a/src/databricks/labs/ucx/source_code/python/python_analyzer.py +++ b/src/databricks/labs/ucx/source_code/python/python_analyzer.py @@ -60,11 +60,12 @@ def build_graph(self) -> list[DependencyProblem]: return problems def build_inherited_context(self, child_path: Path) -> InheritedContext: - try: - tree, nodes, _ = self._parse_and_extract_nodes() - except AstroidSyntaxError: - logger.debug(f"Could not parse Python code: {self._python_code}", exc_info=True) + tree, nodes, problems = self._parse_and_extract_nodes() + if problems: + # TODO: bubble up problems via InheritedContext + logger.warning(f"Failed to parse: {problems}") return InheritedContext(None, False) + assert tree is not None, "no problems should yield a tree" if len(nodes) == 0: return InheritedContext(tree, False) context = InheritedContext(Tree.new_module(), False) @@ -92,9 +93,13 @@ def build_inherited_context(self, child_path: Path) -> InheritedContext: context.tree.append_globals(globs) return context - def _parse_and_extract_nodes(self) -> tuple[Tree, list[NodeBase], Iterable[DependencyProblem]]: + def _parse_and_extract_nodes(self) -> tuple[Tree | None, list[NodeBase], Iterable[DependencyProblem]]: problems: list[DependencyProblem] = [] - tree = Tree.normalize_and_parse(self._python_code) + maybe_tree = Tree.maybe_normalized_parse(self._python_code) + if maybe_tree.failure: + return None, [], [DependencyProblem(maybe_tree.failure.code, maybe_tree.failure.message)] + assert maybe_tree.tree is not None + tree = maybe_tree.tree syspath_changes = SysPathChange.extract_from_tree(self._context.session_state, tree) run_calls = DbutilsPyLinter.list_dbutils_notebook_run_calls(tree) import_sources: list[ImportSource] diff --git a/src/databricks/labs/ucx/source_code/python/python_ast.py b/src/databricks/labs/ucx/source_code/python/python_ast.py index c67f3dd029..15ee7574c1 100644 --- a/src/databricks/labs/ucx/source_code/python/python_ast.py +++ b/src/databricks/labs/ucx/source_code/python/python_ast.py @@ -2,10 +2,11 @@ import builtins import sys -from abc import ABC +from abc import ABC, abstractmethod import logging import re from collections.abc import Iterable +from dataclasses import dataclass from typing import TypeVar, cast from astroid import ( # type: ignore @@ -24,6 +25,19 @@ NodeNG, parse, Uninferable, + AstroidSyntaxError, +) + +from databricks.labs.ucx.source_code.base import ( + Failure, + Linter, + Advice, + TableCollector, + UsedTable, + UsedTableNode, + DfsaCollector, + DirectFsAccess, + DirectFsAccessNode, ) logger = logging.getLogger(__name__) @@ -33,18 +47,61 @@ T = TypeVar("T", bound=NodeNG) +@dataclass(frozen=True) +class MaybeTree: + tree: Tree | None + failure: Failure | None + + def walk(self) -> Iterable[NodeNG]: + # mainly a helper method for unit testing + if self.tree is None: + assert self.failure is not None + logger.warning(self.failure.message) + return [] + return self.tree.walk() + + def first_statement(self) -> NodeNG | None: + # mainly a helper method for unit testing + if self.tree is None: + assert self.failure is not None + logger.warning(self.failure.message) + return None + return self.tree.first_statement() + + class Tree: + @classmethod + def maybe_parse(cls, code: str) -> MaybeTree: + try: + root = parse(code) + tree = Tree(root) + return MaybeTree(tree, None) + except AstroidSyntaxError as e: + return cls._definitely_failure('syntax-error', code, e) + except SystemError as e: + # see https://github.com/databrickslabs/ucx/issues/2976 + return cls._definitely_failure('system-error', code, e) + @staticmethod - def parse(code: str) -> Tree: - root = parse(code) - return Tree(root) + def _definitely_failure(message_code: str, source_code: str, e: Exception) -> MaybeTree: + return MaybeTree( + None, + Failure( + code=message_code, + message=f"Failed to parse code `{source_code}`: {e}. Report this as an issue on UCX GitHub.", + # Lines and columns are both 0-based: the first line is line 0. + start_line=0, + start_col=0, + end_line=1, + end_col=1, + ), + ) @classmethod - def normalize_and_parse(cls, code: str) -> Tree: + def maybe_normalized_parse(cls, code: str) -> MaybeTree: code = cls.normalize(code) - root = parse(code) - return Tree(root) + return cls.maybe_parse(code) @classmethod def normalize(cls, code: str) -> str: @@ -496,3 +553,131 @@ def node(self) -> NodeNG: def __repr__(self): return f"<{self.__class__.__name__}: {repr(self._node)}>" + + +class PythonLinter(Linter): + + def lint(self, code: str) -> Iterable[Advice]: + maybe_tree = Tree.maybe_normalized_parse(code) + if maybe_tree.failure: + yield maybe_tree.failure + return + assert maybe_tree.tree is not None + yield from self.lint_tree(maybe_tree.tree) + + @abstractmethod + def lint_tree(self, tree: Tree) -> Iterable[Advice]: ... + + +class TablePyCollector(TableCollector, ABC): + + def collect_tables(self, source_code: str) -> Iterable[UsedTable]: + maybe_tree = Tree.maybe_normalized_parse(source_code) + if maybe_tree.failure: + logger.warning(maybe_tree.failure.message) + return + assert maybe_tree.tree is not None + for table_node in self.collect_tables_from_tree(maybe_tree.tree): + yield table_node.table + + @abstractmethod + def collect_tables_from_tree(self, tree: Tree) -> Iterable[UsedTableNode]: ... + + +class DfsaPyCollector(DfsaCollector, ABC): + + def collect_dfsas(self, source_code: str) -> Iterable[DirectFsAccess]: + maybe_tree = Tree.maybe_normalized_parse(source_code) + if maybe_tree.failure: + logger.warning(maybe_tree.failure.message) + return + assert maybe_tree.tree is not None + for dfsa_node in self.collect_dfsas_from_tree(maybe_tree.tree): + yield dfsa_node.dfsa + + @abstractmethod + def collect_dfsas_from_tree(self, tree: Tree) -> Iterable[DirectFsAccessNode]: ... + + +class PythonSequentialLinter(Linter, DfsaCollector, TableCollector): + + def __init__( + self, + linters: list[PythonLinter], + dfsa_collectors: list[DfsaPyCollector], + table_collectors: list[TablePyCollector], + ): + self._linters = linters + self._dfsa_collectors = dfsa_collectors + self._table_collectors = table_collectors + self._tree: Tree | None = None + + def lint(self, code: str) -> Iterable[Advice]: + maybe_tree = self._parse_and_append(code) + if maybe_tree.failure: + yield maybe_tree.failure + return + assert maybe_tree.tree is not None + yield from self.lint_tree(maybe_tree.tree) + + def lint_tree(self, tree: Tree) -> Iterable[Advice]: + for linter in self._linters: + yield from linter.lint_tree(tree) + + def _parse_and_append(self, code: str) -> MaybeTree: + maybe_tree = Tree.maybe_normalized_parse(code) + if maybe_tree.failure: + return maybe_tree + assert maybe_tree.tree is not None + self.append_tree(maybe_tree.tree) + return maybe_tree + + def append_tree(self, tree: Tree) -> None: + self._make_tree().append_tree(tree) + + def append_nodes(self, nodes: list[NodeNG]) -> None: + self._make_tree().append_nodes(nodes) + + def append_globals(self, globs: dict) -> None: + self._make_tree().append_globals(globs) + + def process_child_cell(self, code: str) -> None: + this_tree = self._make_tree() + maybe_tree = Tree.maybe_normalized_parse(code) + if maybe_tree.failure: + # TODO: bubble up this error + logger.warning(maybe_tree.failure.message) + return + assert maybe_tree.tree is not None + this_tree.append_tree(maybe_tree.tree) + + def collect_dfsas(self, source_code: str) -> Iterable[DirectFsAccess]: + maybe_tree = self._parse_and_append(source_code) + if maybe_tree.failure: + logger.warning(maybe_tree.failure.message) + return + assert maybe_tree.tree is not None + for dfsa_node in self.collect_dfsas_from_tree(maybe_tree.tree): + yield dfsa_node.dfsa + + def collect_dfsas_from_tree(self, tree: Tree) -> Iterable[DirectFsAccessNode]: + for collector in self._dfsa_collectors: + yield from collector.collect_dfsas_from_tree(tree) + + def collect_tables(self, source_code: str) -> Iterable[UsedTable]: + maybe_tree = self._parse_and_append(source_code) + if maybe_tree.failure: + logger.warning(maybe_tree.failure.message) + return + assert maybe_tree.tree is not None + for table_node in self.collect_tables_from_tree(maybe_tree.tree): + yield table_node.table + + def collect_tables_from_tree(self, tree: Tree) -> Iterable[UsedTableNode]: + for collector in self._table_collectors: + yield from collector.collect_tables_from_tree(tree) + + def _make_tree(self) -> Tree: + if self._tree is None: + self._tree = Tree.new_module() + return self._tree diff --git a/tests/integration/source_code/message_codes.py b/tests/integration/source_code/message_codes.py index f118b4e871..138dc5a2cb 100644 --- a/tests/integration/source_code/message_codes.py +++ b/tests/integration/source_code/message_codes.py @@ -11,7 +11,10 @@ def main(): product_info = ProductInfo.from_class(Advice) source_code = product_info.version_file().parent for file in source_code.glob("**/*.py"): - tree = Tree.parse(file.read_text()) + maybe_tree = Tree.maybe_parse(file.read_text()) + if not maybe_tree.tree: + continue + tree = maybe_tree.tree # recursively detect values of "code" kwarg in calls for node in tree.walk(): if not isinstance(node, astroid.Call): diff --git a/tests/unit/source_code/linters/test_pyspark.py b/tests/unit/source_code/linters/test_pyspark.py index e4a1df596c..526e9eb3a4 100644 --- a/tests/unit/source_code/linters/test_pyspark.py +++ b/tests/unit/source_code/linters/test_pyspark.py @@ -576,40 +576,45 @@ def test_direct_cloud_access_to_volumes_reports_nothing(empty_index, fs_function def test_get_full_function_name_for_member_function() -> None: - tree = Tree.parse("value.attr()") - node = tree.first_statement() + tree = Tree.maybe_parse("value.attr()") + assert tree.tree is not None + node = tree.tree.first_statement() assert isinstance(node, Expr) assert isinstance(node.value, Call) assert TreeHelper.get_full_function_name(node.value) == 'value.attr' def test_get_full_function_name_for_member_member_function() -> None: - tree = Tree.parse("value1.value2.attr()") - node = tree.first_statement() + tree = Tree.maybe_parse("value1.value2.attr()") + assert tree.tree is not None + node = tree.tree.first_statement() assert isinstance(node, Expr) assert isinstance(node.value, Call) assert TreeHelper.get_full_function_name(node.value) == 'value1.value2.attr' def test_get_full_function_name_for_chained_function() -> None: - tree = Tree.parse("value.attr1().attr2()") - node = tree.first_statement() + tree = Tree.maybe_parse("value.attr1().attr2()") + assert tree.tree is not None + node = tree.tree.first_statement() assert isinstance(node, Expr) assert isinstance(node.value, Call) assert TreeHelper.get_full_function_name(node.value) == 'value.attr1.attr2' def test_get_full_function_name_for_global_function() -> None: - tree = Tree.parse("name()") - node = tree.first_statement() + tree = Tree.maybe_parse("name()") + assert tree.tree is not None + node = tree.tree.first_statement() assert isinstance(node, Expr) assert isinstance(node.value, Call) assert TreeHelper.get_full_function_name(node.value) == 'name' def test_get_full_function_name_for_non_method() -> None: - tree = Tree.parse("not_a_function") - node = tree.first_statement() + tree = Tree.maybe_parse("not_a_function") + assert tree.tree is not None + node = tree.tree.first_statement() assert isinstance(node, Expr) assert TreeHelper.get_full_function_name(node.value) is None @@ -617,8 +622,9 @@ def test_get_full_function_name_for_non_method() -> None: def test_apply_table_name_matcher_with_missing_constant(migration_index) -> None: from_table = FromTableSqlLinter(migration_index, CurrentSessionState('old')) matcher = SparkCallMatcher('things', 1, 1, 0) - tree = Tree.parse("call('some.things')") - node = tree.first_statement() + tree = Tree.maybe_parse("call('some.things')") + assert tree.tree is not None + node = tree.tree.first_statement() assert isinstance(node, Expr) assert isinstance(node.value, Call) matcher.apply(from_table, migration_index, node.value) @@ -630,8 +636,9 @@ def test_apply_table_name_matcher_with_missing_constant(migration_index) -> None def test_apply_table_name_matcher_with_existing_constant(migration_index) -> None: from_table = FromTableSqlLinter(migration_index, CurrentSessionState('old')) matcher = SparkCallMatcher('things', 1, 1, 0) - tree = Tree.parse("call('old.things')") - node = tree.first_statement() + tree = Tree.maybe_parse("call('old.things')") + assert tree.tree is not None + node = tree.tree.first_statement() assert isinstance(node, Expr) assert isinstance(node.value, Call) matcher.apply(from_table, migration_index, node.value) diff --git a/tests/unit/source_code/linters/test_python_imports.py b/tests/unit/source_code/linters/test_python_imports.py index 98dcca2c83..55837928ef 100644 --- a/tests/unit/source_code/linters/test_python_imports.py +++ b/tests/unit/source_code/linters/test_python_imports.py @@ -14,8 +14,9 @@ def test_linter_returns_empty_list_of_dbutils_notebook_run_calls() -> None: - tree = Tree.parse('') - assert not DbutilsPyLinter.list_dbutils_notebook_run_calls(tree) + tree = Tree.maybe_parse('') + assert tree.tree is not None + assert not DbutilsPyLinter.list_dbutils_notebook_run_calls(tree.tree) def test_linter_returns_list_of_dbutils_notebook_run_calls() -> None: @@ -24,34 +25,40 @@ def test_linter_returns_list_of_dbutils_notebook_run_calls() -> None: for i in z: ww = dbutils.notebook.run("toto") """ - tree = Tree.parse(code) - calls = DbutilsPyLinter.list_dbutils_notebook_run_calls(tree) + tree = Tree.maybe_parse(code) + assert tree.tree is not None + calls = DbutilsPyLinter.list_dbutils_notebook_run_calls(tree.tree) assert {"toto", "stuff"} == {str(call.node.args[0].value) for call in calls} def test_linter_returns_empty_list_of_imports() -> None: - tree = Tree.parse('') - assert not ImportSource.extract_from_tree(tree, DependencyProblem.from_node)[0] + tree = Tree.maybe_parse('') + assert tree.tree is not None + assert not ImportSource.extract_from_tree(tree.tree, DependencyProblem.from_node)[0] def test_linter_returns_import() -> None: - tree = Tree.parse('import x') - assert ["x"] == [node.name for node in ImportSource.extract_from_tree(tree, DependencyProblem.from_node)[0]] + tree = Tree.maybe_parse('import x') + assert tree.tree is not None + assert ["x"] == [node.name for node in ImportSource.extract_from_tree(tree.tree, DependencyProblem.from_node)[0]] def test_linter_returns_import_from() -> None: - tree = Tree.parse('from x import z') - assert ["x"] == [node.name for node in ImportSource.extract_from_tree(tree, DependencyProblem.from_node)[0]] + tree = Tree.maybe_parse('from x import z') + assert tree.tree is not None + assert ["x"] == [node.name for node in ImportSource.extract_from_tree(tree.tree, DependencyProblem.from_node)[0]] def test_linter_returns_import_module() -> None: - tree = Tree.parse('importlib.import_module("x")') - assert ["x"] == [node.name for node in ImportSource.extract_from_tree(tree, DependencyProblem.from_node)[0]] + tree = Tree.maybe_parse('importlib.import_module("x")') + assert tree.tree is not None + assert ["x"] == [node.name for node in ImportSource.extract_from_tree(tree.tree, DependencyProblem.from_node)[0]] def test_linter_returns__import__() -> None: - tree = Tree.parse('importlib.__import__("x")') - assert ["x"] == [node.name for node in ImportSource.extract_from_tree(tree, DependencyProblem.from_node)[0]] + tree = Tree.maybe_parse('importlib.__import__("x")') + assert tree.tree is not None + assert ["x"] == [node.name for node in ImportSource.extract_from_tree(tree.tree, DependencyProblem.from_node)[0]] def test_linter_returns_appended_absolute_paths() -> None: @@ -60,8 +67,9 @@ def test_linter_returns_appended_absolute_paths() -> None: sys.path.append("absolute_path_1") sys.path.append("absolute_path_2") """ - tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) + tree = Tree.maybe_parse(code) + assert tree.tree is not None + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree.tree) assert ["absolute_path_1", "absolute_path_2"] == [p.path for p in appended] @@ -71,8 +79,9 @@ def test_linter_returns_appended_absolute_paths_with_sys_alias() -> None: stuff.path.append("absolute_path_1") stuff.path.append("absolute_path_2") """ - tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) + tree = Tree.maybe_parse(code) + assert tree.tree is not None + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree.tree) assert ["absolute_path_1", "absolute_path_2"] == [p.path for p in appended] @@ -81,8 +90,9 @@ def test_linter_returns_appended_absolute_paths_with_sys_path_alias() -> None: from sys import path as stuff stuff.append("absolute_path") """ - tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) + tree = Tree.maybe_parse(code) + assert tree.tree is not None + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree.tree) assert "absolute_path" in [p.path for p in appended] @@ -92,8 +102,9 @@ def test_linter_returns_appended_relative_paths() -> None: import os sys.path.append(os.path.abspath("relative_path")) """ - tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) + tree = Tree.maybe_parse(code) + assert tree.tree is not None + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree.tree) assert "relative_path" in [p.path for p in appended] @@ -103,8 +114,9 @@ def test_linter_returns_appended_relative_paths_with_os_alias() -> None: import os as stuff sys.path.append(stuff.path.abspath("relative_path")) """ - tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) + tree = Tree.maybe_parse(code) + assert tree.tree is not None + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree.tree) assert "relative_path" in [p.path for p in appended] @@ -114,8 +126,9 @@ def test_linter_returns_appended_relative_paths_with_os_path_alias() -> None: from os import path as stuff sys.path.append(stuff.abspath("relative_path")) """ - tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) + tree = Tree.maybe_parse(code) + assert tree.tree is not None + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree.tree) assert "relative_path" in [p.path for p in appended] @@ -125,8 +138,9 @@ def test_linter_returns_appended_relative_paths_with_os_path_abspath_import() -> from os.path import abspath sys.path.append(abspath("relative_path")) """ - tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) + tree = Tree.maybe_parse(code) + assert tree.tree is not None + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree.tree) assert "relative_path" in [p.path for p in appended] @@ -136,8 +150,9 @@ def test_linter_returns_appended_relative_paths_with_os_path_abspath_alias() -> from os.path import abspath as stuff sys.path.append(stuff("relative_path")) """ - tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) + tree = Tree.maybe_parse(code) + assert tree.tree is not None + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree.tree) assert "relative_path" in [p.path for p in appended] @@ -147,8 +162,9 @@ def test_linter_returns_inferred_paths() -> None: path = "absolute_path_1" sys.path.append(path) """ - tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) + tree = Tree.maybe_parse(code) + assert tree.tree is not None + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree.tree) assert ["absolute_path_1"] == [p.path for p in appended] @@ -188,8 +204,9 @@ def foo(): return "bar" ], ) def test_infers_dbutils_notebook_run_dynamic_value(code, expected) -> None: - tree = Tree.parse(code) - calls = DbutilsPyLinter.list_dbutils_notebook_run_calls(tree) + tree = Tree.maybe_parse(code) + assert tree.tree is not None + calls = DbutilsPyLinter.list_dbutils_notebook_run_calls(tree.tree) all_paths: list[str] = [] for call in calls: _, paths = call.get_notebook_paths(CurrentSessionState()) diff --git a/tests/unit/source_code/linters/test_spark_connect.py b/tests/unit/source_code/linters/test_spark_connect.py index 25c285b76b..4932a86e83 100644 --- a/tests/unit/source_code/linters/test_spark_connect.py +++ b/tests/unit/source_code/linters/test_spark_connect.py @@ -236,7 +236,7 @@ def test_logging_shared(session_state) -> None: end_line=6, end_col=24, ), - ] == list(chain.from_iterable([logging_matcher.lint(node) for node in Tree.parse(code).walk()])) + ] == list(chain.from_iterable([logging_matcher.lint(node) for node in Tree.maybe_parse(code).walk()])) def test_logging_serverless(session_state) -> None: @@ -248,6 +248,9 @@ def test_logging_serverless(session_state) -> None: """ + maybe_tree = Tree.maybe_parse(code) + assert maybe_tree.tree is not None + tree = maybe_tree.tree assert [ Failure( code='spark-logging-in-shared-clusters', @@ -266,7 +269,7 @@ def test_logging_serverless(session_state) -> None: end_line=2, end_col=38, ), - ] == list(chain.from_iterable([logging_matcher.lint(node) for node in Tree.parse(code).walk()])) + ] == list(chain.from_iterable([logging_matcher.lint(node) for node in tree.walk()])) def test_valid_code() -> None: diff --git a/tests/unit/source_code/notebooks/test_cells.py b/tests/unit/source_code/notebooks/test_cells.py index 07fa2bd5ac..abbc871410 100644 --- a/tests/unit/source_code/notebooks/test_cells.py +++ b/tests/unit/source_code/notebooks/test_cells.py @@ -184,22 +184,18 @@ def test_pip_cell_build_dependency_graph_handles_multiline_code() -> None: def test_graph_builder_parse_error( - simple_dependency_resolver: DependencyResolver, mock_path_lookup: PathLookup + simple_dependency_resolver: DependencyResolver, + mock_path_lookup: PathLookup, ) -> None: """Check that internal parsing errors are caught and logged.""" - # Fixture. dependency = Dependency(FileLoader(), Path("")) graph = DependencyGraph(dependency, None, simple_dependency_resolver, mock_path_lookup, CurrentSessionState()) analyser = PythonCodeAnalyzer(graph.new_dependency_graph_context(), "this is not valid python") - # Run the test. - problems = [] - for problem in analyser.build_graph(): - if problem.code == "parse-error" and problem.message.startswith("Could not parse Python code"): - problems.append(problem) + problems = analyser.build_graph() + codes = {_.code for _ in problems} - # Check results. - assert problems + assert codes == {'syntax-error'} def test_parses_python_cell_with_magic_commands(simple_dependency_resolver, mock_path_lookup) -> None: @@ -272,7 +268,9 @@ def test_unsupported_magic_raises_problem(simple_dependency_resolver, mock_path_ source = """ %unsupported stuff '"%#@! """ - tree = Tree.normalize_and_parse(source) + maybe_tree = Tree.maybe_normalized_parse(source) + assert maybe_tree.tree, maybe_tree.failure + tree = maybe_tree.tree commands, _ = MagicLine.extract_from_tree(tree, DependencyProblem.from_node) dependency = Dependency(FileLoader(), Path("")) graph = DependencyGraph(dependency, None, simple_dependency_resolver, mock_path_lookup, CurrentSessionState()) diff --git a/tests/unit/source_code/python/test_python_ast.py b/tests/unit/source_code/python/test_python_ast.py index 0d9ddeb6d1..c23b76c4ef 100644 --- a/tests/unit/source_code/python/test_python_ast.py +++ b/tests/unit/source_code/python/test_python_ast.py @@ -1,12 +1,14 @@ import pytest -from astroid import Assign, AstroidSyntaxError, Attribute, Call, Const, Expr, Module, Name # type: ignore +from astroid import Assign, Attribute, Call, Const, Expr, Module, Name # type: ignore from databricks.labs.ucx.source_code.python.python_ast import Tree, TreeHelper from databricks.labs.ucx.source_code.python.python_infer import InferredValue def test_extracts_root() -> None: - tree = Tree.parse("o.m1().m2().m3()") + maybe_tree = Tree.maybe_parse("o.m1().m2().m3()") + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree stmt = tree.first_statement() root = Tree(stmt).root assert root == tree.node @@ -14,12 +16,13 @@ def test_extracts_root() -> None: def test_no_first_statement() -> None: - tree = Tree.parse("") - assert not tree.first_statement() + maybe_tree = Tree.maybe_parse("") + assert maybe_tree.tree is not None + assert maybe_tree.tree.first_statement() is None def test_extract_call_by_name() -> None: - tree = Tree.parse("o.m1().m2().m3()") + tree = Tree.maybe_parse("o.m1().m2().m3()") stmt = tree.first_statement() assert isinstance(stmt, Expr) assert isinstance(stmt.value, Call) @@ -30,7 +33,7 @@ def test_extract_call_by_name() -> None: def test_extract_call_by_name_none() -> None: - tree = Tree.parse("o.m1().m2().m3()") + tree = Tree.maybe_parse("o.m1().m2().m3()") stmt = tree.first_statement() assert isinstance(stmt, Expr) assert isinstance(stmt.value, Call) @@ -56,7 +59,7 @@ def test_extract_call_by_name_none() -> None: ], ) def test_linter_gets_arg(code, arg_index, arg_name, expected) -> None: - tree = Tree.parse(code) + tree = Tree.maybe_parse(code) stmt = tree.first_statement() assert isinstance(stmt, Expr) assert isinstance(stmt.value, Call) @@ -81,7 +84,7 @@ def test_linter_gets_arg(code, arg_index, arg_name, expected) -> None: ], ) def test_args_count(code, expected) -> None: - tree = Tree.parse(code) + tree = Tree.maybe_parse(code) stmt = tree.first_statement() assert isinstance(stmt, Expr) assert isinstance(stmt.value, Call) @@ -92,7 +95,7 @@ def test_args_count(code, expected) -> None: def test_tree_walks_nodes_once() -> None: nodes = set() count = 0 - tree = Tree.parse("o.m1().m2().m3()") + tree = Tree.maybe_parse("o.m1().m2().m3()") for node in tree.walk(): nodes.add(node) count += 1 @@ -119,10 +122,10 @@ def test_parses_incorrectly_indented_code() -> None: ) """ # ensure it would fail if not normalized - with pytest.raises(AstroidSyntaxError): - Tree.parse(source) - Tree.normalize_and_parse(source) - assert True + maybe_tree = Tree.maybe_parse(source) + assert maybe_tree.failure is not None + maybe_tree = Tree.maybe_normalized_parse(source) + assert maybe_tree.failure is None def test_ignores_magic_marker_in_multiline_comment() -> None: @@ -132,15 +135,19 @@ def test_ignores_magic_marker_in_multiline_comment() -> None: version="version" formatted=message_unformatted % (name, version) """ - Tree.normalize_and_parse(source) + Tree.maybe_normalized_parse(source) assert True def test_appends_statements() -> None: source_1 = "a = 'John'" - tree_1 = Tree.normalize_and_parse(source_1) + maybe_tree_1 = Tree.maybe_normalized_parse(source_1) + assert maybe_tree_1.tree is not None, maybe_tree_1.failure + tree_1 = maybe_tree_1.tree source_2 = 'b = f"Hello {a}!"' - tree_2 = Tree.normalize_and_parse(source_2) + maybe_tree_2 = Tree.maybe_normalized_parse(source_2) + assert maybe_tree_2.tree is not None, maybe_tree_2.failure + tree_2 = maybe_tree_2.tree tree_3 = tree_1.append_tree(tree_2) nodes = tree_3.locate(Assign, []) tree = Tree(nodes[0].value) # tree_3 only contains tree_2 statements @@ -154,7 +161,9 @@ def test_is_from_module() -> None: df = spark.read.csv("hi") df.write.format("delta").saveAsTable("old.things") """ - tree = Tree.normalize_and_parse(source) + maybe_tree = Tree.maybe_normalized_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree save_call = tree.locate( Call, [("saveAsTable", Attribute), ("format", Attribute), ("write", Attribute), ("df", Name)] )[0] @@ -163,7 +172,9 @@ def test_is_from_module() -> None: @pytest.mark.parametrize("source, name, class_name", [("a = 123", "a", "int")]) def test_is_instance_of(source, name, class_name) -> None: - tree = Tree.normalize_and_parse(source) + maybe_tree = Tree.maybe_normalized_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree assert isinstance(tree.node, Module) module = tree.node var = module.globals.get(name, None) @@ -181,11 +192,18 @@ def test_supports_recursive_refs_when_checking_module() -> None: source_3 = """ df = df.withColumn(stuff2) """ - main_tree = Tree.normalize_and_parse(source_1) - main_tree.append_tree(Tree.normalize_and_parse(source_2)) - tree = Tree.normalize_and_parse(source_3) - main_tree.append_tree(tree) - assign = tree.locate(Assign, [])[0] + maybe_tree = Tree.maybe_normalized_parse(source_1) + assert maybe_tree.tree is not None, maybe_tree.failure + main_tree = maybe_tree.tree + maybe_tree_2 = Tree.maybe_normalized_parse(source_2) + assert maybe_tree_2.tree is not None, maybe_tree_2.failure + tree_2 = maybe_tree_2.tree + main_tree.append_tree(tree_2) + maybe_tree_3 = Tree.maybe_normalized_parse(source_3) + assert maybe_tree_3.tree is not None, maybe_tree_3.failure + tree_3 = maybe_tree_3.tree + main_tree.append_tree(tree_3) + assign = tree_3.locate(Assign, [])[0] assert Tree(assign.value).is_from_module("spark") @@ -193,7 +211,9 @@ def test_renumbers_positively() -> None: source = """df = spark.read.csv("hi") df.write.format("delta").saveAsTable("old.things") """ - tree = Tree.normalize_and_parse(source) + maybe_tree = Tree.maybe_normalized_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree nodes = list(tree.node.get_children()) assert len(nodes) == 2 assert nodes[0].lineno == 1 @@ -209,7 +229,9 @@ def test_renumbers_negatively() -> None: source = """df = spark.read.csv("hi") df.write.format("delta").saveAsTable("old.things") """ - tree = Tree.normalize_and_parse(source) + maybe_tree = Tree.maybe_normalized_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree nodes = list(tree.node.get_children()) assert len(nodes) == 2 assert nodes[0].lineno == 1 @@ -231,7 +253,9 @@ def test_renumbers_negatively() -> None: ], ) def test_counts_lines(source: str, line_count: int) -> None: - tree = Tree.normalize_and_parse(source) + maybe_tree = Tree.maybe_normalized_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree assert tree.line_count() == line_count @@ -251,7 +275,9 @@ def test_counts_lines(source: str, line_count: int) -> None: ], ) def test_is_builtin(source, name, is_builtin) -> None: - tree = Tree.normalize_and_parse(source) + maybe_tree = Tree.maybe_normalized_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree nodes = list(tree.node.get_children()) for node in nodes: if isinstance(node, Assign): diff --git a/tests/unit/source_code/python/test_python_infer.py b/tests/unit/source_code/python/test_python_infer.py index a0ffa6fe72..f9c14aca7d 100644 --- a/tests/unit/source_code/python/test_python_infer.py +++ b/tests/unit/source_code/python/test_python_infer.py @@ -6,7 +6,9 @@ def test_infers_empty_list() -> None: - tree = Tree.parse("a=[]") + maybe_tree = Tree.maybe_parse("a=[]") + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree nodes = tree.locate(Assign, []) tree = Tree(nodes[0].value) values = list(InferredValue.infer_from_node(tree.node)) @@ -14,7 +16,9 @@ def test_infers_empty_list() -> None: def test_infers_empty_tuple() -> None: - tree = Tree.parse("a=tuple()") + maybe_tree = Tree.maybe_parse("a=tuple()") + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree nodes = tree.locate(Assign, []) tree = Tree(nodes[0].value) values = list(InferredValue.infer_from_node(tree.node)) @@ -22,7 +26,9 @@ def test_infers_empty_tuple() -> None: def test_infers_empty_set() -> None: - tree = Tree.parse("a={}") + maybe_tree = Tree.maybe_parse("a={}") + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree nodes = tree.locate(Assign, []) tree = Tree(nodes[0].value) values = list(InferredValue.infer_from_node(tree.node)) @@ -34,7 +40,9 @@ def test_infers_fstring_value() -> None: value = "abc" fstring = f"Hello {value}!" """ - tree = Tree.parse(source) + maybe_tree = Tree.maybe_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree nodes = tree.locate(Assign, []) tree = Tree(nodes[1].value) # value of fstring = ... values = list(InferredValue.infer_from_node(tree.node)) @@ -48,7 +56,9 @@ def test_infers_fstring_dict_value() -> None: value = { "abc": 123 } fstring = f"Hello {value['abc']}!" """ - tree = Tree.parse(source) + maybe_tree = Tree.maybe_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree nodes = tree.locate(Assign, []) tree = Tree(nodes[1].value) # value of fstring = ... values = list(InferredValue.infer_from_node(tree.node)) @@ -62,7 +72,9 @@ def test_infers_string_format_value() -> None: value = "abc" fstring = "Hello {0}!".format(value) """ - tree = Tree.parse(source) + maybe_tree = Tree.maybe_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree nodes = tree.locate(Assign, []) tree = Tree(nodes[1].value) # value of fstring = ... values = list(InferredValue.infer_from_node(tree.node)) @@ -79,7 +91,9 @@ def test_infers_fstring_values() -> None: for value2 in values_2: fstring = f"Hello {value1}, {value2}!" """ - tree = Tree.parse(source) + maybe_tree = Tree.maybe_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree nodes = tree.locate(Assign, []) tree = Tree(nodes[2].value) # value of fstring = ... values = list(InferredValue.infer_from_node(tree.node)) @@ -95,7 +109,9 @@ def test_infers_externally_defined_value() -> None: name = "my-widget" value = dbutils.widgets.get(name) """ - tree = Tree.parse(source) + maybe_tree = Tree.maybe_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree nodes = tree.locate(Assign, []) tree = Tree(nodes[1].value) # value of value = ... values = list(InferredValue.infer_from_node(tree.node, state)) @@ -110,7 +126,9 @@ def test_infers_externally_defined_values() -> None: for name in ["my-widget-1", "my-widget-2"]: value = dbutils.widgets.get(name) """ - tree = Tree.parse(source) + maybe_tree = Tree.maybe_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree nodes = tree.locate(Assign, []) tree = Tree(nodes[0].value) # value of value = ... values = list(InferredValue.infer_from_node(tree.node, state)) @@ -125,7 +143,9 @@ def test_fails_to_infer_missing_externally_defined_value() -> None: name = "my-widget" value = dbutils.widgets.get(name) """ - tree = Tree.parse(source) + maybe_tree = Tree.maybe_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree nodes = tree.locate(Assign, []) tree = Tree(nodes[1].value) # value of value = ... values = InferredValue.infer_from_node(tree.node, state) @@ -137,7 +157,9 @@ def test_survives_absence_of_externally_defined_values() -> None: name = "my-widget" value = dbutils.widgets.get(name) """ - tree = Tree.parse(source) + maybe_tree = Tree.maybe_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree nodes = tree.locate(Assign, []) tree = Tree(nodes[1].value) # value of value = ... values = InferredValue.infer_from_node(tree.node, CurrentSessionState()) @@ -152,7 +174,9 @@ def test_infers_externally_defined_value_set() -> None: name = "my-widget" value = values[name] """ - tree = Tree.parse(source) + maybe_tree = Tree.maybe_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree nodes = tree.locate(Assign, []) tree = Tree(nodes[2].value) # value of value = ... values = list(InferredValue.infer_from_node(tree.node, state)) diff --git a/tests/unit/source_code/samples/functional/zoo.py b/tests/unit/source_code/samples/functional/zoo.py new file mode 100644 index 0000000000..4e22c645c6 --- /dev/null +++ b/tests/unit/source_code/samples/functional/zoo.py @@ -0,0 +1,9 @@ +# Databricks notebook source +# ucx[session-state] {"data_security_mode": "USER_ISOLATION"} + +# ucx[direct-filesystem-access-in-sql-query:+1:0:+1:86] The use of direct filesystem references is deprecated: /foo/bar +spark.sql("SELECT * FROM db.table LEFT JOIN delta.`/foo/bar` AS t ON t.id = table.id").show() + +# ucx[rdd-in-shared-clusters:+2:29:+2:42] RDD APIs are not supported on Unity Catalog clusters in Shared access mode. Rewrite it using DataFrame API +# ucx[legacy-context-in-shared-clusters:+1:29:+1:40] sc is not supported on Unity Catalog clusters in Shared access mode. Rewrite it using spark +rdd2 = spark.createDataFrame(sc.emptyRDD(), 'foo') diff --git a/tests/unit/source_code/test_jobs.py b/tests/unit/source_code/test_jobs.py index cbd2f87394..d067b3b0c0 100644 --- a/tests/unit/source_code/test_jobs.py +++ b/tests/unit/source_code/test_jobs.py @@ -5,6 +5,7 @@ from unittest.mock import create_autospec import pytest +from databricks.labs.lsql.backends import MockBackend from databricks.sdk.service.compute import LibraryInstallStatus from databricks.sdk.service.jobs import Job, SparkPythonTask from databricks.sdk.service.pipelines import NotebookLibrary, GetPipelineResponse, PipelineLibrary, FileLibrary @@ -17,7 +18,7 @@ from databricks.sdk import WorkspaceClient from databricks.sdk.errors import NotFound from databricks.sdk.service import compute, jobs, pipelines -from databricks.sdk.service.workspace import ExportFormat +from databricks.sdk.service.workspace import ExportFormat, ObjectInfo, Language from databricks.labs.ucx.source_code.linters.files import FileLoader, ImportFileResolver from databricks.labs.ucx.source_code.graph import ( @@ -537,3 +538,56 @@ def test_linting_walker_populates_paths(dependency_resolver, mock_path_lookup, m advices += 1 assert "UNKNOWN" not in advice.path.as_posix() assert advices + + +def test_workflow_linter_refresh_report(dependency_resolver, mock_path_lookup, migration_index) -> None: + ws = create_autospec(WorkspaceClient) + ws.workspace.get_status.return_value = ObjectInfo(object_id=123, language=Language.PYTHON) + some_things = mock_path_lookup.resolve(Path("functional/zoo.py")) + assert some_things is not None + ws.workspace.download.return_value = some_things.read_bytes() + notebook_task = jobs.NotebookTask( + notebook_path=some_things.absolute().as_posix(), + base_parameters={"a": "b", "c": "dbfs:/mnt/foo"}, + ) + task = jobs.Task( + task_key="test", + job_cluster_key="main", + notebook_task=notebook_task, + ) + settings = jobs.JobSettings( + tasks=[task], + name='some', + job_clusters=[ + jobs.JobCluster( + job_cluster_key="main", + new_cluster=compute.ClusterSpec( + spark_version="15.2.x-photon-scala2.12", + node_type_id="Standard_F4s", + num_workers=2, + data_security_mode=compute.DataSecurityMode.LEGACY_TABLE_ACL, + spark_conf={"spark.databricks.cluster.profile": "singleNode"}, + ), + ), + ], + ) + ws.jobs.list.return_value = [Job(job_id=1), Job(job_id=2, settings=settings)] + ws.jobs.get.return_value = Job(job_id=2, settings=settings) + + sql_backend = MockBackend() + directfs_crawler = DirectFsAccessCrawler.for_paths(sql_backend, "test") + used_tables_crawler = UsedTablesCrawler.for_paths(sql_backend, "test") + linter = WorkflowLinter( + ws, + dependency_resolver, + mock_path_lookup, + migration_index, + directfs_crawler, + used_tables_crawler, + [1], + ) + linter.refresh_report(sql_backend, 'test') + + sql_backend.has_rows_written_for('test.workflow_problems') + sql_backend.has_rows_written_for('hive_metastore.test.used_tables_in_paths') + sql_backend.has_rows_written_for('hive_metastore.test.directfs_in_paths') diff --git a/tests/unit/source_code/test_notebook.py b/tests/unit/source_code/test_notebook.py index b255a0e398..58cde4a8e3 100644 --- a/tests/unit/source_code/test_notebook.py +++ b/tests/unit/source_code/test_notebook.py @@ -262,8 +262,9 @@ def test_detects_multiple_calls_to_dbutils_notebook_run_in_python_code() -> None stuff3 = dbutils.notebook.run("where is notebook 2?") """ linter = DbutilsPyLinter(CurrentSessionState()) - tree = Tree.parse(source) - nodes = linter.list_dbutils_notebook_run_calls(tree) + tree = Tree.maybe_parse(source) + assert tree.tree is not None + nodes = linter.list_dbutils_notebook_run_calls(tree.tree) assert len(nodes) == 2 @@ -274,8 +275,9 @@ def test_does_not_detect_partial_call_to_dbutils_notebook_run_in_python_code_() stuff2 = notebook.run("where is notebook 1?") """ linter = DbutilsPyLinter(CurrentSessionState()) - tree = Tree.parse(source) - nodes = linter.list_dbutils_notebook_run_calls(tree) + tree = Tree.maybe_parse(source) + assert tree.tree is not None + nodes = linter.list_dbutils_notebook_run_calls(tree.tree) assert len(nodes) == 0