Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chained function calls separated into multiple assignments #171

Merged
merged 1 commit into from
Sep 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyt/core/ast_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import subprocess
from functools import lru_cache

from .transformer import AsyncTransformer
from .transformer import PytTransformer


BLACK_LISTED_CALL_NAMES = ['self']
Expand Down Expand Up @@ -35,7 +35,7 @@ def generate_ast(path):
with open(path, 'r') as f:
try:
tree = ast.parse(f.read())
return AsyncTransformer().visit(tree)
return PytTransformer().visit(tree)
except SyntaxError: # pragma: no cover
global recursive
if not recursive:
Expand Down
52 changes: 51 additions & 1 deletion pyt/core/transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast


class AsyncTransformer(ast.NodeTransformer):
class AsyncTransformer():
"""Converts all async nodes into their synchronous counterparts."""

def visit_Await(self, node):
Expand All @@ -16,3 +16,53 @@ def visit_AsyncFor(self, node):

def visit_AsyncWith(self, node):
return self.visit(ast.With(**node.__dict__))


class ChainedFunctionTransformer():
def visit_chain(self, node, depth=1):
if (
isinstance(node.value, ast.Call) and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love how you did
node.value
node.value.func
node.value.func.value
in that order, it's super clean.

isinstance(node.value.func, ast.Attribute) and
isinstance(node.value.func.value, ast.Call)
):
# Node is assignment or return with value like `b.c().d()`
call_node = node.value
# If we want to handle nested functions in future, depth needs fixing
temp_var_id = '__chain_tmp_{}'.format(depth)
# AST tree is from right to left, so d() is the outer Call and b.c() is the inner Call
unvisited_inner_call = ast.Assign(
targets=[ast.Name(id=temp_var_id, ctx=ast.Store())],
value=call_node.func.value,
)
ast.copy_location(unvisited_inner_call, node)
inner_calls = self.visit_chain(unvisited_inner_call, depth + 1)
for inner_call_node in inner_calls:
ast.copy_location(inner_call_node, node)
outer_call = self.generic_visit(type(node)(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation nit: This is awesomely smart but more indentation may make it clearer

outer_call = self.generic_visit(
    type(node)(
        value=ast.Call(
            func=ast.Attribute(
                value=ast.Name(id=temp_var_id, ctx=ast.Load()),
                attr=call_node.func.attr,
                ctx=ast.Load(),
            ),
            args=call_node.args,
            keywords=call_node.keywords,
        ),
        **{field: value for field, value in ast.iter_fields(node) if field != 'value'}  # e.g. targets
    )
)

Nit: w/r/t generic_visit, since it isn't defined in this class, maybe inheriting from ast.NodeTransformer in both classes would be clearer.

value=ast.Call(
func=ast.Attribute(
value=ast.Name(id=temp_var_id, ctx=ast.Load()),
attr=call_node.func.attr,
ctx=ast.Load(),
),
args=call_node.args,
keywords=call_node.keywords,
),
**{field: value for field, value in ast.iter_fields(node) if field != 'value'} # e.g. targets
))
ast.copy_location(outer_call, node)
ast.copy_location(outer_call.value, node)
ast.copy_location(outer_call.value.func, node)
return [*inner_calls, outer_call]
else:
return [self.generic_visit(node)]

def visit_Assign(self, node):
return self.visit_chain(node)

def visit_Return(self, node):
return self.visit_chain(node)


class PytTransformer(AsyncTransformer, ChainedFunctionTransformer, ast.NodeTransformer):
pass
3 changes: 2 additions & 1 deletion tests/base_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pyt.cfg import make_cfg
from pyt.core.ast_helper import generate_ast
from pyt.core.module_definitions import project_definitions
from pyt.core.transformer import PytTransformer


class BaseTestCase(unittest.TestCase):
Expand Down Expand Up @@ -36,7 +37,7 @@ def cfg_create_from_ast(
):
project_definitions.clear()
self.cfg = make_cfg(
ast_tree,
PytTransformer().visit(ast_tree),
project_modules,
local_modules,
filename='?'
Expand Down
32 changes: 32 additions & 0 deletions tests/cfg/cfg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,3 +1497,35 @@ def test_name_for(self):

self.assert_length(self.cfg.nodes, expected_length=4)
self.assertEqual(self.cfg.nodes[1].label, 'for x in l:')


class CFGFunctionChain(CFGBaseTestCase):
def test_simple(self):
self.cfg_create_from_ast(
ast.parse('a = b.c(z)')
)
middle_nodes = self.cfg.nodes[1:-1]
self.assert_length(middle_nodes, expected_length=2)
self.assertEqual(middle_nodes[0].label, '~call_1 = ret_b.c(z)')
self.assertEqual(middle_nodes[0].func_name, 'b.c')
self.assertCountEqual(middle_nodes[0].right_hand_side_variables, ['z', 'b'])

def test_chain(self):
self.cfg_create_from_ast(
ast.parse('a = b.xxx.c(z).d(y)')
)
middle_nodes = self.cfg.nodes[1:-1]
self.assert_length(middle_nodes, expected_length=4)

self.assertEqual(middle_nodes[0].left_hand_side, '~call_1')
self.assertCountEqual(middle_nodes[0].right_hand_side_variables, ['b', 'z'])
self.assertEqual(middle_nodes[0].label, '~call_1 = ret_b.xxx.c(z)')

self.assertEqual(middle_nodes[1].left_hand_side, '__chain_tmp_1')
self.assertCountEqual(middle_nodes[1].right_hand_side_variables, ['~call_1'])

self.assertEqual(middle_nodes[2].left_hand_side, '~call_2')
self.assertCountEqual(middle_nodes[2].right_hand_side_variables, ['__chain_tmp_1', 'y'])

self.assertEqual(middle_nodes[3].left_hand_side, 'a')
self.assertCountEqual(middle_nodes[3].right_hand_side_variables, ['~call_2'])
22 changes: 20 additions & 2 deletions tests/core/transformer_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import ast
import unittest

from pyt.core.transformer import AsyncTransformer
from pyt.core.transformer import PytTransformer


class TransformerTest(unittest.TestCase):
"""Tests for the AsyncTransformer."""

def test_async_removed_by_transformer(self):
self.maxDiff = 99999
async_tree = ast.parse("\n".join([
"async def a():",
" async for b in c():",
Expand All @@ -30,7 +31,24 @@ def test_async_removed_by_transformer(self):
]))
self.assertIsInstance(sync_tree.body[0], ast.FunctionDef)

transformed = AsyncTransformer().visit(async_tree)
transformed = PytTransformer().visit(async_tree)
self.assertIsInstance(transformed.body[0], ast.FunctionDef)

self.assertEqual(ast.dump(transformed), ast.dump(sync_tree))

def test_chained_function(self):
chained_tree = ast.parse("\n".join([
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great use of join :D

"def a():",
" b = c.d(e).f(g).h(i).j(k)",
]))

separated_tree = ast.parse("\n".join([
"def a():",
" __chain_tmp_3 = c.d(e)",
" __chain_tmp_2 = __chain_tmp_3.f(g)",
" __chain_tmp_1 = __chain_tmp_2.h(i)",
" b = __chain_tmp_1.j(k)",
]))

transformed = PytTransformer().visit(chained_tree)
self.assertEqual(ast.dump(transformed), ast.dump(separated_tree))
2 changes: 1 addition & 1 deletion tests/vulnerabilities/vulnerabilities_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def test_path_traversal_sanitised_2_result(self):

def test_sql_result(self):
vulnerabilities = self.run_analysis('examples/vulnerable_code/sql/sqli.py')
self.assert_length(vulnerabilities, expected_length=2)
self.assert_length(vulnerabilities, expected_length=3)
vulnerability_description = str(vulnerabilities[0])
EXPECTED_VULNERABILITY_DESCRIPTION = """
File: examples/vulnerable_code/sql/sqli.py
Expand Down