Skip to content

Commit

Permalink
Merge pull request #171 from bcaller/chains
Browse files Browse the repository at this point in the history
Chained function calls separated into multiple assignments
  • Loading branch information
KevinHock authored Sep 2, 2018
2 parents 11567c4 + 2e91ce7 commit f56e761
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 7 deletions.
4 changes: 2 additions & 2 deletions pyt/core/ast_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import subprocess
from functools import lru_cache

from .transformer import AsyncTransformer
from .transformer import PytTransformer


BLACK_LISTED_CALL_NAMES = ['self']
Expand Down Expand Up @@ -35,7 +35,7 @@ def generate_ast(path):
with open(path, 'r') as f:
try:
tree = ast.parse(f.read())
return AsyncTransformer().visit(tree)
return PytTransformer().visit(tree)
except SyntaxError: # pragma: no cover
global recursive
if not recursive:
Expand Down
52 changes: 51 additions & 1 deletion pyt/core/transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast


class AsyncTransformer(ast.NodeTransformer):
class AsyncTransformer():
"""Converts all async nodes into their synchronous counterparts."""

def visit_Await(self, node):
Expand All @@ -16,3 +16,53 @@ def visit_AsyncFor(self, node):

def visit_AsyncWith(self, node):
return self.visit(ast.With(**node.__dict__))


class ChainedFunctionTransformer():
def visit_chain(self, node, depth=1):
if (
isinstance(node.value, ast.Call) and
isinstance(node.value.func, ast.Attribute) and
isinstance(node.value.func.value, ast.Call)
):
# Node is assignment or return with value like `b.c().d()`
call_node = node.value
# If we want to handle nested functions in future, depth needs fixing
temp_var_id = '__chain_tmp_{}'.format(depth)
# AST tree is from right to left, so d() is the outer Call and b.c() is the inner Call
unvisited_inner_call = ast.Assign(
targets=[ast.Name(id=temp_var_id, ctx=ast.Store())],
value=call_node.func.value,
)
ast.copy_location(unvisited_inner_call, node)
inner_calls = self.visit_chain(unvisited_inner_call, depth + 1)
for inner_call_node in inner_calls:
ast.copy_location(inner_call_node, node)
outer_call = self.generic_visit(type(node)(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id=temp_var_id, ctx=ast.Load()),
attr=call_node.func.attr,
ctx=ast.Load(),
),
args=call_node.args,
keywords=call_node.keywords,
),
**{field: value for field, value in ast.iter_fields(node) if field != 'value'} # e.g. targets
))
ast.copy_location(outer_call, node)
ast.copy_location(outer_call.value, node)
ast.copy_location(outer_call.value.func, node)
return [*inner_calls, outer_call]
else:
return [self.generic_visit(node)]

def visit_Assign(self, node):
return self.visit_chain(node)

def visit_Return(self, node):
return self.visit_chain(node)


class PytTransformer(AsyncTransformer, ChainedFunctionTransformer, ast.NodeTransformer):
pass
3 changes: 2 additions & 1 deletion tests/base_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pyt.cfg import make_cfg
from pyt.core.ast_helper import generate_ast
from pyt.core.module_definitions import project_definitions
from pyt.core.transformer import PytTransformer


class BaseTestCase(unittest.TestCase):
Expand Down Expand Up @@ -36,7 +37,7 @@ def cfg_create_from_ast(
):
project_definitions.clear()
self.cfg = make_cfg(
ast_tree,
PytTransformer().visit(ast_tree),
project_modules,
local_modules,
filename='?'
Expand Down
32 changes: 32 additions & 0 deletions tests/cfg/cfg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,3 +1497,35 @@ def test_name_for(self):

self.assert_length(self.cfg.nodes, expected_length=4)
self.assertEqual(self.cfg.nodes[1].label, 'for x in l:')


class CFGFunctionChain(CFGBaseTestCase):
def test_simple(self):
self.cfg_create_from_ast(
ast.parse('a = b.c(z)')
)
middle_nodes = self.cfg.nodes[1:-1]
self.assert_length(middle_nodes, expected_length=2)
self.assertEqual(middle_nodes[0].label, '~call_1 = ret_b.c(z)')
self.assertEqual(middle_nodes[0].func_name, 'b.c')
self.assertCountEqual(middle_nodes[0].right_hand_side_variables, ['z', 'b'])

def test_chain(self):
self.cfg_create_from_ast(
ast.parse('a = b.xxx.c(z).d(y)')
)
middle_nodes = self.cfg.nodes[1:-1]
self.assert_length(middle_nodes, expected_length=4)

self.assertEqual(middle_nodes[0].left_hand_side, '~call_1')
self.assertCountEqual(middle_nodes[0].right_hand_side_variables, ['b', 'z'])
self.assertEqual(middle_nodes[0].label, '~call_1 = ret_b.xxx.c(z)')

self.assertEqual(middle_nodes[1].left_hand_side, '__chain_tmp_1')
self.assertCountEqual(middle_nodes[1].right_hand_side_variables, ['~call_1'])

self.assertEqual(middle_nodes[2].left_hand_side, '~call_2')
self.assertCountEqual(middle_nodes[2].right_hand_side_variables, ['__chain_tmp_1', 'y'])

self.assertEqual(middle_nodes[3].left_hand_side, 'a')
self.assertCountEqual(middle_nodes[3].right_hand_side_variables, ['~call_2'])
22 changes: 20 additions & 2 deletions tests/core/transformer_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import ast
import unittest

from pyt.core.transformer import AsyncTransformer
from pyt.core.transformer import PytTransformer


class TransformerTest(unittest.TestCase):
"""Tests for the AsyncTransformer."""

def test_async_removed_by_transformer(self):
self.maxDiff = 99999
async_tree = ast.parse("\n".join([
"async def a():",
" async for b in c():",
Expand All @@ -30,7 +31,24 @@ def test_async_removed_by_transformer(self):
]))
self.assertIsInstance(sync_tree.body[0], ast.FunctionDef)

transformed = AsyncTransformer().visit(async_tree)
transformed = PytTransformer().visit(async_tree)
self.assertIsInstance(transformed.body[0], ast.FunctionDef)

self.assertEqual(ast.dump(transformed), ast.dump(sync_tree))

def test_chained_function(self):
chained_tree = ast.parse("\n".join([
"def a():",
" b = c.d(e).f(g).h(i).j(k)",
]))

separated_tree = ast.parse("\n".join([
"def a():",
" __chain_tmp_3 = c.d(e)",
" __chain_tmp_2 = __chain_tmp_3.f(g)",
" __chain_tmp_1 = __chain_tmp_2.h(i)",
" b = __chain_tmp_1.j(k)",
]))

transformed = PytTransformer().visit(chained_tree)
self.assertEqual(ast.dump(transformed), ast.dump(separated_tree))
2 changes: 1 addition & 1 deletion tests/vulnerabilities/vulnerabilities_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def test_path_traversal_sanitised_2_result(self):

def test_sql_result(self):
vulnerabilities = self.run_analysis('examples/vulnerable_code/sql/sqli.py')
self.assert_length(vulnerabilities, expected_length=2)
self.assert_length(vulnerabilities, expected_length=3)
vulnerability_description = str(vulnerabilities[0])
EXPECTED_VULNERABILITY_DESCRIPTION = """
File: examples/vulnerable_code/sql/sqli.py
Expand Down

0 comments on commit f56e761

Please sign in to comment.