diff --git a/examples/flat/a.sol b/examples/flat/a.sol index d252e04ba9..4fa9c75490 100644 --- a/examples/flat/a.sol +++ b/examples/flat/a.sol @@ -1,3 +1,9 @@ -contract A{ +pragma solidity 0.8.19; -} +error RevertIt(); + +contract Example { + function reverts() external pure { + revert RevertIt(); + } +} \ No newline at end of file diff --git a/examples/flat/b.sol b/examples/flat/b.sol index 74b4d78ce5..edbd902256 100644 --- a/examples/flat/b.sol +++ b/examples/flat/b.sol @@ -1,5 +1,16 @@ import "./a.sol"; -contract B is A{ +pragma solidity 0.8.19; +enum B { + a, + b } + +contract T { + Example e = new Example(); + function b() public returns(uint) { + B b = B.a; + return 4; + } +} \ No newline at end of file diff --git a/scripts/ci_test_flat.sh b/scripts/ci_test_flat.sh index e3a837a035..0d9185171e 100755 --- a/scripts/ci_test_flat.sh +++ b/scripts/ci_test_flat.sh @@ -1,6 +1,8 @@ #!/usr/bin/env bash +shopt -s extglob -### Test slither-prop +### Test slither-flat +solc-select use 0.8.19 --always-install cd examples/flat || exit 1 @@ -8,5 +10,11 @@ if ! slither-flat b.sol; then echo "slither-flat failed" exit 1 fi + +SUFFIX="@(sol)" +if ! solc "crytic-export/flattening/"*$SUFFIX; then + echo "solc failed on flattened files" + exit 1 +fi exit 0 diff --git a/slither/core/compilation_unit.py b/slither/core/compilation_unit.py index 4550ea8948..6d24786eb1 100644 --- a/slither/core/compilation_unit.py +++ b/slither/core/compilation_unit.py @@ -13,7 +13,7 @@ Function, Modifier, ) -from slither.core.declarations.custom_error import CustomError +from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel from slither.core.declarations.enum_top_level import EnumTopLevel from slither.core.declarations.function_top_level import FunctionTopLevel from slither.core.declarations.structure_top_level import StructureTopLevel @@ -46,7 +46,7 @@ def __init__(self, core: "SlitherCore", crytic_compilation_unit: CompilationUnit self._using_for_top_level: List[UsingForTopLevel] = [] self._pragma_directives: List[Pragma] = [] self._import_directives: List[Import] = [] - self._custom_errors: List[CustomError] = [] + self._custom_errors: List[CustomErrorTopLevel] = [] self._user_defined_value_types: Dict[str, TypeAliasTopLevel] = {} self._all_functions: Set[Function] = set() @@ -216,7 +216,7 @@ def using_for_top_level(self) -> List[UsingForTopLevel]: return self._using_for_top_level @property - def custom_errors(self) -> List[CustomError]: + def custom_errors(self) -> List[CustomErrorTopLevel]: return self._custom_errors @property diff --git a/slither/core/declarations/solidity_variables.py b/slither/core/declarations/solidity_variables.py index f0e903d7b2..552cf9e7f5 100644 --- a/slither/core/declarations/solidity_variables.py +++ b/slither/core/declarations/solidity_variables.py @@ -201,6 +201,10 @@ def __init__(self, custom_error: CustomError) -> None: # pylint: disable=super- self._custom_error = custom_error self._return_type: List[Union[TypeInformation, ElementaryType]] = [] + @property + def custom_error(self) -> CustomError: + return self._custom_error + def __eq__(self, other: Any) -> bool: return ( self.__class__ == other.__class__ diff --git a/slither/tools/flattening/export/export.py b/slither/tools/flattening/export/export.py index e9b4552efb..8b8ce73559 100644 --- a/slither/tools/flattening/export/export.py +++ b/slither/tools/flattening/export/export.py @@ -15,7 +15,7 @@ Export = namedtuple("Export", ["filename", "content"]) -logger = logging.getLogger("Slither") +logger = logging.getLogger("Slither-flat") def save_to_zip(files: List[Export], zip_filename: str, zip_type: str = "lzma"): diff --git a/slither/tools/flattening/flattening.py b/slither/tools/flattening/flattening.py index 67b3c00a36..55e1af21d3 100644 --- a/slither/tools/flattening/flattening.py +++ b/slither/tools/flattening/flattening.py @@ -11,6 +11,7 @@ from slither.core.declarations.contract import Contract from slither.core.declarations.function_top_level import FunctionTopLevel from slither.core.declarations.top_level import TopLevel +from slither.core.declarations.solidity_variables import SolidityCustomRevert from slither.core.solidity_types import MappingType, ArrayType from slither.core.solidity_types.type import Type from slither.core.solidity_types.user_defined_type import UserDefinedType @@ -23,7 +24,8 @@ save_to_disk, ) -logger = logging.getLogger("Slither-flattening") +logger = logging.getLogger("Slither-flat") +logger.setLevel(logging.INFO) # index: where to start # patch_type: @@ -75,6 +77,7 @@ def __init__( self._get_source_code_top_level(compilation_unit.structures_top_level) self._get_source_code_top_level(compilation_unit.enums_top_level) + self._get_source_code_top_level(compilation_unit.custom_errors) self._get_source_code_top_level(compilation_unit.variables_top_level) self._get_source_code_top_level(compilation_unit.functions_top_level) @@ -249,12 +252,14 @@ def _export_from_type( t: Type, contract: Contract, exported: Set[str], - list_contract: List[Contract], - list_top_level: List[TopLevel], + list_contract: Set[Contract], + list_top_level: Set[TopLevel], ): if isinstance(t, UserDefinedType): t_type = t.type - if isinstance(t_type, (EnumContract, StructureContract)): + if isinstance(t_type, TopLevel): + list_top_level.add(t_type) + elif isinstance(t_type, (EnumContract, StructureContract)): if t_type.contract != contract and t_type.contract not in exported: self._export_list_used_contracts( t_type.contract, exported, list_contract, list_top_level @@ -275,8 +280,8 @@ def _export_list_used_contracts( # pylint: disable=too-many-branches self, contract: Contract, exported: Set[str], - list_contract: List[Contract], - list_top_level: List[TopLevel], + list_contract: Set[Contract], + list_top_level: Set[TopLevel], ): # TODO: investigate why this happen if not isinstance(contract, Contract): @@ -332,19 +337,21 @@ def _export_list_used_contracts( # pylint: disable=too-many-branches for read in ir.read: if isinstance(read, TopLevel): - if read not in list_top_level: - list_top_level.append(read) - if isinstance(ir, InternalCall): - function_called = ir.function - if isinstance(function_called, FunctionTopLevel): - list_top_level.append(function_called) - - if contract not in list_contract: - list_contract.append(contract) + list_top_level.add(read) + if isinstance(ir, InternalCall) and isinstance(ir.function, FunctionTopLevel): + list_top_level.add(ir.function) + if ( + isinstance(ir, SolidityCall) + and isinstance(ir.function, SolidityCustomRevert) + and isinstance(ir.function.custom_error, TopLevel) + ): + list_top_level.add(ir.function.custom_error) + + list_contract.add(contract) def _export_contract_with_inheritance(self, contract) -> Export: - list_contracts: List[Contract] = [] # will contain contract itself - list_top_level: List[TopLevel] = [] + list_contracts: Set[Contract] = set() # will contain contract itself + list_top_level: Set[TopLevel] = set() self._export_list_used_contracts(contract, set(), list_contracts, list_top_level) path = Path(self._export_path, f"{contract.name}_{uuid.uuid4()}.sol") @@ -401,8 +408,8 @@ def _export_all(self) -> List[Export]: def _export_with_import(self) -> List[Export]: exports: List[Export] = [] for contract in self._compilation_unit.contracts: - list_contracts: List[Contract] = [] # will contain contract itself - list_top_level: List[TopLevel] = [] + list_contracts: Set[Contract] = set() # will contain contract itself + list_top_level: Set[TopLevel] = set() self._export_list_used_contracts(contract, set(), list_contracts, list_top_level) if list_top_level: