diff --git a/slither/__main__.py b/slither/__main__.py index a4ff784db1..886d392c0d 100644 --- a/slither/__main__.py +++ b/slither/__main__.py @@ -348,6 +348,14 @@ def parse_args( default=defaults_flag_in_config["printers_to_run"], ) + group_printer.add_argument( + "--include-interfaces", + help="Include interfaces from inheritance-graph printer", + action="store_true", + dest="include_interfaces", + default=False, + ) + group_detector.add_argument( "--list-detectors", help="List available detectors", diff --git a/slither/printers/inheritance/inheritance_graph.py b/slither/printers/inheritance/inheritance_graph.py index a16ce273af..54b525c774 100644 --- a/slither/printers/inheritance/inheritance_graph.py +++ b/slither/printers/inheritance/inheritance_graph.py @@ -98,12 +98,21 @@ def _summary(self, contract): """ ret = "" + # Remove contracts that have "mock" in the name and if --include-interfaces in False (default) + # removes inherited interfaces + inheritance = [ + i + for i in contract.immediate_inheritance + if "mock" not in i.name.lower() + and (not i.is_interface or self.slither.include_interfaces) + ] + # Add arrows (number them if there is more than one path so we know order of declaration for inheritance). - if len(contract.immediate_inheritance) == 1: + if len(inheritance) == 1: immediate_inheritance = contract.immediate_inheritance[0] ret += f"c{contract.id}_{contract.name} -> c{immediate_inheritance.id}_{immediate_inheritance};\n" else: - for i, immediate_inheritance in enumerate(contract.immediate_inheritance): + for i, immediate_inheritance in enumerate(inheritance): ret += f'c{contract.id}_{contract.name} -> c{immediate_inheritance.id}_{immediate_inheritance} [ label="{i + 1}" ];\n' # Functions @@ -113,6 +122,7 @@ def _summary(self, contract): for f in contract.functions if not f.is_constructor and not f.is_constructor_variables + and not f.is_virtual and f.contract_declarer == contract and f.visibility in visibilities ] @@ -195,6 +205,12 @@ def output(self, filename): content = 'digraph "" {\n' for c in self.contracts: + if ( + "mock" in c.name.lower() + or c.is_library + or (c.is_interface and not self.slither.include_interfaces) + ): + continue content += self._summary(c) + "\n" content += "}" diff --git a/slither/slither.py b/slither/slither.py index 0f22185353..7adc0694ca 100644 --- a/slither/slither.py +++ b/slither/slither.py @@ -196,10 +196,12 @@ def __init__(self, target: Union[str, CryticCompile], **kwargs) -> None: if printers_to_run == "echidna": self.skip_data_dependency = True + # Used in inheritance-graph printer + self.include_interfaces = kwargs.get("include_interfaces", False) + self._init_parsing_and_analyses(kwargs.get("skip_analyze", False)) def _init_parsing_and_analyses(self, skip_analyze: bool) -> None: - for parser in self._parsers: try: parser.parse_contracts() diff --git a/tests/e2e/printers/test_data/test_contract_names/C.sol b/tests/e2e/printers/test_data/test_contract_names/C.sol index 90bc35df39..d6ba9b5c15 100644 --- a/tests/e2e/printers/test_data/test_contract_names/C.sol +++ b/tests/e2e/printers/test_data/test_contract_names/C.sol @@ -1,7 +1,21 @@ import "./A.sol"; -contract C is A { +interface MyInterfaceX { + function count() external view returns (uint256); + + function increment() external; +} + +contract C is A, MyInterfaceX { function c_main() public pure { a_main(); } + + function count() external view override returns (uint256){ + return 1; + } + + function increment() external override { + + } } diff --git a/tests/e2e/printers/test_printers.py b/tests/e2e/printers/test_printers.py index 26429d3381..3dea8b74a4 100644 --- a/tests/e2e/printers/test_printers.py +++ b/tests/e2e/printers/test_printers.py @@ -34,3 +34,15 @@ def test_inheritance_printer(solc_binary_path) -> None: assert counter["B -> A"] == 2 assert counter["C -> A"] == 1 + + # Lets also test the include/exclude interface behavior + # Check that the interface is not included + assert "MyInterfaceX" not in content + + slither.include_interfaces = True + output = printer.output("test_printer.dot") + content = output.elements[0]["name"]["content"] + assert "MyInterfaceX" in content + + # Remove test generated files + Path("test_printer.dot").unlink(missing_ok=True)