diff --git a/examples/vulnerable_code/recursive.py b/examples/vulnerable_code/recursive.py new file mode 100644 index 00000000..d2cd6163 --- /dev/null +++ b/examples/vulnerable_code/recursive.py @@ -0,0 +1,32 @@ +from flask import Flask, request + +app = Flask(__name__) + + +def recur_without_any_propagation(x): + if len(x) < 20: + return recur_without_any_propagation("a" * 24) + return "Done" + + +def recur_no_propagation_false_positive(x): + if len(x) < 20: + return recur_no_propagation_false_positive(x + "!") + return "Done" + + +def recur_with_propagation(x): + if len(x) < 20: + return recur_with_propagation(x + "!") + return x + + +@app.route('/recursive') +def route(): + param = request.args.get('param', 'not set') + repeated_completely_untainted = recur_without_any_propagation(param) + app.db.execute(repeated_completely_untainted) + repeated_untainted = recur_no_propagation_false_positive(param) + app.db.execute(repeated_untainted) + repeated_tainted = recur_with_propagation(param) + app.db.execute(repeated_tainted) diff --git a/pyt/__main__.py b/pyt/__main__.py index d952cbe3..8646bceb 100644 --- a/pyt/__main__.py +++ b/pyt/__main__.py @@ -125,7 +125,9 @@ def main(command_line_args=sys.argv[1:]): # noqa: C901 ) initialize_constraint_table(cfg_list) + log.info("Analysing") analyse(cfg_list) + log.info("Finding vulnerabilities") vulnerabilities = find_vulnerabilities( cfg_list, args.blackbox_mapping_file, diff --git a/pyt/cfg/expr_visitor.py b/pyt/cfg/expr_visitor.py index 6623d717..57537875 100644 --- a/pyt/cfg/expr_visitor.py +++ b/pyt/cfg/expr_visitor.py @@ -1,4 +1,5 @@ import ast +import logging from .alias_helper import handle_aliases_in_calls from ..core.ast_helper import ( @@ -30,6 +31,8 @@ from .stmt_visitor import StmtVisitor from .stmt_visitor_helper import CALL_IDENTIFIER +log = logging.getLogger(__name__) + class ExprVisitor(StmtVisitor): def __init__( @@ -52,6 +55,7 @@ def __init__( self.undecided = False self.function_names = list() self.function_return_stack = list() + self.function_definition_stack = list() # used to avoid recursion self.module_definitions_stack = list() self.prev_nodes_to_avoid = list() self.last_control_flow_nodes = list() @@ -543,6 +547,7 @@ def process_function(self, call_node, definition): first_node ) self.function_return_stack.pop() + self.function_definition_stack.pop() return self.nodes[-1] @@ -560,11 +565,15 @@ def visit_Call(self, node): last_attribute = _id.rpartition('.')[-1] if definition: + if definition in self.function_definition_stack: + log.debug("Recursion encountered in function %s", _id) + return self.add_blackbox_or_builtin_call(node, blackbox=True) if isinstance(definition.node, ast.ClassDef): self.add_blackbox_or_builtin_call(node, blackbox=False) elif isinstance(definition.node, ast.FunctionDef): self.undecided = False self.function_return_stack.append(_id) + self.function_definition_stack.append(definition) return self.process_function(node, definition) else: raise Exception('Definition was neither FunctionDef or ' + diff --git a/pyt/web_frameworks/framework_adaptor.py b/pyt/web_frameworks/framework_adaptor.py index 2bc4d7ee..96d2a32f 100644 --- a/pyt/web_frameworks/framework_adaptor.py +++ b/pyt/web_frameworks/framework_adaptor.py @@ -1,6 +1,7 @@ """A generic framework adaptor that leaves route criteria to the caller.""" import ast +import logging from ..cfg import make_cfg from ..core.ast_helper import Arguments @@ -10,6 +11,8 @@ TaintedNode ) +log = logging.getLogger(__name__) + class FrameworkAdaptor(): """An engine that uses the template pattern to find all @@ -31,6 +34,7 @@ def __init__( def get_func_cfg_with_tainted_args(self, definition): """Build a function cfg and return it, with all arguments tainted.""" + log.debug("Getting CFG for %s", definition.name) func_cfg = make_cfg( definition.node, self.project_modules, diff --git a/tests/cfg/cfg_test.py b/tests/cfg/cfg_test.py index 3c215983..ef478396 100644 --- a/tests/cfg/cfg_test.py +++ b/tests/cfg/cfg_test.py @@ -3,6 +3,7 @@ from .cfg_base_test_case import CFGBaseTestCase from pyt.core.node_types import ( + BBorBInode, EntryOrExitNode, Node ) @@ -1389,6 +1390,13 @@ def test_call_on_call(self): path = 'examples/example_inputs/call_on_call.py' self.cfg_create_from_file(path) + def test_recursive_function(self): + path = 'examples/example_inputs/recursive.py' + self.cfg_create_from_file(path) + recursive_call = self.cfg.nodes[7] + assert recursive_call.label == '~call_3 = ret_rec(wat)' + assert isinstance(recursive_call, BBorBInode) # Not RestoreNode + class CFGCallWithAttributeTest(CFGBaseTestCase): def setUp(self): diff --git a/tests/main_test.py b/tests/main_test.py index 1e33ee24..bc985629 100644 --- a/tests/main_test.py +++ b/tests/main_test.py @@ -108,11 +108,11 @@ def test_targets_with_recursive(self): excluded_files = "" included_files = discover_files(targets, excluded_files, True) - self.assertEqual(len(included_files), 31) + self.assertEqual(len(included_files), 32) def test_targets_with_recursive_and_excluded(self): targets = ["examples/vulnerable_code/"] excluded_files = "inter_command_injection.py" included_files = discover_files(targets, excluded_files, True) - self.assertEqual(len(included_files), 30) + self.assertEqual(len(included_files), 31) diff --git a/tests/vulnerabilities/vulnerabilities_test.py b/tests/vulnerabilities/vulnerabilities_test.py index 52fbfb2c..8f5e70f1 100644 --- a/tests/vulnerabilities/vulnerabilities_test.py +++ b/tests/vulnerabilities/vulnerabilities_test.py @@ -465,6 +465,11 @@ def assert_vulnerable(fixture): assert_vulnerable('result = repr(str("%s" % TAINT.lower().upper()))') assert_vulnerable('result = repr(str("{}".format(TAINT.lower())))') + def test_recursion(self): + # Really this file only has one vulnerability, but for now it's safer to keep the false positive. + vulnerabilities = self.run_analysis('examples/vulnerable_code/recursive.py') + self.assert_length(vulnerabilities, expected_length=2) + class EngineDjangoTest(VulnerabilitiesBaseTestCase): def run_analysis(self, path):