Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add helper functions for common ways of filtering nodes #1137

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/source/helpers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,18 @@ Functions that assist in traversing an existing LibCST tree.
.. autofunction:: libcst.helpers.get_full_name_for_node
.. autofunction:: libcst.helpers.get_full_name_for_node_or_raise
.. autofunction:: libcst.helpers.ensure_type

Node fields filtering Helpers
-----------------------------

Function that assist when handling CST nodes' fields.

.. autofunction:: libcst.helpers.filter_node_fields

And lower level functions:

.. autofunction:: libcst.helpers.get_node_fields
.. autofunction:: libcst.helpers.is_whitespace_node_field
.. autofunction:: libcst.helpers.is_syntax_node_field
.. autofunction:: libcst.helpers.is_default_node_field
.. autofunction:: libcst.helpers.get_field_default_value
14 changes: 14 additions & 0 deletions libcst/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@
insert_header_comments,
ModuleNameAndPackage,
)
from libcst.helpers.node_fields import (
filter_node_fields,
get_field_default_value,
get_node_fields,
is_default_node_field,
is_syntax_node_field,
is_whitespace_node_field,
)

__all__ = [
"calculate_module_and_package",
Expand All @@ -42,4 +50,10 @@
"parse_template_statement",
"parse_template_expression",
"ModuleNameAndPackage",
"get_node_fields",
"get_field_default_value",
"is_whitespace_node_field",
"is_syntax_node_field",
"is_default_node_field",
"filter_node_fields",
]
128 changes: 128 additions & 0 deletions libcst/helpers/node_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import dataclasses
from typing import TYPE_CHECKING

from libcst import IndentedBlock, Module
from libcst._nodes.deep_equals import deep_equals

if TYPE_CHECKING:
from typing import Sequence

Check warning on line 15 in libcst/helpers/node_fields.py

View check run for this annotation

Codecov / codecov/patch

libcst/helpers/node_fields.py#L15

Added line #L15 was not covered by tests

from libcst import CSTNode

Check warning on line 17 in libcst/helpers/node_fields.py

View check run for this annotation

Codecov / codecov/patch

libcst/helpers/node_fields.py#L17

Added line #L17 was not covered by tests


def get_node_fields(node: CSTNode) -> Sequence[dataclasses.Field[CSTNode]]:
"""
Returns the sequence of a given CST-node's fields.
"""
return dataclasses.fields(node)


def is_whitespace_node_field(node: CSTNode, field: dataclasses.Field[CSTNode]) -> bool:
"""
Returns True if a given CST-node's field is a whitespace-related field
(whitespace, indent, header, footer, etc.).
"""
if "whitespace" in field.name:
return True
if "leading_lines" in field.name:
return True
if "lines_after_decorators" in field.name:
return True
if isinstance(node, (IndentedBlock, Module)) and field.name in [
"header",
"footer",
]:
return True
if isinstance(node, IndentedBlock) and field.name == "indent":
return True
return False


def is_syntax_node_field(node: CSTNode, field: dataclasses.Field[CSTNode]) -> bool:
"""
Returns True if a given CST-node's field is a syntax-related field
(colon, semicolon, dot, encoding, etc.).
"""
if isinstance(node, Module) and field.name in [
"encoding",
"default_indent",
"default_newline",
"has_trailing_newline",
]:
return True
type_str = repr(field.type)
if (
"Sentinel" in type_str
and field.name not in ["star_arg", "star", "posonly_ind"]
and "whitespace" not in field.name
):
# This is a value that can optionally be specified, so its
# definitely syntax.
return True

for name in ["Semicolon", "Colon", "Comma", "Dot", "AssignEqual"]:
# These are all nodes that exist for separation syntax
if name in type_str:
return True

Check warning on line 73 in libcst/helpers/node_fields.py

View check run for this annotation

Codecov / codecov/patch

libcst/helpers/node_fields.py#L73

Added line #L73 was not covered by tests

return False


def get_field_default_value(field: dataclasses.Field[CSTNode]) -> object:
"""
Returns the default value of a CST-node's field.
"""
if field.default_factory is not dataclasses.MISSING:
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
# dataclasses._DefaultFactory[object]]` is not a function.
return field.default_factory()
return field.default


def is_default_node_field(node: CSTNode, field: dataclasses.Field[CSTNode]) -> bool:
"""
Returns True if a given CST-node's field has its default value.
"""
return deep_equals(getattr(node, field.name), get_field_default_value(field))


def filter_node_fields(
node: CSTNode,
*,
show_defaults: bool,
show_syntax: bool,
show_whitespace: bool,
) -> Sequence[dataclasses.Field[CSTNode]]:
"""
Returns a filtered sequence of a CST-node's fields.

Setting ``show_whitespace`` to ``False`` will filter whitespace fields.

Setting ``show_defaults`` to ``False`` will filter fields if their value is equal to
the default value ; while respecting the value of ``show_whitespace``.

Setting ``show_syntax`` to ``False`` will filter syntax fields ; while respecting
the value of ``show_whitespace`` & ``show_defaults``.
"""

fields: Sequence[dataclasses.Field[CSTNode]] = dataclasses.fields(node)
# Hide all fields prefixed with "_"
fields = [f for f in fields if f.name[0] != "_"]
# Filter whitespace nodes if needed
if not show_whitespace:
fields = [f for f in fields if not is_whitespace_node_field(node, f)]
# Filter values which aren't changed from their defaults
if not show_defaults:
fields = [f for f in fields if not is_default_node_field(node, f)]
# Filter out values which aren't interesting if needed
if not show_syntax:
fields = [f for f in fields if not is_syntax_node_field(node, f)]

return fields
Loading
Loading