From 8cd81c36561463b9849a8e0c2d70248c5b1feb62 Mon Sep 17 00:00:00 2001 From: Jo <46752250+GeorgeSittas@users.noreply.github.com> Date: Fri, 23 Sep 2022 18:36:24 +0300 Subject: [PATCH] Add support for removing AST nodes using transform (#476) * Add support for removing ast nodes using transform * Always allow node removal, add more tests --- sqlglot/expressions.py | 11 +++++------ tests/test_expressions.py | 32 +++++++++++++++++++++++++++++--- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index e6888dfbb4..9214e12c87 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -7,7 +7,7 @@ from enum import auto from sqlglot.errors import ParseError -from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list +from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list, list_get class _Expression(type): @@ -351,7 +351,8 @@ def transform(self, fun, *args, copy=True, **kwargs): Args: fun (function): a function which takes a node as an argument and returns a - new transformed node or the same node without modifications. + new transformed node or the same node without modifications. If the function + returns None, then the corresponding node will be removed from the syntax tree. copy (bool): if set to True a new tree instance is constructed, otherwise the tree is modified in place. @@ -361,9 +362,7 @@ def transform(self, fun, *args, copy=True, **kwargs): node = self.copy() if copy else self new_node = fun(node, *args, **kwargs) - if new_node is None: - raise ValueError("A transformed node cannot be None") - if not isinstance(new_node, Expression): + if new_node is None or not isinstance(new_node, Expression): return new_node if new_node is not node: new_node.parent = node.parent @@ -3009,7 +3008,7 @@ def replace_children(expression, fun): else: new_child_nodes.append(cn) - expression.args[k] = new_child_nodes if is_list_arg else new_child_nodes[0] + expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0) def column_table_names(expression): diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 716e4573ab..59d584c813 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -224,9 +224,6 @@ def fun(node): self.assertEqual(actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)") self.assertIs(actual_expression_2, expression) - with self.assertRaises(ValueError): - parse_one("a").transform(lambda n: None) - def test_transform_no_infinite_recursion(self): expression = parse_one("a") @@ -247,6 +244,35 @@ def fun(node): self.assertEqual(expression.transform(fun).sql(), "SELECT a, b FROM x") + def test_transform_node_removal(self): + expression = parse_one("SELECT a, b FROM x") + + def remove_column_b(node): + if isinstance(node, exp.Column) and node.name == "b": + return None + return node + + self.assertEqual(expression.transform(remove_column_b).sql(), "SELECT a FROM x") + self.assertEqual(expression.transform(lambda _: None), None) + + expression = parse_one("CAST(x AS FLOAT)") + + def remove_non_list_arg(node): + if isinstance(node, exp.DataType): + return None + return node + + self.assertEqual(expression.transform(remove_non_list_arg).sql(), "CAST(x AS )") + + expression = parse_one("SELECT a, b FROM x") + + def remove_all_columns(node): + if isinstance(node, exp.Column): + return None + return node + + self.assertEqual(expression.transform(remove_all_columns).sql(), "SELECT FROM x") + def test_replace(self): expression = parse_one("SELECT a, b FROM x") expression.find(exp.Column).replace(parse_one("c"))