Skip to content

Commit

Permalink
Test sarif generation (crytic#68)
Browse files Browse the repository at this point in the history
* test sarif generation in pytest

* Refactor testing file.

Co-authored-by: Filipe Casal <fcasal@users.noreply.github.com>
  • Loading branch information
2 people authored and 0xLucqs committed Sep 15, 2022
1 parent 23b7ca0 commit af0fff3
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 33 deletions.
9 changes: 5 additions & 4 deletions amarna/Result.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ def result_multiple_positions(
return ResultMultiplePositions(filenames, rule_name, text, position_list)


def create_sarif(
results: List[Any], fname: Optional[str] = None, printoutput: bool = False
) -> None:
def create_sarif(results: List[Any], fname: Optional[str] = None, printoutput: bool = False) -> str:
"""
Create the sarif output json for the results, and write it to file or print it.
"""
Expand All @@ -168,8 +166,11 @@ def create_sarif(
with open(os.path.join(os.getcwd(), fname), "w", encoding="utf8") as f:
json.dump(sarif, f, indent=1)

sarif_str = json.dumps(sarif)
if printoutput:
print(json.dumps(sarif))
print(sarif_str)

return sarif_str


def sarif_region_from_position(position: PositionType) -> Dict[str, int]:
Expand Down
12 changes: 6 additions & 6 deletions amarna/command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from amarna.amarna import Amarna, analyze_directory, analyze_file
from amarna.Result import Result, ResultMultiplePositions, output_result
from amarna.Result import SARIF_MODE, SUMMARY_MODE
from typing import List, Union
from typing import List, Union, Dict
import sys

example_usage = """---------------\nUsage examples\n---------------
Expand Down Expand Up @@ -34,11 +34,11 @@ def parse_comma_sep_strings(s: str) -> List[str]:
return []


def get_rule_names(rules: str, excluded: str) -> List[str]:
def get_rule_names(rule_str: str, excluded_str: str) -> List[str]:
ALL_RULES = Amarna.get_all_rule_names()

rules = parse_comma_sep_strings(rules)
excluded = parse_comma_sep_strings(excluded)
rules = parse_comma_sep_strings(rule_str)
excluded = parse_comma_sep_strings(excluded_str)

for rule in rules + excluded:
if rule not in ALL_RULES:
Expand All @@ -56,7 +56,7 @@ def get_rule_names(rules: str, excluded: str) -> List[str]:
def filter_results_from_disable(
results: List[Union[Result, ResultMultiplePositions]]
) -> List[Union[Result, ResultMultiplePositions]]:
first_lines_per_file = {}
first_lines_per_file: Dict[str, str] = {}
disable_token = "# amarna: disable="

new_results = []
Expand All @@ -77,7 +77,7 @@ def filter_results_from_disable(
continue

rule_tok = first_line.split(disable_token)[1]
rule_list = get_rule_names(None, rule_tok)
rule_list = get_rule_names("", rule_tok)

if result.rule_name in rule_list:
new_results.append(result)
Expand Down
65 changes: 42 additions & 23 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from pathlib import Path
import json
from typing import List, Union

from amarna.amarna import Amarna, analyze_directory, analyze_file
from amarna.Result import create_summary
from amarna.Result import Result, ResultMultiplePositions, create_summary, create_sarif
from amarna.command_line import filter_results_from_disable

_module_dir = Path(__file__).resolve().parent
Expand All @@ -20,6 +22,36 @@ def test_all() -> None:
_test_directory(subdir)


def load_sarif(results: List[Union[Result, ResultMultiplePositions]]) -> None:
"""
Loads a sarif str with json.loads to test if generation worked.
"""
sarif_str = create_sarif(results, None, False)

try:
json.loads(sarif_str)
except json.JSONDecodeError as e:
assert False, "Sarif generation is broken."


def compare_expected(summary_results: str, expected_filename: str) -> None:
"""
Compares the obtained summary results with the ones saved in the .expected files.
If the expected file does not exist, create it with the current summary_results.
"""
expected_result = str(TESTS_DIR.joinpath("expected", expected_filename + ".expected"))
try:
with open(expected_result, "r", encoding="utf8") as f:
expected = f.read()
assert expected == summary_results

except FileNotFoundError as e:
print("Expected test result does not exist. Creating it.")
print("at {}".format(expected_result))
with open(expected_result, "w", encoding="utf8") as f:
f.write(summary_results)


def _test_single(filename: str) -> None:
FILE, ext = os.path.splitext(filename)
if ext != ".cairo":
Expand All @@ -30,21 +62,15 @@ def _test_single(filename: str) -> None:
all_rules = Amarna.get_all_rule_names()

results = analyze_file(test_file, all_rules)

# filter results to test the # amarna: disable= rule
results = filter_results_from_disable(results)

# Generate summary and compare with the expected result
summary = create_summary(results)
compare_expected(summary, FILE)

expected_result = str(TESTS_DIR.joinpath("expected", FILE + ".expected"))
try:
with open(expected_result, "r", encoding="utf8") as f:
expected = f.read()
assert expected == summary

except FileNotFoundError as e:
print("Expected test result does not exist. Creating it.")
print("at {}".format(expected_result))
with open(expected_result, "w", encoding="utf8") as f:
f.write(summary)
load_sarif(results)


def _test_directory(filename: str) -> None:
Expand All @@ -56,16 +82,9 @@ def _test_directory(filename: str) -> None:
all_rules = Amarna.get_all_rule_names()

results = analyze_directory(filename, all_rules)
summary = create_summary(results)

expected_result = str(TESTS_DIR.joinpath("expected", test_name + ".expected"))
try:
with open(expected_result, "r", encoding="utf8") as f:
expected = f.read()
assert expected == summary
# Generate summary and compare with the expected result
summary = create_summary(results)
compare_expected(summary, test_name)

except FileNotFoundError as e:
print("Expected test result does not exist. Creating it.")
print("at {}".format(expected_result))
with open(expected_result, "w", encoding="utf8") as f:
f.write(summary)
load_sarif(results)

0 comments on commit af0fff3

Please sign in to comment.