Skip to content

Commit

Permalink
Add support for removing AST nodes using transform (#476)
Browse files Browse the repository at this point in the history
* Add support for removing ast nodes using transform

* Always allow node removal, add more tests
  • Loading branch information
georgesittas authored Sep 23, 2022
1 parent 085c130 commit 8cd81c3
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
11 changes: 5 additions & 6 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from enum import auto

from sqlglot.errors import ParseError
from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list
from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list, list_get


class _Expression(type):
Expand Down Expand Up @@ -351,7 +351,8 @@ def transform(self, fun, *args, copy=True, **kwargs):
Args:
fun (function): a function which takes a node as an argument and returns a
new transformed node or the same node without modifications.
new transformed node or the same node without modifications. If the function
returns None, then the corresponding node will be removed from the syntax tree.
copy (bool): if set to True a new tree instance is constructed, otherwise the tree is
modified in place.
Expand All @@ -361,9 +362,7 @@ def transform(self, fun, *args, copy=True, **kwargs):
node = self.copy() if copy else self
new_node = fun(node, *args, **kwargs)

if new_node is None:
raise ValueError("A transformed node cannot be None")
if not isinstance(new_node, Expression):
if new_node is None or not isinstance(new_node, Expression):
return new_node
if new_node is not node:
new_node.parent = node.parent
Expand Down Expand Up @@ -3009,7 +3008,7 @@ def replace_children(expression, fun):
else:
new_child_nodes.append(cn)

expression.args[k] = new_child_nodes if is_list_arg else new_child_nodes[0]
expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0)


def column_table_names(expression):
Expand Down
32 changes: 29 additions & 3 deletions tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,6 @@ def fun(node):
self.assertEqual(actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)")
self.assertIs(actual_expression_2, expression)

with self.assertRaises(ValueError):
parse_one("a").transform(lambda n: None)

def test_transform_no_infinite_recursion(self):
expression = parse_one("a")

Expand All @@ -247,6 +244,35 @@ def fun(node):

self.assertEqual(expression.transform(fun).sql(), "SELECT a, b FROM x")

def test_transform_node_removal(self):
expression = parse_one("SELECT a, b FROM x")

def remove_column_b(node):
if isinstance(node, exp.Column) and node.name == "b":
return None
return node

self.assertEqual(expression.transform(remove_column_b).sql(), "SELECT a FROM x")
self.assertEqual(expression.transform(lambda _: None), None)

expression = parse_one("CAST(x AS FLOAT)")

def remove_non_list_arg(node):
if isinstance(node, exp.DataType):
return None
return node

self.assertEqual(expression.transform(remove_non_list_arg).sql(), "CAST(x AS )")

expression = parse_one("SELECT a, b FROM x")

def remove_all_columns(node):
if isinstance(node, exp.Column):
return None
return node

self.assertEqual(expression.transform(remove_all_columns).sql(), "SELECT FROM x")

def test_replace(self):
expression = parse_one("SELECT a, b FROM x")
expression.find(exp.Column).replace(parse_one("c"))
Expand Down

0 comments on commit 8cd81c3

Please sign in to comment.