Skip to content

Commit

Permalink
Merge pull request #173 from bcaller/recursion
Browse files Browse the repository at this point in the history
 Recursive function calls shouldn't raise RecursionError
  • Loading branch information
KevinHock authored Sep 8, 2018
2 parents c7b244d + 093f506 commit f023eaa
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 2 deletions.
32 changes: 32 additions & 0 deletions examples/vulnerable_code/recursive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from flask import Flask, request

app = Flask(__name__)


def recur_without_any_propagation(x):
if len(x) < 20:
return recur_without_any_propagation("a" * 24)
return "Done"


def recur_no_propagation_false_positive(x):
if len(x) < 20:
return recur_no_propagation_false_positive(x + "!")
return "Done"


def recur_with_propagation(x):
if len(x) < 20:
return recur_with_propagation(x + "!")
return x


@app.route('/recursive')
def route():
param = request.args.get('param', 'not set')
repeated_completely_untainted = recur_without_any_propagation(param)
app.db.execute(repeated_completely_untainted)
repeated_untainted = recur_no_propagation_false_positive(param)
app.db.execute(repeated_untainted)
repeated_tainted = recur_with_propagation(param)
app.db.execute(repeated_tainted)
2 changes: 2 additions & 0 deletions pyt/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def main(command_line_args=sys.argv[1:]): # noqa: C901
)

initialize_constraint_table(cfg_list)
log.info("Analysing")
analyse(cfg_list)
log.info("Finding vulnerabilities")
vulnerabilities = find_vulnerabilities(
cfg_list,
args.blackbox_mapping_file,
Expand Down
9 changes: 9 additions & 0 deletions pyt/cfg/expr_visitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import logging

from .alias_helper import handle_aliases_in_calls
from ..core.ast_helper import (
Expand Down Expand Up @@ -30,6 +31,8 @@
from .stmt_visitor import StmtVisitor
from .stmt_visitor_helper import CALL_IDENTIFIER

log = logging.getLogger(__name__)


class ExprVisitor(StmtVisitor):
def __init__(
Expand All @@ -52,6 +55,7 @@ def __init__(
self.undecided = False
self.function_names = list()
self.function_return_stack = list()
self.function_definition_stack = list() # used to avoid recursion
self.module_definitions_stack = list()
self.prev_nodes_to_avoid = list()
self.last_control_flow_nodes = list()
Expand Down Expand Up @@ -543,6 +547,7 @@ def process_function(self, call_node, definition):
first_node
)
self.function_return_stack.pop()
self.function_definition_stack.pop()

return self.nodes[-1]

Expand All @@ -560,11 +565,15 @@ def visit_Call(self, node):
last_attribute = _id.rpartition('.')[-1]

if definition:
if definition in self.function_definition_stack:
log.debug("Recursion encountered in function %s", _id)
return self.add_blackbox_or_builtin_call(node, blackbox=True)
if isinstance(definition.node, ast.ClassDef):
self.add_blackbox_or_builtin_call(node, blackbox=False)
elif isinstance(definition.node, ast.FunctionDef):
self.undecided = False
self.function_return_stack.append(_id)
self.function_definition_stack.append(definition)
return self.process_function(node, definition)
else:
raise Exception('Definition was neither FunctionDef or ' +
Expand Down
4 changes: 4 additions & 0 deletions pyt/web_frameworks/framework_adaptor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A generic framework adaptor that leaves route criteria to the caller."""

import ast
import logging

from ..cfg import make_cfg
from ..core.ast_helper import Arguments
Expand All @@ -10,6 +11,8 @@
TaintedNode
)

log = logging.getLogger(__name__)


class FrameworkAdaptor():
"""An engine that uses the template pattern to find all
Expand All @@ -31,6 +34,7 @@ def __init__(

def get_func_cfg_with_tainted_args(self, definition):
"""Build a function cfg and return it, with all arguments tainted."""
log.debug("Getting CFG for %s", definition.name)
func_cfg = make_cfg(
definition.node,
self.project_modules,
Expand Down
8 changes: 8 additions & 0 deletions tests/cfg/cfg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .cfg_base_test_case import CFGBaseTestCase

from pyt.core.node_types import (
BBorBInode,
EntryOrExitNode,
Node
)
Expand Down Expand Up @@ -1389,6 +1390,13 @@ def test_call_on_call(self):
path = 'examples/example_inputs/call_on_call.py'
self.cfg_create_from_file(path)

def test_recursive_function(self):
path = 'examples/example_inputs/recursive.py'
self.cfg_create_from_file(path)
recursive_call = self.cfg.nodes[7]
assert recursive_call.label == '~call_3 = ret_rec(wat)'
assert isinstance(recursive_call, BBorBInode) # Not RestoreNode


class CFGCallWithAttributeTest(CFGBaseTestCase):
def setUp(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ def test_targets_with_recursive(self):
excluded_files = ""

included_files = discover_files(targets, excluded_files, True)
self.assertEqual(len(included_files), 31)
self.assertEqual(len(included_files), 32)

def test_targets_with_recursive_and_excluded(self):
targets = ["examples/vulnerable_code/"]
excluded_files = "inter_command_injection.py"

included_files = discover_files(targets, excluded_files, True)
self.assertEqual(len(included_files), 30)
self.assertEqual(len(included_files), 31)
5 changes: 5 additions & 0 deletions tests/vulnerabilities/vulnerabilities_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,11 @@ def assert_vulnerable(fixture):
assert_vulnerable('result = repr(str("%s" % TAINT.lower().upper()))')
assert_vulnerable('result = repr(str("{}".format(TAINT.lower())))')

def test_recursion(self):
# Really this file only has one vulnerability, but for now it's safer to keep the false positive.
vulnerabilities = self.run_analysis('examples/vulnerable_code/recursive.py')
self.assert_length(vulnerabilities, expected_length=2)


class EngineDjangoTest(VulnerabilitiesBaseTestCase):
def run_analysis(self, path):
Expand Down

0 comments on commit f023eaa

Please sign in to comment.