diff --git a/docs/source/helpers.rst b/docs/source/helpers.rst index e4b94d2b0..3cf5abfbd 100644 --- a/docs/source/helpers.rst +++ b/docs/source/helpers.rst @@ -32,3 +32,18 @@ Functions that assist in traversing an existing LibCST tree. .. autofunction:: libcst.helpers.get_full_name_for_node .. autofunction:: libcst.helpers.get_full_name_for_node_or_raise .. autofunction:: libcst.helpers.ensure_type + +Node fields filtering Helpers +----------------------------- + +Function that assist when handling CST nodes' fields. + +.. autofunction:: libcst.helpers.filter_node_fields + +And lower level functions: + +.. autofunction:: libcst.helpers.get_node_fields +.. autofunction:: libcst.helpers.is_whitespace_node_field +.. autofunction:: libcst.helpers.is_syntax_node_field +.. autofunction:: libcst.helpers.is_default_node_field +.. autofunction:: libcst.helpers.get_field_default_value diff --git a/libcst/helpers/__init__.py b/libcst/helpers/__init__.py index c7fdf9b15..817acc39b 100644 --- a/libcst/helpers/__init__.py +++ b/libcst/helpers/__init__.py @@ -25,6 +25,14 @@ insert_header_comments, ModuleNameAndPackage, ) +from libcst.helpers.node_fields import ( + filter_node_fields, + get_field_default_value, + get_node_fields, + is_default_node_field, + is_syntax_node_field, + is_whitespace_node_field, +) __all__ = [ "calculate_module_and_package", @@ -42,4 +50,10 @@ "parse_template_statement", "parse_template_expression", "ModuleNameAndPackage", + "get_node_fields", + "get_field_default_value", + "is_whitespace_node_field", + "is_syntax_node_field", + "is_default_node_field", + "filter_node_fields", ] diff --git a/libcst/helpers/node_fields.py b/libcst/helpers/node_fields.py new file mode 100644 index 000000000..418d6cbbd --- /dev/null +++ b/libcst/helpers/node_fields.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING + +from libcst import IndentedBlock, Module +from libcst._nodes.deep_equals import deep_equals + +if TYPE_CHECKING: + from typing import Sequence + + from libcst import CSTNode + + +def get_node_fields(node: CSTNode) -> Sequence[dataclasses.Field[CSTNode]]: + """ + Returns the sequence of a given CST-node's fields. + """ + return dataclasses.fields(node) + + +def is_whitespace_node_field(node: CSTNode, field: dataclasses.Field[CSTNode]) -> bool: + """ + Returns True if a given CST-node's field is a whitespace-related field + (whitespace, indent, header, footer, etc.). + """ + if "whitespace" in field.name: + return True + if "leading_lines" in field.name: + return True + if "lines_after_decorators" in field.name: + return True + if isinstance(node, (IndentedBlock, Module)) and field.name in [ + "header", + "footer", + ]: + return True + if isinstance(node, IndentedBlock) and field.name == "indent": + return True + return False + + +def is_syntax_node_field(node: CSTNode, field: dataclasses.Field[CSTNode]) -> bool: + """ + Returns True if a given CST-node's field is a syntax-related field + (colon, semicolon, dot, encoding, etc.). + """ + if isinstance(node, Module) and field.name in [ + "encoding", + "default_indent", + "default_newline", + "has_trailing_newline", + ]: + return True + type_str = repr(field.type) + if ( + "Sentinel" in type_str + and field.name not in ["star_arg", "star", "posonly_ind"] + and "whitespace" not in field.name + ): + # This is a value that can optionally be specified, so its + # definitely syntax. + return True + + for name in ["Semicolon", "Colon", "Comma", "Dot", "AssignEqual"]: + # These are all nodes that exist for separation syntax + if name in type_str: + return True + + return False + + +def get_field_default_value(field: dataclasses.Field[CSTNode]) -> object: + """ + Returns the default value of a CST-node's field. + """ + if field.default_factory is not dataclasses.MISSING: + # pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE, + # dataclasses._DefaultFactory[object]]` is not a function. + return field.default_factory() + return field.default + + +def is_default_node_field(node: CSTNode, field: dataclasses.Field[CSTNode]) -> bool: + """ + Returns True if a given CST-node's field has its default value. + """ + return deep_equals(getattr(node, field.name), get_field_default_value(field)) + + +def filter_node_fields( + node: CSTNode, + *, + show_defaults: bool, + show_syntax: bool, + show_whitespace: bool, +) -> Sequence[dataclasses.Field[CSTNode]]: + """ + Returns a filtered sequence of a CST-node's fields. + + Setting ``show_whitespace`` to ``False`` will filter whitespace fields. + + Setting ``show_defaults`` to ``False`` will filter fields if their value is equal to + the default value ; while respecting the value of ``show_whitespace``. + + Setting ``show_syntax`` to ``False`` will filter syntax fields ; while respecting + the value of ``show_whitespace`` & ``show_defaults``. + """ + + fields: Sequence[dataclasses.Field[CSTNode]] = dataclasses.fields(node) + # Hide all fields prefixed with "_" + fields = [f for f in fields if f.name[0] != "_"] + # Filter whitespace nodes if needed + if not show_whitespace: + fields = [f for f in fields if not is_whitespace_node_field(node, f)] + # Filter values which aren't changed from their defaults + if not show_defaults: + fields = [f for f in fields if not is_default_node_field(node, f)] + # Filter out values which aren't interesting if needed + if not show_syntax: + fields = [f for f in fields if not is_syntax_node_field(node, f)] + + return fields diff --git a/libcst/helpers/tests/test_node_fields.py b/libcst/helpers/tests/test_node_fields.py new file mode 100644 index 000000000..61d5ec21f --- /dev/null +++ b/libcst/helpers/tests/test_node_fields.py @@ -0,0 +1,314 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from unittest import TestCase + +from libcst import ( + Annotation, + CSTNode, + FunctionDef, + IndentedBlock, + Module, + Param, + parse_module, + Pass, + Semicolon, + SimpleStatementLine, +) + +from libcst.helpers import ( + get_node_fields, + is_default_node_field, + is_syntax_node_field, + is_whitespace_node_field, +) + + +class _NodeFieldsTest(TestCase): + """Node fields related tests.""" + + module: Module + annotation: Annotation + param: Param + _pass: Pass + semicolon: Semicolon + statement: SimpleStatementLine + indent: IndentedBlock + function: FunctionDef + + @classmethod + def setUpClass(cls) -> None: + """Parse a simple CST and references interesting nodes.""" + cls.module = parse_module( + "def foo(a: str) -> None:\n pass ; pass\n return\n" + ) + # /!\ Direct access to nodes + # This is done for test purposes on a known CST + # -> For "real code", use visitors to do this "the correct way" + + # pyre-ignore[8]: direct access for tests + cls.function = cls.module.body[0] + cls.param = cls.function.params.params[0] + # pyre-ignore[8]: direct access for tests + cls.annotation = cls.param.annotation + # pyre-ignore[8]: direct access for tests + cls.indent = cls.function.body + # pyre-ignore[8]: direct access for tests + cls.statement = cls.indent.body[0] + # pyre-ignore[8]: direct access for tests + cls._pass = cls.statement.body[0] + # pyre-ignore[8]: direct access for tests + cls.semicolon = cls.statement.body[0].semicolon + + def test__cst_correctness(self) -> None: + """Test that the CST is correctly parsed.""" + self.assertIsInstance(self.module, Module) + self.assertIsInstance(self.annotation, Annotation) + self.assertIsInstance(self.param, Param) + self.assertIsInstance(self._pass, Pass) + self.assertIsInstance(self.semicolon, Semicolon) + self.assertIsInstance(self.statement, SimpleStatementLine) + self.assertIsInstance(self.indent, IndentedBlock) + self.assertIsInstance(self.function, FunctionDef) + + +class IsWhitespaceNodeFieldTest(_NodeFieldsTest): + """``is_whitespace_node_field`` tests.""" + + def _check_fields(self, is_filtered_field: dict[str, bool], node: CSTNode) -> None: + fields = get_node_fields(node) + self.assertEqual(len(is_filtered_field), len(fields)) + for field in fields: + self.assertEqual( + is_filtered_field[field.name], + is_whitespace_node_field(node, field), + f"Node ``{node.__class__.__qualname__}`` field '{field.name}' " + f"{'should have' if is_filtered_field[field.name] else 'should not have'} " + "been filtered by ``is_whitespace_node_field``", + ) + + def test_module(self) -> None: + """Check if a CST Module node is correctly filtered.""" + is_filtered_field = { + "body": False, + "header": True, + "footer": True, + "encoding": False, + "default_indent": False, + "default_newline": False, + "has_trailing_newline": False, + } + self._check_fields(is_filtered_field, self.module) + + def test_annotation(self) -> None: + """Check if a CST Annotation node is correctly filtered.""" + is_filtered_field = { + "annotation": False, + "whitespace_before_indicator": True, + "whitespace_after_indicator": True, + } + self._check_fields(is_filtered_field, self.annotation) + + def test_param(self) -> None: + """Check if a CST Param node is correctly filtered.""" + is_filtered_field = { + "name": False, + "annotation": False, + "equal": False, + "default": False, + "comma": False, + "star": False, + "whitespace_after_star": True, + "whitespace_after_param": True, + } + self._check_fields(is_filtered_field, self.param) + + def test_semicolon(self) -> None: + """Check if a CST Semicolon node is correctly filtered.""" + is_filtered_field = { + "whitespace_before": True, + "whitespace_after": True, + } + self._check_fields(is_filtered_field, self.semicolon) + + def test_statement(self) -> None: + """Check if a CST SimpleStatementLine node is correctly filtered.""" + is_filtered_field = { + "body": False, + "leading_lines": True, + "trailing_whitespace": True, + } + self._check_fields(is_filtered_field, self.statement) + + def test_indent(self) -> None: + """Check if a CST IndentedBlock node is correctly filtered.""" + is_filtered_field = { + "body": False, + "header": True, + "indent": True, + "footer": True, + } + self._check_fields(is_filtered_field, self.indent) + + def test_function(self) -> None: + """Check if a CST FunctionDef node is correctly filtered.""" + is_filtered_field = { + "name": False, + "params": False, + "body": False, + "decorators": False, + "returns": False, + "asynchronous": False, + "leading_lines": True, + "lines_after_decorators": True, + "whitespace_after_def": True, + "whitespace_after_name": True, + "whitespace_before_params": True, + "whitespace_before_colon": True, + "type_parameters": False, + "whitespace_after_type_parameters": True, + } + self._check_fields(is_filtered_field, self.function) + + +class IsSyntaxNodeFieldTest(_NodeFieldsTest): + """``is_syntax_node_field`` tests.""" + + def _check_fields(self, is_filtered_field: dict[str, bool], node: CSTNode) -> None: + fields = get_node_fields(node) + self.assertEqual(len(is_filtered_field), len(fields)) + for field in fields: + self.assertEqual( + is_filtered_field[field.name], + is_syntax_node_field(node, field), + f"Node ``{node.__class__.__qualname__}`` field '{field.name}' " + f"{'should have' if is_filtered_field[field.name] else 'should not have'} " + "been filtered by ``is_syntax_node_field``", + ) + + def test_module(self) -> None: + """Check if a CST Module node is correctly filtered.""" + is_filtered_field = { + "body": False, + "header": False, + "footer": False, + "encoding": True, + "default_indent": True, + "default_newline": True, + "has_trailing_newline": True, + } + self._check_fields(is_filtered_field, self.module) + + def test_param(self) -> None: + """Check if a CST Param node is correctly filtered.""" + is_filtered_field = { + "name": False, + "annotation": False, + "equal": True, + "default": False, + "comma": True, + "star": False, + "whitespace_after_star": False, + "whitespace_after_param": False, + } + self._check_fields(is_filtered_field, self.param) + + def test_pass(self) -> None: + """Check if a CST Pass node is correctly filtered.""" + is_filtered_field = { + "semicolon": True, + } + self._check_fields(is_filtered_field, self._pass) + + +class IsDefaultNodeFieldTest(_NodeFieldsTest): + """``is_default_node_field`` tests.""" + + def _check_fields(self, is_filtered_field: dict[str, bool], node: CSTNode) -> None: + fields = get_node_fields(node) + self.assertEqual(len(is_filtered_field), len(fields)) + for field in fields: + self.assertEqual( + is_filtered_field[field.name], + is_default_node_field(node, field), + f"Node ``{node.__class__.__qualname__}`` field '{field.name}' " + f"{'should have' if is_filtered_field[field.name] else 'should not have'} " + "been filtered by ``is_default_node_field``", + ) + + def test_module(self) -> None: + """Check if a CST Module node is correctly filtered.""" + is_filtered_field = { + "body": False, + "header": True, + "footer": True, + "encoding": True, + "default_indent": True, + "default_newline": True, + "has_trailing_newline": True, + } + self._check_fields(is_filtered_field, self.module) + + def test_annotation(self) -> None: + """Check if a CST Annotation node is correctly filtered.""" + is_filtered_field = { + "annotation": False, + "whitespace_before_indicator": False, + "whitespace_after_indicator": True, + } + self._check_fields(is_filtered_field, self.annotation) + + def test_param(self) -> None: + """Check if a CST Param node is correctly filtered.""" + is_filtered_field = { + "name": False, + "annotation": False, + "equal": True, + "default": True, + "comma": True, + "star": False, + "whitespace_after_star": True, + "whitespace_after_param": True, + } + self._check_fields(is_filtered_field, self.param) + + def test_statement(self) -> None: + """Check if a CST SimpleStatementLine node is correctly filtered.""" + is_filtered_field = { + "body": False, + "leading_lines": True, + "trailing_whitespace": True, + } + self._check_fields(is_filtered_field, self.statement) + + def test_indent(self) -> None: + """Check if a CST IndentedBlock node is correctly filtered.""" + is_filtered_field = { + "body": False, + "header": True, + "indent": True, + "footer": True, + } + self._check_fields(is_filtered_field, self.indent) + + def test_function(self) -> None: + """Check if a CST FunctionDef node is correctly filtered.""" + is_filtered_field = { + "name": False, + "params": False, + "body": False, + "decorators": True, + "returns": False, + "asynchronous": True, + "leading_lines": True, + "lines_after_decorators": True, + "whitespace_after_def": True, + "whitespace_after_name": True, + "whitespace_before_params": True, + "whitespace_before_colon": True, + "type_parameters": True, + "whitespace_after_type_parameters": True, + } + self._check_fields(is_filtered_field, self.function) diff --git a/libcst/tool.py b/libcst/tool.py index 5469ef27d..85a977be7 100644 --- a/libcst/tool.py +++ b/libcst/tool.py @@ -22,15 +22,7 @@ import yaml -from libcst import ( - CSTNode, - IndentedBlock, - LIBCST_VERSION, - Module, - parse_module, - PartialParserConfig, -) -from libcst._nodes.deep_equals import deep_equals +from libcst import CSTNode, LIBCST_VERSION, parse_module, PartialParserConfig from libcst._parser.parso.utils import parse_version_string from libcst.codemod import ( CodemodCommand, @@ -40,6 +32,7 @@ gather_files, parallel_exec_transform_with_prettyprint, ) +from libcst.helpers import filter_node_fields _DEFAULT_INDENT: str = " " @@ -54,76 +47,14 @@ def _node_repr_recursive( # noqa: C901 ) -> List[str]: if isinstance(node, CSTNode): # This is a CSTNode, we must pretty-print it. + fields: Sequence["dataclasses.Field[CSTNode]"] = filter_node_fields( + node=node, + show_defaults=show_defaults, + show_syntax=show_syntax, + show_whitespace=show_whitespace, + ) + tokens: List[str] = [node.__class__.__name__] - fields: Sequence["dataclasses.Field[object]"] = dataclasses.fields(node) - - # Hide all fields prefixed with "_" - fields = [f for f in fields if f.name[0] != "_"] - - # Filter whitespace nodes if needed - if not show_whitespace: - - def _is_whitespace(field: "dataclasses.Field[object]") -> bool: - if "whitespace" in field.name: - return True - if "leading_lines" in field.name: - return True - if "lines_after_decorators" in field.name: - return True - if isinstance(node, (IndentedBlock, Module)) and field.name in [ - "header", - "footer", - ]: - return True - if isinstance(node, IndentedBlock) and field.name == "indent": - return True - return False - - fields = [f for f in fields if not _is_whitespace(f)] - # Filter values which aren't changed from their defaults - if not show_defaults: - - def _get_default(fld: "dataclasses.Field[object]") -> object: - if fld.default_factory is not dataclasses.MISSING: - # pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE, - # dataclasses._DefaultFactory[object]]` is not a function. - return fld.default_factory() - return fld.default - - fields = [ - f - for f in fields - if not deep_equals(getattr(node, f.name), _get_default(f)) - ] - # Filter out values which aren't interesting if needed - if not show_syntax: - - def _is_syntax(field: "dataclasses.Field[object]") -> bool: - if isinstance(node, Module) and field.name in [ - "encoding", - "default_indent", - "default_newline", - "has_trailing_newline", - ]: - return True - type_str = repr(field.type) - if ( - "Sentinel" in type_str - and field.name not in ["star_arg", "star", "posonly_ind"] - and "whitespace" not in field.name - ): - # This is a value that can optionally be specified, so its - # definitely syntax. - return True - - for name in ["Semicolon", "Colon", "Comma", "Dot", "AssignEqual"]: - # These are all nodes that exist for separation syntax - if name in type_str: - return True - - return False - - fields = [f for f in fields if not _is_syntax(f)] if len(fields) == 0: tokens.append("()") @@ -204,12 +135,12 @@ def dump( from the default contruction of the node while also hiding whitespace and syntax fields. - Setting ``show_default`` to ``True`` will add fields regardless if their + Setting ``show_defaults`` to ``True`` will add fields regardless if their value is different from the default value. Setting ``show_whitespace`` will add whitespace fields and setting ``show_syntax`` will add syntax fields while respecting the value of - ``show_default``. + ``show_defaults``. When all keyword args are set to true, the output of this function is indentical to the __repr__ method of the node.