Skip to content

Commit

Permalink
Fix pystr_to_symbolic not correctly interpreting constants as boole…
Browse files Browse the repository at this point in the history
…an values in boolean comparisons (#1756)

Strings like `not ((N > 20) != 0)` (== `Not(Ne(Gt(N, 20), 0))`) were
incorrectly interpreted by `sympy.sympify` as constant "False". This is
a limitation by sympy, which does not assume integer 0 to be a Falsy,
and enforces exact equivalence (or difference) checks with `Ne`. To get
around this limitation, the DaCe internal AST preprocessor now replaces
constants with boolean values if they are arguments to Comparison
operations, where the other operand is also a comparison operation, thus
returning a boolean.

This fixes an issue with `DeadStateElimination`, closing issue #1129.
  • Loading branch information
phschaad authored Nov 14, 2024
1 parent 17e4a88 commit c83f601
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
36 changes: 35 additions & 1 deletion dace/symbolic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import ast
from functools import lru_cache
import sys
import sympy
import pickle
import re
Expand Down Expand Up @@ -982,6 +983,32 @@ def _process_is(elem: Union[Is, IsNot]):
return expr


# Depending on the Python version we need to handle different AST nodes to correctly interpret and detect falsy / truthy
# values.
if sys.version_info < (3, 8):
_SimpleASTNode = (ast.Constant, ast.Name, ast.NameConstant, ast.Num)
_SimpleASTNodeT = Union[ast.Constant, ast.Name, ast.NameConstant, ast.Num]

def __comp_convert_truthy_falsy(node: _SimpleASTNodeT):
if isinstance(node, ast.Num):
node_val = node.n
elif isinstance(node, ast.Name):
node_val = node.id
else:
node_val = node.value
return ast.copy_location(ast.NameConstant(bool(node_val)), node)
else:
_SimpleASTNode = (ast.Constant, ast.Name)
_SimpleASTNodeT = Union[ast.Constant, ast.Name]

def __comp_convert_truthy_falsy(node: _SimpleASTNodeT):
return ast.copy_location(ast.Constant(bool(node.value)), node)

# Convert simple AST node (constant) into a falsy / truthy. Anything other than 0, None, and an empty string '' is
# considered a truthy, while the listed exceptions are considered falsy values - following the semantics of Python's
# bool() builtin.
_convert_truthy_falsy = __comp_convert_truthy_falsy

class PythonOpToSympyConverter(ast.NodeTransformer):
"""
Replaces various operations with the appropriate SymPy functions to avoid non-symbolic evaluation.
Expand Down Expand Up @@ -1067,6 +1094,13 @@ def visit_Compare(self, node: ast.Compare):
raise NotImplementedError
op = node.ops[0]
arguments = [node.left, node.comparators[0]]

# Ensure constant values in boolean comparisons are interpreted als booleans.
if isinstance(node.left, ast.Compare) and isinstance(node.comparators[0], _SimpleASTNode):
arguments[1] = _convert_truthy_falsy(node.comparators[0])
elif isinstance(node.left, _SimpleASTNode) and isinstance(node.comparators[0], ast.Compare):
arguments[0] = _convert_truthy_falsy(node.left)

func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_comparators[type(op)], ctx=ast.Load()), node)
new_node = ast.Call(func=func_node, args=[self.visit(arg) for arg in arguments], keywords=[])
return ast.copy_location(new_node, node)
Expand Down
23 changes: 22 additions & 1 deletion tests/passes/dead_code_elimination_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
""" Various tests for dead code elimination passes. """

import numpy as np
Expand Down Expand Up @@ -45,6 +45,26 @@ def test_dse_unconditional():
assert set(sdfg.states()) == {s, s2, e}


def test_dse_edge_condition_with_integer_as_boolean_regression():
"""
This is a regression test for issue #1129, which describes dead state elimination incorrectly eliminating interstate
edges when integers are used as boolean values in interstate edge conditions. Code taken from issue #1129.
"""
sdfg = dace.SDFG('dse_edge_condition_with_integer_as_boolean_regression')
sdfg.add_scalar('N', dtype=dace.int32, transient=True)
sdfg.add_scalar('result', dtype=dace.int32)
state_init = sdfg.add_state()
state_middle = sdfg.add_state()
state_end = sdfg.add_state()
sdfg.add_edge(state_init, state_end, dace.InterstateEdge(condition='(not ((N > 20) != 0))',
assignments={'result': 'N'}))
sdfg.add_edge(state_init, state_middle, dace.InterstateEdge(condition='((N > 20) != 0)'))
sdfg.add_edge(state_middle, state_end, dace.InterstateEdge(assignments={'result': '20'}))

res = DeadStateElimination().apply_pass(sdfg, {})
assert res is None


def test_dde_simple():

@dace.program
Expand Down Expand Up @@ -307,6 +327,7 @@ def test_dce_add_type_hint_of_variable(dtype):
if __name__ == '__main__':
test_dse_simple()
test_dse_unconditional()
test_dse_edge_condition_with_integer_as_boolean_regression()
test_dde_simple()
test_dde_libnode()
test_dde_access_node_in_scope(False)
Expand Down

0 comments on commit c83f601

Please sign in to comment.