diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index c47993a2b2..96ca3f2e59 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -28,5 +28,10 @@ jobs: - name: Build and install capstone run: pip install ./bindings/python - - name: Run tests + - name: Run legacy tests run: python ./bindings/python/tests/test_all.py + + - name: cstest.py integration tests + run: | + cd suite/cstest/test/ + python3 ./integration_tests.py "python3 ../../../bindings/python/py_cstest/cstest.py" diff --git a/bindings/python/capstone/__init__.py b/bindings/python/capstone/__init__.py index 75ad68083f..e9669381e6 100755 --- a/bindings/python/capstone/__init__.py +++ b/bindings/python/capstone/__init__.py @@ -350,13 +350,15 @@ CS_AC_READ_WRITE = (2) # Capstone syntax value -CS_OPT_SYNTAX_DEFAULT = 1 << 1 # Default assembly syntax of all platforms (CS_OPT_SYNTAX) -CS_OPT_SYNTAX_INTEL = 1 << 2 # Intel X86 asm syntax - default syntax on X86 (CS_OPT_SYNTAX, CS_ARCH_X86) -CS_OPT_SYNTAX_ATT = 1 << 3 # ATT asm syntax (CS_OPT_SYNTAX, CS_ARCH_X86) -CS_OPT_SYNTAX_NOREGNAME = 1 << 4 # Asm syntax prints register name with only number - (CS_OPT_SYNTAX, CS_ARCH_PPC, CS_ARCH_ARM) -CS_OPT_SYNTAX_MASM = 1 << 5 # MASM syntax (CS_OPT_SYNTAX, CS_ARCH_X86) -CS_OPT_SYNTAX_MOTOROLA = 1 << 6 # MOS65XX use $ as hex prefix -CS_OPT_SYNTAX_CS_REG_ALIAS = 1 << 7 # Prints common register alias which are not defined in LLVM (ARM: r9 = sb etc.) +CS_OPT_SYNTAX_DEFAULT = (1 << 1) # Default assembly syntax of all platforms (CS_OPT_SYNTAX) +CS_OPT_SYNTAX_INTEL = (1 << 2) # Intel X86 asm syntax - default syntax on X86 (CS_OPT_SYNTAX, CS_ARCH_X86) +CS_OPT_SYNTAX_ATT = (1 << 3) # ATT asm syntax (CS_OPT_SYNTAX, CS_ARCH_X86) +CS_OPT_SYNTAX_NOREGNAME = (1 << 4) # Asm syntax prints register name with only number - (CS_OPT_SYNTAX, CS_ARCH_PPC, CS_ARCH_ARM) +CS_OPT_SYNTAX_MASM = (1 << 5) # MASM syntax (CS_OPT_SYNTAX, CS_ARCH_X86) +CS_OPT_SYNTAX_MOTOROLA = (1 << 6) # MOS65XX use $ as hex prefix +CS_OPT_SYNTAX_CS_REG_ALIAS = (1 << 7) # Prints common register alias which are not defined in LLVM (ARM: r9 = sb etc.) +CS_OPT_SYNTAX_PERCENT = (1 << 8) # Prints the % in front of PPC registers. +CS_OPT_DETAIL_REAL = (1 << 1) # If enabled, always sets the real instruction detail.Even if the instruction is an alias. # Capstone error type CS_ERR_OK = 0 # No error: everything was fine @@ -1021,9 +1023,10 @@ def __del__(self): except: # _cs might be pulled from under our feet pass - - # def option(self, opt_type, opt_value): - # return _cs.cs_option(self.csh, opt_type, opt_value) + def option(self, opt_type, opt_value): + status = _cs.cs_option(self.csh, opt_type, opt_value) + if status != CS_ERR_OK: + raise CsError(status) # is this a diet engine? diff --git a/bindings/python/py_cstest/cstest.py b/bindings/python/py_cstest/cstest.py index 83952257d8..017abe8b99 100755 --- a/bindings/python/py_cstest/cstest.py +++ b/bindings/python/py_cstest/cstest.py @@ -3,44 +3,314 @@ # SPDX-License-Identifier: BSD-3 import argparse -import logging as log +import logging import subprocess as sp import sys import os import yaml +import capstone +import cs_modes +from capstone import CsInsn, Cs, CS_ARCH_AARCH64, CS_MODE_64, CS_MODE_16 +from compare import ( + compare_asm_text, + compare_str, + compare_tbool, + compare_uint8, + compare_int8, + compare_uint16, + compare_int16, + compare_uint32, + compare_int32, + compare_uint64, + compare_int64, + compare_fp, + compare_enum, + compare_bit_flags, + compare_reg, +) +from enum import Enum from pathlib import Path +log = logging.getLogger("__name__") + + +def get_cs_int_attr(cs, attr: str, err_msg_pre: str): + try: + attr_int = getattr(cs, attr) + if not isinstance(attr_int, int): + raise AttributeError(f"{attr} not found") + return attr_int + except AttributeError: + log.warning(f"{err_msg_pre}: Capstone doesn't have the attribute '{attr}'") + return None + + +def arch_bits(arch: int, mode: int) -> int: + if arch == CS_ARCH_AARCH64 or mode & CS_MODE_64: + return 64 + elif mode & CS_MODE_16: + return 16 + return 32 + + +class TestResult(Enum): + SUCCESS = 0 + FAILED = 1 + SKIPPED = 2 + ERROR = 3 + class TestStats: def __init__(self, total_file_count: int): self.total_file_count = total_file_count + self.valid_test_files = 0 self.test_case_count = 0 self.success = 0 self.failed = 0 self.skipped = 0 - self.error = 0 + self.errors = 0 self.invalid_files = 0 self.err_msgs: list[str] = list() - def add_error(self, msg: str): - self.error += 1 + def add_error_msg(self, msg: str): self.err_msgs.append(msg) + def add_invalid_file_dp(self): + self.invalid_files += 1 + self.errors += 1 + + def add_test_case_data_point(self, dp: TestResult): + if dp == TestResult.SUCCESS: + self.success += 1 + elif dp == TestResult.FAILED: + self.failed += 1 + elif dp == TestResult.SKIPPED: + self.skipped += 1 + elif dp == TestResult.ERROR: + self.errors += 1 + self.failed += 1 + else: + raise ValueError(f"Unhandled TestResult: {dp}") + + def set_total_valid_files(self, total_valid_files: int): + self.total_valid_files = total_valid_files + def set_total_test_cases(self, total_test_cases: int): self.test_case_count = total_test_cases def get_test_case_count(self) -> int: return self.test_case_count + def print_evaluate(self): + if self.total_file_count == 0: + log.error("No test files found!") + exit(-1) + if self.test_case_count == 0: + log.error("No test cases found!") + exit(-1) + if self.err_msgs: + print("Error messages:") + for error in self.err_msgs: + print(f" - {error}") + + print("\n-----------------------------------------") + print("Test run statistics\n") + print(f"Valid files: {self.total_valid_files}") + print(f"Invalid files: {self.invalid_files}") + print(f"Errors: {self.errors}\n") + print("Test cases:") + print(f"\tTotal: {self.test_case_count}") + print(f"\tSuccessful: {self.success}") + print(f"\tSkipped: {self.skipped}") + print(f"\tFailed: {self.failed}") + print("-----------------------------------------") + print("") + + if self.test_case_count != self.success + self.failed + self.skipped: + log.error( + "Inconsistent statistics: total != successful + failed + skipped\n" + ) + + if self.errors != 0: + log.error("Failed with errors\n") + exit(-1) + elif self.failed != 0: + log.warning("Not all tests succeeded\n") + exit(-1) + log.info("All tests succeeded.\n") + exit(0) + + +class TestInput: + def __init__(self, input_dict: dict): + self.input_dict = input_dict + if "bytes" not in self.input_dict: + raise ValueError("Error: 'Missing required mapping field'\nField: 'bytes'.") + if "options" not in self.input_dict: + raise ValueError( + "Error: 'Missing required mapping field'\nField: 'options'." + ) + if "arch" not in self.input_dict: + raise ValueError("Error: 'Missing required mapping field'\nField: 'arch'.") + self.in_bytes = bytes(self.input_dict["bytes"]) + self.options = self.input_dict["options"] + self.arch = self.input_dict["arch"] + + self.name = "" if "name" not in self.input_dict else self.input_dict["name"] + if "address" not in self.input_dict: + self.address: int = 0 + else: + assert isinstance(self.input_dict["address"], int) + self.address = self.input_dict["address"] + self.handle = None + self.arch_bits = 0 + + def setup(self): + log.debug(f"Init {self}") + arch = get_cs_int_attr(capstone, self.arch, "CS_ARCH") + if arch is None: + cs_name = f"CS_ARCH_{self.arch.upper()}" + arch = get_cs_int_attr(capstone, cs_name, "CS_ARCH") + if arch is None: + raise ValueError( + f"Couldn't init architecture as '{self.arch}' or '{cs_name}'.\n" + f"'{self.arch}' is not mapped to a capstone architecture." + ) + self.handle = Cs(arch, 0) + new_mode = 0 + for opt in self.options: + if "CS_MODE_" in opt: + mode = get_cs_int_attr(capstone, opt, "CS_OPT") + if mode: + new_mode |= mode + continue + if "CS_OPT_" in opt and opt in cs_modes.configs: + mtype, val = cs_modes.configs[opt] + self.handle.option(mtype, val) + continue + log.warning(f"Option: '{opt}' not used") + + self.handle.mode = new_mode + self.arch_bits = arch_bits(self.handle.arch, self.handle.mode) + log.debug("Init done") + + def decode(self) -> list[CsInsn]: + if not self.handle: + raise ValueError("self.handle is None. Must be setup before.") + return list(self.handle.disasm(self.in_bytes, self.address)) + + def __str__(self): + if self.name: + return self.name + return ( + f"TestInput {{ arch: {self.arch}, options: {self.options}, " + f"addr: {self.address:x}, bytes: {self.in_bytes} }}" + ) + + +class TestExpected: + def __init__(self, expected_dict: dict, arch_bits: int): + self.arch_bits = arch_bits + self.expected_dict = expected_dict + self.insns = ( + list() if "insns" not in self.expected_dict else self.expected_dict["insns"] + ) + + def compare(self, actual_insns: list[CsInsn]) -> TestResult: + if len(actual_insns) != len(self.insns): + log.error( + "Number of decoded instructions don't match (actual != expected): " + f"{len(actual_insns)} != {len(self.insns):#x}" + ) + return TestResult.FAILED + for a_insn, e_insn in zip(actual_insns, self.insns): + if "asm_text" in self.expected_dict and not compare_asm_text( + a_insn, + self.expected_dict["asm_text"], + self.arch_bits, + ): + return TestResult.FAILED + + if "mnemonic" in self.expected_dict and not compare_str( + a_insn.mnemonic, self.expected_dict["mnemonic"], "mnemonic" + ): + return TestResult.FAILED + + if "op_str" in self.expected_dict and not compare_str( + a_insn.op_str, self.expected_dict["op_str"], "op_str" + ): + return TestResult.FAILED + + if "id" in self.expected_dict and not compare_uint32( + a_insn.id, self.expected_dict["id"], "id" + ): + return TestResult.FAILED + + if "is_alias" in self.expected_dict and not compare_tbool( + a_insn.is_alias, self.expected_dict["is_alias"], "is_alias" + ): + return TestResult.FAILED + + if "alias_id" in self.expected_dict and not compare_uint32( + a_insn.alias_id, self.expected_dict["alias_id"], "alias_id" + ): + return TestResult.FAILED + + if "details" in self.expected_dict: + pass + return TestResult.SUCCESS + class TestCase: def __init__(self, test_case_dict: dict): + self.arch_bits = arch_bits self.tc_dict = test_case_dict + if "input" not in self.tc_dict: + raise ValueError("Mandatory field 'input' missing") + if "expected" not in self.tc_dict: + raise ValueError("Mandatory field 'expected' missing") + self.input = TestInput(self.tc_dict["input"]) + self.expected = TestExpected(self.tc_dict["expected"], self.input.arch_bits) + self.skip = "skip" in self.tc_dict + if self.skip and "skip_reason" not in self.tc_dict: + raise ValueError( + "If 'skip' field is set a 'skip_reason' field must be set as well." + ) + self.skip_reason = ( + self.tc_dict["skip_reason"] if "skip_reason" in self.tc_dict else "" + ) + + def __str__(self) -> str: + return f"{self.input}" + + def test(self) -> TestResult: + if self.skip: + log.info(f"Skip {self}\nReason: {self.skip_reason}") + return TestResult.SKIPPED + + try: + self.input.setup() + except Exception as e: + log.error(f"Setup failed with: {e}") + return TestResult.ERROR + + try: + insns = self.input.decode() + except Exception as e: + log.error(f"Decode failed with: {e}") + return TestResult.ERROR + + try: + return self.expected.compare(insns) + except Exception as e: + log.error(f"Compare expected failed with: {e}") + return TestResult.ERROR class TestFile: def __init__(self, tfile_path: Path): + self.path = tfile_path with open(tfile_path) as f: try: self.content = yaml.safe_load(f) @@ -56,6 +326,9 @@ def __init__(self, tfile_path: Path): def num_test_cases(self) -> int: return len(self.test_cases) + def __str__(self) -> str: + return f"{self.path}" + class CSTest: def __init__(self, path: Path, exclude: list[Path], include: list[Path]): @@ -73,7 +346,7 @@ def __init__(self, path: Path, exclude: list[Path], include: list[Path]): if f.suffix in [".yaml", ".yml"]: self.yaml_paths.append(f) - log.info(f"Found {len(self.yaml_paths)} test files.") + log.info(f"Test files found: {len(self.yaml_paths)}") self.stats = TestStats(len(self.yaml_paths)) self.test_files: list[TestFile] = list() @@ -82,21 +355,46 @@ def parse_files(self): total_files = len(self.yaml_paths) count = 1 for tfile in self.yaml_paths: - print(f"Parse {count}/{total_files}: {tfile.name}", end=f"{' ' * 20}\r", flush=True) + print( + f"Parse {count}/{total_files}: {tfile.name}", + end=f"{' ' * 20}\r", + flush=True, + ) try: tf = TestFile(tfile) total_test_cases += tf.num_test_cases() self.test_files.append(tf) - except (yaml.YAMLError, ValueError) as e: - self.stats.add_error(str(e)) - log.error(f"Invalid YAML file: {tfile}") + except yaml.YAMLError as e: + self.stats.add_error_msg(str(e)) + self.stats.add_invalid_file_dp() + log.error("Error: 'libyaml parser error'") + log.error(f"{e}") + log.error(f"Failed to parse test file '{tfile}'") + except ValueError as e: + self.stats.add_error_msg(str(e)) + self.stats.add_invalid_file_dp() + log.error(f"Error: ValueError: {e}") + log.error(f"Failed to parse test file '{tfile}'") finally: count += 1 + self.stats.set_total_valid_files(len(self.test_files)) self.stats.set_total_test_cases(total_test_cases) - log.info(f"Found {self.stats.get_test_case_count()} test cases.") + log.info(f"Found {self.stats.get_test_case_count()} test cases.{' ' * 20}") def run_tests(self): self.parse_files() + for tf in self.test_files: + log.info(f"Test file: {tf}\n") + for tc in tf.test_cases: + log.info(f"Run test: {tc}") + try: + result = tc.test() + except Exception as e: + self.stats.add_error_msg(str(e)) + self.stats.add_test_case_data_point(result) + log.info(result) + print() + self.stats.print_evaluate() def get_repo_root() -> str | None: @@ -115,7 +413,6 @@ def parse_args() -> argparse.Namespace: repo_root = get_repo_root() if repo_root: parser.add_argument( - "-d", dest="search_dir", help="Directory to search for .yaml test files.", default=Path(f"{repo_root}/tests/"), @@ -123,7 +420,6 @@ def parse_args() -> argparse.Namespace: ) else: parser.add_argument( - "-d", dest="search_dir", help="Directory to search for .yaml test files.", required=True, @@ -158,18 +454,26 @@ def parse_args() -> argparse.Namespace: if __name__ == "__main__": log_levels = { - "debug": log.DEBUG, - "info": log.INFO, - "warning": log.WARNING, - "error": log.ERROR, - "fatal": log.FATAL, - "critical": log.CRITICAL, + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "fatal": logging.FATAL, + "critical": logging.CRITICAL, } args = parse_args() - log.basicConfig( - level=log_levels[args.verbosity], - stream=sys.stdout, - format="%(levelname)-5s - %(message)s", - force=True, - ) + format = logging.Formatter("%(levelname)-5s - %(message)s", None, "%") + log.setLevel(log_levels[args.verbosity]) + + h1 = logging.StreamHandler(sys.stdout) + h1.setLevel(log_levels[args.verbosity]) + h1.addFilter(lambda record: record.levelno <= log_levels[args.verbosity]) + h1.setFormatter(format) + + h2 = logging.StreamHandler(sys.stderr) + h2.setLevel(logging.WARNING) + h2.setFormatter(format) + + log.addHandler(h1) + log.addHandler(h2) CSTest(args.search_dir, args.exclude, args.include).run_tests()