Skip to content

Commit

Permalink
Change logic of direct filesystem access linting (#2766)
Browse files Browse the repository at this point in the history
## Changes
Our current logic detects all string constants that match a DFSA
pattern, excluding false positives on a per use case basis.
This leaves many false positives. Practically, we only care about DFSA
if called from `spark` or `dbutils` modules.
This PR implements this change.

### Linked issues
None

### Functionality
None

### Tests
- [x] added unit tests

---------

Co-authored-by: Eric Vergnaud <eric.vergnaud@databricks.com>
  • Loading branch information
ericvergnaud and ericvergnaud authored Oct 3, 2024
1 parent 9a44fc1 commit 62d2e27
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 70 deletions.
69 changes: 16 additions & 53 deletions src/databricks/labs/ucx/source_code/linters/directfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC
from collections.abc import Iterable

from astroid import Attribute, Call, Const, InferenceError, JoinedStr, Name, NodeNG # type: ignore
from astroid import Call, InferenceError, NodeNG # type: ignore
from sqlglot.expressions import Alter, Create, Delete, Drop, Expression, Identifier, Insert, Literal, Select

from databricks.labs.ucx.source_code.base import (
Expand All @@ -16,7 +16,7 @@
DfsaSqlCollector,
DirectFsAccess,
)
from databricks.labs.ucx.source_code.python.python_ast import Tree, TreeVisitor
from databricks.labs.ucx.source_code.python.python_ast import Tree, TreeVisitor, TreeHelper
from databricks.labs.ucx.source_code.python.python_infer import InferredValue
from databricks.labs.ucx.source_code.sql.sql_parser import SqlParser, SqlExpression

Expand Down Expand Up @@ -68,43 +68,37 @@ class _DetectDirectFsAccessVisitor(TreeVisitor):
def __init__(self, session_state: CurrentSessionState, prevent_spark_duplicates: bool) -> None:
self._session_state = session_state
self._directfs_nodes: list[DirectFsAccessNode] = []
self._reported_locations: set[tuple[int, int]] = set()
self._prevent_spark_duplicates = prevent_spark_duplicates

def visit_call(self, node: Call):
for arg in node.args:
self._visit_arg(arg)
self._visit_arg(node, arg)

def _visit_arg(self, arg: NodeNG):
def _visit_arg(self, call: Call, arg: NodeNG):
try:
for inferred in InferredValue.infer_from_node(arg, self._session_state):
if not inferred.is_inferred():
logger.debug(f"Could not infer value of {arg.as_string()}")
continue
self._check_str_constant(arg, inferred)
self._check_str_arg(call, arg, inferred)
except InferenceError as e:
logger.debug(f"Could not infer value of {arg.as_string()}", exc_info=e)

def visit_const(self, node: Const):
# Constant strings yield Advisories
if isinstance(node.value, str):
self._check_str_constant(node, InferredValue([node]))

def _check_str_constant(self, source_node: NodeNG, inferred: InferredValue):
if self._already_reported(source_node, inferred):
return
# don't report on JoinedStr fragments
if isinstance(source_node.parent, JoinedStr):
return
def _check_str_arg(self, call_node: Call, arg_node: NodeNG, inferred: InferredValue):
value = inferred.as_string()
for pattern in DIRECT_FS_ACCESS_PATTERNS:
if not pattern.matches(value):
continue
# avoid false positives with relative URLs
if self._is_http_call_parameter(source_node):
# only capture 'open' calls or calls originating from spark or dbutils
# because there is no other known way to manipulate data directly from file system
tree = Tree(call_node)
is_open = TreeHelper.get_call_name(call_node) == "open" and tree.is_builtin()
is_from_db_utils = False if is_open else tree.is_from_module("dbutils")
is_from_spark = False if is_open or is_from_db_utils else tree.is_from_module("spark")
if not (is_open or is_from_db_utils or is_from_spark):
return
# avoid duplicate advices that are reported by SparkSqlPyLinter
if self._prevent_spark_duplicates and Tree(source_node).is_from_module("spark"):
if self._prevent_spark_duplicates and is_from_spark:
return
# since we're normally filtering out spark calls, we're dealing with dfsas we know little about
# notably we don't know is_read or is_write
Expand All @@ -113,39 +107,8 @@ def _check_str_constant(self, source_node: NodeNG, inferred: InferredValue):
is_read=True,
is_write=False,
)
self._directfs_nodes.append(DirectFsAccessNode(dfsa, source_node))
self._reported_locations.add((source_node.lineno, source_node.col_offset))

@classmethod
def _is_http_call_parameter(cls, source_node: NodeNG):
if not isinstance(source_node.parent, Call):
return False
# for now we only cater for ws.api_client.do
return cls._is_ws_api_client_do_call(source_node)

@classmethod
def _is_ws_api_client_do_call(cls, source_node: NodeNG):
assert isinstance(source_node.parent, Call)
func = source_node.parent.func
if not isinstance(func, Attribute) or func.attrname != "do":
return False
expr = func.expr
if not isinstance(expr, Attribute) or expr.attrname != "api_client":
return False
expr = expr.expr
if not isinstance(expr, Name):
return False
for value in InferredValue.infer_from_node(expr):
if not value.is_inferred():
continue
for node in value.nodes:
return Tree(node).is_instance_of("WorkspaceClient")
# at this point is seems safer to assume that expr.expr is a workspace than the opposite
return True

def _already_reported(self, source_node: NodeNG, inferred: InferredValue):
all_nodes = [source_node] + inferred.nodes
return any((node.lineno, node.col_offset) in self._reported_locations for node in all_nodes)
self._directfs_nodes.append(DirectFsAccessNode(dfsa, arg_node))
return

@property
def directfs_nodes(self):
Expand Down
24 changes: 23 additions & 1 deletion src/databricks/labs/ucx/source_code/python/python_ast.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import builtins
import sys
from abc import ABC
import logging
import re
Expand All @@ -25,7 +27,6 @@
)

logger = logging.getLogger(__name__)

missing_handlers: set[str] = set()


Expand Down Expand Up @@ -289,6 +290,16 @@ def renumber_node(node: NodeNG, offset: int) -> None:
start = start + num_lines if start > 0 else start - num_lines
return self

def is_builtin(self) -> bool:
if isinstance(self._node, Name):
name = self._node.name
return name in dir(builtins) or name in sys.stdlib_module_names or name in sys.builtin_module_names
if isinstance(self._node, Call):
return Tree(self._node.func).is_builtin()
if isinstance(self._node, Attribute):
return Tree(self._node.expr).is_builtin()
return False # not supported yet


class _LocalTree(Tree):

Expand All @@ -298,6 +309,17 @@ def is_from_module_visited(self, name: str, visited_nodes: set[NodeNG]) -> bool:

class TreeHelper(ABC):

@classmethod
def get_call_name(cls, call: Call) -> str:
if not isinstance(call, Call):
return ""
func = call.func
if isinstance(func, Name):
return func.name
if isinstance(func, Attribute):
return func.attrname
return "" # not supported yet

@classmethod
def extract_call_by_name(cls, call: Call, name: str) -> Call | None:
"""Given a call-chain, extract its sub-call by method name (if it has one)"""
Expand Down
20 changes: 14 additions & 6 deletions tests/unit/source_code/linters/test_directfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ def test_matches_dfsa_pattern(path, matches):
"code, expected",
[
('SOME_CONSTANT = "not a file system path"', 0),
('SOME_CONSTANT = ("/dbfs/mnt", "dbfs:/", "/mnt/")', 3),
('SOME_CONSTANT = ("/dbfs/mnt", "dbfs:/", "/mnt/")', 0),
('# "/dbfs/mnt"', 0),
('SOME_CONSTANT = "/dbfs/mnt"', 1),
('SOME_CONSTANT = "/dbfs/mnt"; load_data(SOME_CONSTANT)', 1),
('SOME_CONSTANT = "/dbfs/mnt"', 0),
('SOME_CONSTANT = "/dbfs/mnt"; load_data(SOME_CONSTANT)', 0),
('SOME_CONSTANT = "/dbfs/mnt"; spark.table(SOME_CONSTANT)', 1),
('SOME_CONSTANT = ("/dbfs/mnt", "dbfs:/", "/mnt/"); [dbutils.fs(path) for path in SOME_CONSTANT]', 3),
('SOME_CONSTANT = 42; load_data(SOME_CONSTANT)', 0),
('SOME_CONSTANT = "/dbfs/mnt"; dbutils.fs(SOME_CONSTANT)', 1),
],
)
def test_detects_dfsa_paths(code, expected):
Expand All @@ -47,9 +50,14 @@ def test_detects_dfsa_paths(code, expected):
@pytest.mark.parametrize(
"code, expected",
[
("load_data('/dbfs/mnt/data')", 1),
("load_data('/data')", 1),
("load_data('/dbfs/mnt/data', '/data')", 2),
("load_data('/dbfs/mnt/data')", 0),
(
"""with open('/dbfs/mnt/data') as f:
f.read()""",
1,
),
("dbutils.fs('/data')", 1),
("dbutils.fs('/dbfs/mnt/data', '/data')", 2),
("# load_data('/dbfs/mnt/data', '/data')", 0),
('spark.read.parquet("/mnt/foo/bar")', 1),
('spark.read.parquet("dbfs:/mnt/foo/bar")', 1),
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/source_code/python/test_python_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,35 @@ def test_counts_lines(source: str, line_count: int):
assert tree.line_count() == line_count


@pytest.mark.parametrize(
"source, name, is_builtin",
[
("x = open()", "open", True),
("import datetime; x = datetime.datetime.now()", "now", True),
("import stuff; x = stuff()", "stuff", False),
(
"""def stuff():
pass
x = stuff()""",
"stuff",
False,
),
],
)
def test_is_builtin(source, name, is_builtin):
tree = Tree.normalize_and_parse(source)
nodes = list(tree.node.get_children())
for node in nodes:
if isinstance(node, Assign):
call = node.value
assert isinstance(call, Call)
func_name = TreeHelper.get_call_name(call)
assert func_name == name
assert Tree(call).is_builtin() == is_builtin
return
assert False # could not locate call


def test_first_statement_is_none():
node = Const("xyz")
assert not Tree(node).first_statement()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,12 @@
-- COMMAND ----------
-- DBTITLE 1,A Python cell that references DBFS
-- MAGIC %python
-- ucx[direct-filesystem-access:+1:7:+1:18] The use of direct filesystem references is deprecated: dbfs:/...
-- MAGIC DBFS = "dbfs:/..."
-- ucx[direct-filesystem-access:+1:7:+1:18] The use of direct filesystem references is deprecated: /dbfs/mnt
-- MAGIC DBFS = "/dbfs/mnt"
-- ucx[direct-filesystem-access:+1:7:+1:14] The use of direct filesystem references is deprecated: /mnt/
-- MAGIC DBFS = "/mnt/"
-- ucx[direct-filesystem-access:+1:7:+1:18] The use of direct filesystem references is deprecated: dbfs:/...
-- MAGIC DBFS = "dbfs:/..."
-- ucx[direct-filesystem-access:+1:10:+1:26] The use of direct filesystem references is deprecated: /dbfs/mnt/data
-- MAGIC load_data('/dbfs/mnt/data')
-- ucx[direct-filesystem-access:+1:10:+1:17] The use of direct filesystem references is deprecated: /data
-- MAGIC load_data('/data')
-- ucx[direct-filesystem-access:+2:10:+2:26] The use of direct filesystem references is deprecated: /dbfs/mnt/data
-- ucx[direct-filesystem-access:+1:28:+1:35] The use of direct filesystem references is deprecated: /data
-- MAGIC load_data('/dbfs/mnt/data', '/data')
-- MAGIC # load_data('/dbfs/mnt/data', '/data')
-- ucx[direct-filesystem-access:+1:0:+1:34] The use of direct filesystem references is deprecated: /mnt/foo/bar
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# ucx[direct-filesystem-access:+1:6:+1:26] The use of direct filesystem references is deprecated: dbfs:/mnt/foo/bar1
DBFS1="dbfs:/mnt/foo/bar1"
# ucx[direct-filesystem-access:+1:16:+1:36] The use of direct filesystem references is deprecated: dbfs:/mnt/foo/bar2
systems=[DBFS1, "dbfs:/mnt/foo/bar2"]
for system in systems:
# ucx[direct-filesystem-access:+2:4:+2:30] The use of direct filesystem references is deprecated: dbfs:/mnt/foo/bar1
Expand Down

0 comments on commit 62d2e27

Please sign in to comment.