Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change logic of direct filesystem access linting #2766

Merged
merged 22 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 16 additions & 53 deletions src/databricks/labs/ucx/source_code/linters/directfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC
from collections.abc import Iterable

from astroid import Attribute, Call, Const, InferenceError, JoinedStr, Name, NodeNG # type: ignore
from astroid import Call, InferenceError, NodeNG # type: ignore
from sqlglot.expressions import Alter, Create, Delete, Drop, Expression, Identifier, Insert, Literal, Select

from databricks.labs.ucx.source_code.base import (
Expand All @@ -16,7 +16,7 @@
DfsaSqlCollector,
DirectFsAccess,
)
from databricks.labs.ucx.source_code.python.python_ast import Tree, TreeVisitor
from databricks.labs.ucx.source_code.python.python_ast import Tree, TreeVisitor, TreeHelper
from databricks.labs.ucx.source_code.python.python_infer import InferredValue
from databricks.labs.ucx.source_code.sql.sql_parser import SqlParser, SqlExpression

Expand Down Expand Up @@ -68,43 +68,37 @@ class _DetectDirectFsAccessVisitor(TreeVisitor):
def __init__(self, session_state: CurrentSessionState, prevent_spark_duplicates: bool) -> None:
self._session_state = session_state
self._directfs_nodes: list[DirectFsAccessNode] = []
self._reported_locations: set[tuple[int, int]] = set()
self._prevent_spark_duplicates = prevent_spark_duplicates

def visit_call(self, node: Call):
for arg in node.args:
self._visit_arg(arg)
self._visit_arg(node, arg)

def _visit_arg(self, arg: NodeNG):
def _visit_arg(self, call: Call, arg: NodeNG):
try:
for inferred in InferredValue.infer_from_node(arg, self._session_state):
if not inferred.is_inferred():
logger.debug(f"Could not infer value of {arg.as_string()}")
continue
self._check_str_constant(arg, inferred)
self._check_str_arg(call, arg, inferred)
except InferenceError as e:
logger.debug(f"Could not infer value of {arg.as_string()}", exc_info=e)

def visit_const(self, node: Const):
# Constant strings yield Advisories
if isinstance(node.value, str):
self._check_str_constant(node, InferredValue([node]))

def _check_str_constant(self, source_node: NodeNG, inferred: InferredValue):
if self._already_reported(source_node, inferred):
return
# don't report on JoinedStr fragments
if isinstance(source_node.parent, JoinedStr):
return
def _check_str_arg(self, call_node: Call, arg_node: NodeNG, inferred: InferredValue):
value = inferred.as_string()
for pattern in DIRECT_FS_ACCESS_PATTERNS:
if not pattern.matches(value):
continue
# avoid false positives with relative URLs
if self._is_http_call_parameter(source_node):
# only capture 'open' calls or calls originating from spark or dbutils
# because there is no other known way to manipulate data directly from file system
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's also with open('/dbfs/mnt/x', 'r') as f: ..., but we're not handling it here yet. make sure to add that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

tree = Tree(call_node)
is_open = TreeHelper.get_call_name(call_node) == "open" and tree.is_builtin()
is_from_db_utils = False if is_open else tree.is_from_module("dbutils")
is_from_spark = False if is_open or is_from_db_utils else tree.is_from_module("spark")
if not (is_open or is_from_db_utils or is_from_spark):
return
# avoid duplicate advices that are reported by SparkSqlPyLinter
if self._prevent_spark_duplicates and Tree(source_node).is_from_module("spark"):
if self._prevent_spark_duplicates and is_from_spark:
return
# since we're normally filtering out spark calls, we're dealing with dfsas we know little about
# notably we don't know is_read or is_write
Expand All @@ -113,39 +107,8 @@ def _check_str_constant(self, source_node: NodeNG, inferred: InferredValue):
is_read=True,
is_write=False,
)
self._directfs_nodes.append(DirectFsAccessNode(dfsa, source_node))
self._reported_locations.add((source_node.lineno, source_node.col_offset))

@classmethod
def _is_http_call_parameter(cls, source_node: NodeNG):
if not isinstance(source_node.parent, Call):
return False
# for now we only cater for ws.api_client.do
return cls._is_ws_api_client_do_call(source_node)

@classmethod
def _is_ws_api_client_do_call(cls, source_node: NodeNG):
assert isinstance(source_node.parent, Call)
func = source_node.parent.func
if not isinstance(func, Attribute) or func.attrname != "do":
return False
expr = func.expr
if not isinstance(expr, Attribute) or expr.attrname != "api_client":
return False
expr = expr.expr
if not isinstance(expr, Name):
return False
for value in InferredValue.infer_from_node(expr):
if not value.is_inferred():
continue
for node in value.nodes:
return Tree(node).is_instance_of("WorkspaceClient")
# at this point is seems safer to assume that expr.expr is a workspace than the opposite
return True

def _already_reported(self, source_node: NodeNG, inferred: InferredValue):
all_nodes = [source_node] + inferred.nodes
return any((node.lineno, node.col_offset) in self._reported_locations for node in all_nodes)
self._directfs_nodes.append(DirectFsAccessNode(dfsa, arg_node))
return

@property
def directfs_nodes(self):
Expand Down
24 changes: 23 additions & 1 deletion src/databricks/labs/ucx/source_code/python/python_ast.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import builtins
import sys
from abc import ABC
import logging
import re
Expand All @@ -25,7 +27,6 @@
)

logger = logging.getLogger(__name__)

missing_handlers: set[str] = set()


Expand Down Expand Up @@ -289,6 +290,16 @@ def renumber_node(node: NodeNG, offset: int) -> None:
start = start + num_lines if start > 0 else start - num_lines
return self

def is_builtin(self) -> bool:
if isinstance(self._node, Name):
name = self._node.name
return name in dir(builtins) or name in sys.stdlib_module_names or name in sys.builtin_module_names
if isinstance(self._node, Call):
return Tree(self._node.func).is_builtin()
if isinstance(self._node, Attribute):
return Tree(self._node.expr).is_builtin()
return False # not supported yet


class _LocalTree(Tree):

Expand All @@ -298,6 +309,17 @@ def is_from_module_visited(self, name: str, visited_nodes: set[NodeNG]) -> bool:

class TreeHelper(ABC):

@classmethod
def get_call_name(cls, call: Call) -> str:
if not isinstance(call, Call):
return ""
func = call.func
if isinstance(func, Name):
return func.name
if isinstance(func, Attribute):
return func.attrname
return "" # not supported yet

@classmethod
def extract_call_by_name(cls, call: Call, name: str) -> Call | None:
"""Given a call-chain, extract its sub-call by method name (if it has one)"""
Expand Down
20 changes: 14 additions & 6 deletions tests/unit/source_code/linters/test_directfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ def test_matches_dfsa_pattern(path, matches):
"code, expected",
[
('SOME_CONSTANT = "not a file system path"', 0),
('SOME_CONSTANT = ("/dbfs/mnt", "dbfs:/", "/mnt/")', 3),
('SOME_CONSTANT = ("/dbfs/mnt", "dbfs:/", "/mnt/")', 0),
('# "/dbfs/mnt"', 0),
('SOME_CONSTANT = "/dbfs/mnt"', 1),
('SOME_CONSTANT = "/dbfs/mnt"; load_data(SOME_CONSTANT)', 1),
('SOME_CONSTANT = "/dbfs/mnt"', 0),
('SOME_CONSTANT = "/dbfs/mnt"; load_data(SOME_CONSTANT)', 0),
('SOME_CONSTANT = "/dbfs/mnt"; spark.table(SOME_CONSTANT)', 1),
('SOME_CONSTANT = ("/dbfs/mnt", "dbfs:/", "/mnt/"); [dbutils.fs(path) for path in SOME_CONSTANT]', 3),
('SOME_CONSTANT = 42; load_data(SOME_CONSTANT)', 0),
('SOME_CONSTANT = "/dbfs/mnt"; dbutils.fs(SOME_CONSTANT)', 1),
],
)
def test_detects_dfsa_paths(code, expected):
Expand All @@ -47,9 +50,14 @@ def test_detects_dfsa_paths(code, expected):
@pytest.mark.parametrize(
"code, expected",
[
("load_data('/dbfs/mnt/data')", 1),
("load_data('/data')", 1),
("load_data('/dbfs/mnt/data', '/data')", 2),
("load_data('/dbfs/mnt/data')", 0),
(
"""with open('/dbfs/mnt/data') as f:
f.read()""",
1,
),
("dbutils.fs('/data')", 1),
("dbutils.fs('/dbfs/mnt/data', '/data')", 2),
("# load_data('/dbfs/mnt/data', '/data')", 0),
('spark.read.parquet("/mnt/foo/bar")', 1),
('spark.read.parquet("dbfs:/mnt/foo/bar")', 1),
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/source_code/python/test_python_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,35 @@ def test_counts_lines(source: str, line_count: int):
assert tree.line_count() == line_count


@pytest.mark.parametrize(
"source, name, is_builtin",
[
("x = open()", "open", True),
("import datetime; x = datetime.datetime.now()", "now", True),
("import stuff; x = stuff()", "stuff", False),
(
"""def stuff():
pass
x = stuff()""",
"stuff",
False,
),
],
)
def test_is_builtin(source, name, is_builtin):
tree = Tree.normalize_and_parse(source)
nodes = list(tree.node.get_children())
for node in nodes:
if isinstance(node, Assign):
call = node.value
assert isinstance(call, Call)
func_name = TreeHelper.get_call_name(call)
assert func_name == name
assert Tree(call).is_builtin() == is_builtin
return
assert False # could not locate call


def test_first_statement_is_none():
node = Const("xyz")
assert not Tree(node).first_statement()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,12 @@
-- COMMAND ----------
-- DBTITLE 1,A Python cell that references DBFS
-- MAGIC %python
-- ucx[direct-filesystem-access:+1:7:+1:18] The use of direct filesystem references is deprecated: dbfs:/...
-- MAGIC DBFS = "dbfs:/..."
-- ucx[direct-filesystem-access:+1:7:+1:18] The use of direct filesystem references is deprecated: /dbfs/mnt
-- MAGIC DBFS = "/dbfs/mnt"
-- ucx[direct-filesystem-access:+1:7:+1:14] The use of direct filesystem references is deprecated: /mnt/
-- MAGIC DBFS = "/mnt/"
-- ucx[direct-filesystem-access:+1:7:+1:18] The use of direct filesystem references is deprecated: dbfs:/...
-- MAGIC DBFS = "dbfs:/..."
-- ucx[direct-filesystem-access:+1:10:+1:26] The use of direct filesystem references is deprecated: /dbfs/mnt/data
-- MAGIC load_data('/dbfs/mnt/data')
-- ucx[direct-filesystem-access:+1:10:+1:17] The use of direct filesystem references is deprecated: /data
-- MAGIC load_data('/data')
-- ucx[direct-filesystem-access:+2:10:+2:26] The use of direct filesystem references is deprecated: /dbfs/mnt/data
-- ucx[direct-filesystem-access:+1:28:+1:35] The use of direct filesystem references is deprecated: /data
-- MAGIC load_data('/dbfs/mnt/data', '/data')
-- MAGIC # load_data('/dbfs/mnt/data', '/data')
-- ucx[direct-filesystem-access:+1:0:+1:34] The use of direct filesystem references is deprecated: /mnt/foo/bar
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# ucx[direct-filesystem-access:+1:6:+1:26] The use of direct filesystem references is deprecated: dbfs:/mnt/foo/bar1
DBFS1="dbfs:/mnt/foo/bar1"
# ucx[direct-filesystem-access:+1:16:+1:36] The use of direct filesystem references is deprecated: dbfs:/mnt/foo/bar2
systems=[DBFS1, "dbfs:/mnt/foo/bar2"]
for system in systems:
# ucx[direct-filesystem-access:+2:4:+2:30] The use of direct filesystem references is deprecated: dbfs:/mnt/foo/bar1
Expand Down