Skip to content

Commit

Permalink
Merge pull request #1351 from crytic/dev-improve-reentrancy
Browse files Browse the repository at this point in the history
Improve reentrancy detectors
  • Loading branch information
montyly authored Nov 28, 2022
2 parents 3cba359 + 4c5a5a8 commit 133c1a1
Show file tree
Hide file tree
Showing 25 changed files with 7,665 additions and 6,366 deletions.
35 changes: 34 additions & 1 deletion slither/core/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
Contract module
"""
import logging
from collections import defaultdict
from pathlib import Path
from typing import Optional, List, Dict, Callable, Tuple, TYPE_CHECKING, Union
from typing import Optional, List, Dict, Callable, Tuple, TYPE_CHECKING, Union, Set

from crytic_compile.platform import Type as PlatformType

Expand Down Expand Up @@ -100,6 +101,11 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope
self.compilation_unit: "SlitherCompilationUnit" = compilation_unit
self.file_scope: "FileScope" = scope

# memoize
self._state_variables_used_in_reentrant_targets: Optional[
Dict["StateVariable", Set[Union["StateVariable", "Function"]]]
] = None

###################################################################################
###################################################################################
# region General's properties
Expand Down Expand Up @@ -356,6 +362,33 @@ def slithir_variables(self) -> List["SlithIRVariable"]:
slithir_variables = [item for sublist in slithir_variabless for item in sublist]
return list(set(slithir_variables))

@property
def state_variables_used_in_reentrant_targets(
self,
) -> Dict["StateVariable", Set[Union["StateVariable", "Function"]]]:
"""
Returns the state variables used in reentrant targets. Heuristics:
- Variable used (read/write) in entry points that are reentrant
- State variables that are public
"""
from slither.core.variables.state_variable import StateVariable

if self._state_variables_used_in_reentrant_targets is None:
reentrant_functions = [f for f in self.functions_entry_points if f.is_reentrant]
variables_used: Dict[
StateVariable, Set[Union[StateVariable, "Function"]]
] = defaultdict(set)
for function in reentrant_functions:
for ir in function.all_slithir_operations():
state_variables = [v for v in ir.used if isinstance(v, StateVariable)]
for state_variable in state_variables:
variables_used[state_variable].add(ir.node.function)
for variable in [v for v in self.state_variables if v.visibility == "public"]:
variables_used[variable].add(variable)
self._state_variables_used_in_reentrant_targets = variables_used
return self._state_variables_used_in_reentrant_targets

# endregion
###################################################################################
###################################################################################
Expand Down
48 changes: 45 additions & 3 deletions slither/core/declarations/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit"):

# set(ReacheableNode)
self._reachable_from_nodes: Set[ReacheableNode] = set()
self._reachable_from_functions: Set[ReacheableNode] = set()
self._reachable_from_functions: Set[Function] = set()
self._all_reachable_from_functions: Optional[Set[Function]] = None

# Constructor, fallback, State variable constructor
self._function_type: Optional[FunctionType] = None
Expand All @@ -214,7 +215,7 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit"):

self.compilation_unit: "SlitherCompilationUnit" = compilation_unit

# Assume we are analyzing Solidty by default
# Assume we are analyzing Solidity by default
self.function_language: FunctionLanguage = FunctionLanguage.Solidity

self._id: Optional[str] = None
Expand Down Expand Up @@ -1029,9 +1030,30 @@ def reachable_from_nodes(self) -> Set[ReacheableNode]:
return self._reachable_from_nodes

@property
def reachable_from_functions(self) -> Set[ReacheableNode]:
def reachable_from_functions(self) -> Set["Function"]:
return self._reachable_from_functions

@property
def all_reachable_from_functions(self) -> Set["Function"]:
"""
Give the recursive version of reachable_from_functions (all the functions that lead to call self in the CFG)
"""
if self._all_reachable_from_functions is None:
functions: Set["Function"] = set()

new_functions = self.reachable_from_functions
# iterate until we have are finding new functions
while new_functions and not new_functions.issubset(functions):
functions = functions.union(new_functions)
# Use a temporary set, because we iterate over new_functions
new_functionss: Set["Function"] = set()
for f in new_functions:
new_functionss = new_functionss.union(f.reachable_from_functions)
new_functions = new_functionss - functions

self._all_reachable_from_functions = functions
return self._all_reachable_from_functions

def add_reachable_from_node(self, n: "Node", ir: "Operation"):
self._reachable_from_nodes.add(ReacheableNode(n, ir))
self._reachable_from_functions.add(n.function)
Expand Down Expand Up @@ -1460,6 +1482,26 @@ def is_protected(self) -> bool:
)
return self._is_protected

@property
def is_reentrant(self) -> bool:
"""
Determine if the function can be re-entered
"""
# TODO: compare with hash of known nonReentrant modifier instead of the name
if "nonReentrant" in [m.name for m in self.modifiers]:
return False

if self.visibility in ["public", "external"]:
return True

# If it's an internal function, check if all its entry points have the nonReentrant modifier
all_entry_points = [
f for f in self.all_reachable_from_functions if f.visibility in ["public", "external"]
]
if not all_entry_points:
return True
return not all(("nonReentrant" in [m.name for m in f.modifiers] for f in all_entry_points))

# endregion
###################################################################################
###################################################################################
Expand Down
61 changes: 30 additions & 31 deletions slither/detectors/reentrancy/reentrancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,32 @@
Iterate over all the nodes of the graph until reaching a fixpoint
"""
from collections import defaultdict
from typing import Set, Dict, Union
from typing import Set, Dict, List, Tuple, Optional

from slither.core.cfg.node import NodeType, Node
from slither.core.declarations import Function
from slither.core.declarations import Function, Contract
from slither.core.expressions import UnaryOperation, UnaryOperationType
from slither.core.variables.variable import Variable
from slither.detectors.abstract_detector import AbstractDetector
from slither.slithir.operations import Call, EventCall
from slither.slithir.operations import Call, EventCall, Operation
from slither.utils.output import Output


def union_dict(d1, d2):
def union_dict(d1: Dict, d2: Dict) -> Dict:
d3 = {k: d1.get(k, set()) | d2.get(k, set()) for k in set(list(d1.keys()) + list(d2.keys()))}
return defaultdict(set, d3)


def dict_are_equal(d1, d2):
def dict_are_equal(d1: Dict, d2: Dict) -> bool:
if set(list(d1.keys())) != set(list(d2.keys())):
return False
return all(set(d1[k]) == set(d2[k]) for k in d1.keys())


def is_subset(
new_info: Dict[Union[Variable, Node], Set[Node]],
old_info: Dict[Union[Variable, Node], Set[Node]],
):
new_info: Dict,
old_info: Dict,
) -> bool:
for k in new_info.keys():
if k not in old_info:
return False
Expand All @@ -38,15 +39,15 @@ def is_subset(
return True


def to_hashable(d: Dict[Node, Set[Node]]):
def to_hashable(d: Dict[Node, Set[Node]]) -> Tuple:
list_tuple = list(
tuple((k, tuple(sorted(values, key=lambda x: x.node_id)))) for k, values in d.items()
)
return tuple(sorted(list_tuple, key=lambda x: x[0].node_id))


class AbstractState:
def __init__(self):
def __init__(self) -> None:
# send_eth returns the list of calls sending value
# calls returns the list of calls that can callback
# read returns the variable read
Expand Down Expand Up @@ -106,7 +107,9 @@ def events(self) -> Dict[EventCall, Set[Node]]:
"""
return self._events

def merge_fathers(self, node, skip_father, detector):
def merge_fathers(
self, node: Node, skip_father: Optional[Node], detector: "Reentrancy"
) -> None:
for father in node.fathers:
if detector.KEY in father.context:
self._send_eth = union_dict(
Expand All @@ -131,7 +134,7 @@ def merge_fathers(self, node, skip_father, detector):
father.context[detector.KEY].reads_prior_calls,
)

def analyze_node(self, node, detector):
def analyze_node(self, node: Node, detector: "Reentrancy") -> bool:
state_vars_read: Dict[Variable, Set[Node]] = defaultdict(
set, {v: {node} for v in node.state_variables_read}
)
Expand Down Expand Up @@ -175,13 +178,13 @@ def analyze_node(self, node, detector):

return contains_call

def add(self, fathers):
def add(self, fathers: "AbstractState") -> None:
self._send_eth = union_dict(self._send_eth, fathers.send_eth)
self._calls = union_dict(self._calls, fathers.calls)
self._reads = union_dict(self._reads, fathers.reads)
self._reads_prior_calls = union_dict(self._reads_prior_calls, fathers.reads_prior_calls)

def does_not_bring_new_info(self, new_info):
def does_not_bring_new_info(self, new_info: "AbstractState") -> bool:
if is_subset(new_info.calls, self.calls):
if is_subset(new_info.send_eth, self.send_eth):
if is_subset(new_info.reads, self.reads):
Expand All @@ -190,7 +193,7 @@ def does_not_bring_new_info(self, new_info):
return False


def _filter_if(node):
def _filter_if(node: Node) -> bool:
"""
Check if the node is a condtional node where
there is an external call checked
Expand All @@ -201,10 +204,8 @@ def _filter_if(node):
This will work only on naive implementation
"""
return (
isinstance(node.expression, UnaryOperation)
and node.expression.type == UnaryOperationType.BANG
)
expression = node.expression
return isinstance(expression, UnaryOperation) and expression.type == UnaryOperationType.BANG


class Reentrancy(AbstractDetector):
Expand All @@ -214,7 +215,7 @@ class Reentrancy(AbstractDetector):
# allowing inherited classes to define different behaviors
# For example reentrancy_no_gas consider Send and Transfer as reentrant functions
@staticmethod
def can_callback(ir):
def can_callback(ir: Operation) -> bool:
"""
Detect if the node contains a call that can
be used to re-entrance
Expand All @@ -228,13 +229,13 @@ def can_callback(ir):
return isinstance(ir, Call) and ir.can_reenter()

@staticmethod
def can_send_eth(ir):
def can_send_eth(ir: Operation) -> bool:
"""
Detect if the node can send eth
"""
return isinstance(ir, Call) and ir.can_send_eth()

def _explore(self, node, visited, skip_father=None):
def _explore(self, node: Optional[Node], skip_father: Optional[Node] = None) -> None:
"""
Explore the CFG and look for re-entrancy
Heuristic: There is a re-entrancy if a state variable is written
Expand All @@ -245,11 +246,9 @@ def _explore(self, node, visited, skip_father=None):
if node.context is not empty, and variables are written, a re-entrancy is possible
"""
if node in visited:
if node is None:
return

visited = visited + [node]

fathers_context = AbstractState()
fathers_context.merge_fathers(node, skip_father, self)

Expand All @@ -271,26 +270,26 @@ def _explore(self, node, visited, skip_father=None):
if contains_call and node.type in [NodeType.IF, NodeType.IFLOOP]:
if _filter_if(node):
son = sons[0]
self._explore(son, visited, node)
self._explore(son, skip_father=node)
sons = sons[1:]
else:
son = sons[1]
self._explore(son, visited, node)
self._explore(son, skip_father=node)
sons = [sons[0]]

for son in sons:
self._explore(son, visited)
self._explore(son)

def detect_reentrancy(self, contract):
def detect_reentrancy(self, contract: Contract) -> None:
for function in contract.functions_and_modifiers_declared:
if not function.is_constructor:
if function.is_implemented:
if self.KEY in function.context:
continue
self._explore(function.entry_point, [])
self._explore(function.entry_point)
function.context[self.KEY] = True

def _detect(self):
def _detect(self) -> List[Output]:
""""""
# if a node was already visited by another path
# we will only explore it if the traversal brings
Expand Down
Loading

0 comments on commit 133c1a1

Please sign in to comment.