diff --git a/pyt/__main__.py b/pyt/__main__.py index 52275da6..062fea8b 100644 --- a/pyt/__main__.py +++ b/pyt/__main__.py @@ -30,7 +30,25 @@ ) -def main(command_line_args=sys.argv[1:]): # noqa: C901 +def discover_files(targets, excluded_files, recursive=False): + included_files = list() + excluded_list = excluded_files.split(",") + for target in targets: + if os.path.isdir(target): + for root, dirs, files in os.walk(target): + for f in files: + fullpath = os.path.join(root, f) + if os.path.splitext(fullpath)[1] == '.py' and fullpath.split("/")[-1] not in excluded_list: + included_files.append(fullpath) + if not recursive: + break + else: + if target not in excluded_list: + included_files.append(target) + return included_files + + +def main(command_line_args=sys.argv[1:]): args = parse_args(command_line_args) ui_mode = UImode.NORMAL @@ -39,60 +57,67 @@ def main(command_line_args=sys.argv[1:]): # noqa: C901 elif args.trim_reassigned_in: ui_mode = UImode.TRIM - path = os.path.normpath(args.filepath) + files = discover_files( + args.targets, + args.excluded_paths, + args.recursive + ) + + for path in files: + vulnerabilities = list() + if args.ignore_nosec: + nosec_lines = set() + else: + file = open(path, 'r') + lines = file.readlines() + nosec_lines = set( + lineno for + (lineno, line) in enumerate(lines, start=1) + if '#nosec' in line or '# nosec' in line + ) - if args.ignore_nosec: - nosec_lines = set() - else: - file = open(path, 'r') - lines = file.readlines() - nosec_lines = set( - lineno for - (lineno, line) in enumerate(lines, start=1) - if '#nosec' in line or '# nosec' in line - ) + if args.project_root: + directory = os.path.normpath(args.project_root) + else: + directory = os.path.dirname(path) + project_modules = get_modules(directory) + local_modules = get_directory_modules(directory) + tree = generate_ast(path) - if args.project_root: - directory = os.path.normpath(args.project_root) - else: - directory = os.path.dirname(path) - project_modules = get_modules(directory) - local_modules = get_directory_modules(directory) + cfg = make_cfg( + tree, + project_modules, + local_modules, + path + ) + cfg_list = [cfg] - tree = generate_ast(path) - cfg = make_cfg( - tree, - project_modules, - local_modules, - path - ) - cfg_list = [cfg] - framework_route_criteria = is_flask_route_function - if args.adaptor: - if args.adaptor.lower().startswith('e'): - framework_route_criteria = is_function - elif args.adaptor.lower().startswith('p'): - framework_route_criteria = is_function_without_leading_ - elif args.adaptor.lower().startswith('d'): - framework_route_criteria = is_django_view_function - # Add all the route functions to the cfg_list - FrameworkAdaptor( - cfg_list, - project_modules, - local_modules, - framework_route_criteria - ) + framework_route_criteria = is_flask_route_function + if args.adaptor: + if args.adaptor.lower().startswith('e'): + framework_route_criteria = is_function + elif args.adaptor.lower().startswith('p'): + framework_route_criteria = is_function_without_leading_ + elif args.adaptor.lower().startswith('d'): + framework_route_criteria = is_django_view_function + # Add all the route functions to the cfg_list + FrameworkAdaptor( + cfg_list, + project_modules, + local_modules, + framework_route_criteria + ) - initialize_constraint_table(cfg_list) - analyse(cfg_list) - vulnerabilities = find_vulnerabilities( - cfg_list, - ui_mode, - args.blackbox_mapping_file, - args.trigger_word_file, - nosec_lines - ) + initialize_constraint_table(cfg_list) + analyse(cfg_list) + vulnerabilities.extend(find_vulnerabilities( + cfg_list, + ui_mode, + args.blackbox_mapping_file, + args.trigger_word_file, + nosec_lines + )) if args.baseline: vulnerabilities = get_vulnerabilities_not_in_baseline( diff --git a/pyt/usage.py b/pyt/usage.py index 4930eb02..30286215 100644 --- a/pyt/usage.py +++ b/pyt/usage.py @@ -30,9 +30,8 @@ def valid_date(s): def _add_required_group(parser): required_group = parser.add_argument_group('required arguments') required_group.add_argument( - '-f', '--filepath', - help='Path to the file that should be analysed.', - type=str + 'targets', metavar='targets', type=str, nargs='+', + help='source file(s) or directory(s) to be tested' ) @@ -91,6 +90,17 @@ def _add_optional_group(parser): action='store_true', help='do not skip lines with # nosec comments' ) + optional_group.add_argument( + '-r', '--recursive', dest='recursive', + action='store_true', help='find and process files in subdirectories' + ) + optional_group.add_argument( + '-x', '--exclude', + dest='excluded_paths', + action='store', + default='', + help='Separate files with commas' + ) def _add_print_group(parser): @@ -110,8 +120,8 @@ def _add_print_group(parser): def _check_required_and_mutually_exclusive_args(parser, args): - if args.filepath is None: - parser.error('The -f/--filepath argument is required') + if args.targets is None: + parser.error('The targets argument is required') def parse_args(args): diff --git a/tests/main_test.py b/tests/main_test.py index eea6ff47..aee80c68 100644 --- a/tests/main_test.py +++ b/tests/main_test.py @@ -5,17 +5,18 @@ class MainTest(BaseTestCase): + @mock.patch('pyt.__main__.discover_files') @mock.patch('pyt.__main__.parse_args') @mock.patch('pyt.__main__.find_vulnerabilities') @mock.patch('pyt.__main__.text') - def test_text_output(self, mock_text, mock_find_vulnerabilities, mock_parse_args): + def test_text_output(self, mock_text, mock_find_vulnerabilities, mock_parse_args, mock_discover_files): mock_find_vulnerabilities.return_value = 'stuff' example_file = 'examples/vulnerable_code/inter_command_injection.py' output_file = 'mocked_outfile' + mock_discover_files.return_value = [example_file] mock_parse_args.return_value = mock.Mock( autospec=True, - filepath=example_file, project_root=None, baseline=None, json=None, @@ -32,17 +33,18 @@ def test_text_output(self, mock_text, mock_find_vulnerabilities, mock_parse_args mock_parse_args.return_value.output_file ) + @mock.patch('pyt.__main__.discover_files') @mock.patch('pyt.__main__.parse_args') @mock.patch('pyt.__main__.find_vulnerabilities') @mock.patch('pyt.__main__.json') - def test_json_output(self, mock_json, mock_find_vulnerabilities, mock_parse_args): + def test_json_output(self, mock_json, mock_find_vulnerabilities, mock_parse_args, mock_discover_files): mock_find_vulnerabilities.return_value = 'stuff' example_file = 'examples/vulnerable_code/inter_command_injection.py' output_file = 'mocked_outfile' + mock_discover_files.return_value = [example_file] mock_parse_args.return_value = mock.Mock( autospec=True, - filepath=example_file, project_root=None, baseline=None, json=True, diff --git a/tests/usage_test.py b/tests/usage_test.py index cae390e5..d9ed7cec 100644 --- a/tests/usage_test.py +++ b/tests/usage_test.py @@ -25,14 +25,14 @@ def test_no_args(self): self.maxDiff = None - EXPECTED = """usage: python -m pyt [-h] [-f FILEPATH] [-a ADAPTOR] [-pr PROJECT_ROOT] + EXPECTED = """usage: python -m pyt [-h] [-a ADAPTOR] [-pr PROJECT_ROOT] [-b BASELINE_JSON_FILE] [-j] [-m BLACKBOX_MAPPING_FILE] [-t TRIGGER_WORD_FILE] [-o OUTPUT_FILE] [--ignore-nosec] - [-trim] [-i] + [-r] [-x EXCLUDED_PATHS] [-trim] [-i] + targets [targets ...] required arguments: - -f FILEPATH, --filepath FILEPATH - Path to the file that should be analysed. + targets source file(s) or directory(s) to be tested optional arguments: -a ADAPTOR, --adaptor ADAPTOR @@ -52,6 +52,9 @@ def test_no_args(self): -o OUTPUT_FILE, --output OUTPUT_FILE write report to filename --ignore-nosec do not skip lines with # nosec comments + -r, --recursive find and process files in subdirectories + -x EXCLUDED_PATHS, --exclude EXCLUDED_PATHS + Separate files with commas print arguments: -trim, --trim-reassigned-in @@ -62,16 +65,17 @@ def test_no_args(self): self.assertEqual(stdout.getvalue(), EXPECTED) - def test_valid_args_but_no_filepath(self): + def test_valid_args_but_no_targets(self): with self.assertRaises(SystemExit): with capture_sys_output() as (_, stderr): parse_args(['-j']) - EXPECTED = """usage: python -m pyt [-h] [-f FILEPATH] [-a ADAPTOR] [-pr PROJECT_ROOT] + EXPECTED = """usage: python -m pyt [-h] [-a ADAPTOR] [-pr PROJECT_ROOT] [-b BASELINE_JSON_FILE] [-j] [-m BLACKBOX_MAPPING_FILE] [-t TRIGGER_WORD_FILE] [-o OUTPUT_FILE] [--ignore-nosec] - [-trim] [-i] -python -m pyt: error: The -f/--filepath argument is required\n""" + [-r] [-x EXCLUDED_PATHS] [-trim] [-i] + targets [targets ...] +python -m pyt: error: the following arguments are required: targets\n""" self.assertEqual(stderr.getvalue(), EXPECTED) @@ -89,7 +93,7 @@ def test_valid_args_but_no_filepath(self): def test_normal_usage(self): with capture_sys_output() as (stdout, stderr): - parse_args(['-f', 'foo.py']) + parse_args(['foo.py']) self.assertEqual(stdout.getvalue(), '') self.assertEqual(stderr.getvalue(), '')