From d04fc9e1034b86d96790acf11aec0be212c693c9 Mon Sep 17 00:00:00 2001 From: webthethird Date: Mon, 27 Feb 2023 12:09:13 -0600 Subject: [PATCH 01/50] Start `slither.utils.upgradeability` --- slither/utils/upgradeability.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 slither/utils/upgradeability.py diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py new file mode 100644 index 0000000000..d92dc3d6ea --- /dev/null +++ b/slither/utils/upgradeability.py @@ -0,0 +1,32 @@ +from slither.core.declarations.contract import Contract + + +def compare(v1: Contract, v2: Contract): + order_vars1 = [v for v in v1.state_variables if not v.is_constant and not v.is_immutable] + order_vars2 = [v for v in v2.state_variables if not v.is_constant and not v.is_immutable] + func_sigs1 = [function.solidity_signature for function in v1.functions] + func_sigs2 = [function.solidity_signature for function in v2.functions] + + results = { + "missing-vars-in-v2": [], + "new-variables": [], + "tainted-variables": [], + "new-functions": [], + "modified-functions": [], + "tainted-functions": [] + } + + if len(order_vars2) <= len(order_vars1): + for variable in order_vars1: + if variable.name not in [v.name for v in order_vars2]: + results["missing-vars-in-v2"].append(variable) + + new_modified_functions = [] + for sig in func_sigs2: + function = v2.get_function_from_signature(sig) + if sig not in func_sigs1: + new_modified_functions.append(function) + results["new-functions"].append(function) + else: + orig_function = v1.get_function_from_signature(sig) + if function From f08d2afb3ede7fc2016413755408cad422bcec51 Mon Sep 17 00:00:00 2001 From: webthethird Date: Mon, 27 Feb 2023 12:45:26 -0600 Subject: [PATCH 02/50] Implement `compare(v1: Contract, v2: Contract)` in `slither.utils.upgradeability` --- slither/utils/upgradeability.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index d92dc3d6ea..be4fe313bf 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -16,11 +16,13 @@ def compare(v1: Contract, v2: Contract): "tainted-functions": [] } + # Since this is not a detector, include any missing variables in the v2 contract if len(order_vars2) <= len(order_vars1): for variable in order_vars1: if variable.name not in [v.name for v in order_vars2]: results["missing-vars-in-v2"].append(variable) + # Find all new and modified functions in the v2 contract new_modified_functions = [] for sig in func_sigs2: function = v2.get_function_from_signature(sig) @@ -29,4 +31,31 @@ def compare(v1: Contract, v2: Contract): results["new-functions"].append(function) else: orig_function = v1.get_function_from_signature(sig) - if function + # If the function content hashes are the same, no need to investigate the function further + if function.source_mapping.content_hash != orig_function.source_mapping.content_hash: + # If the hashes differ, it is possible a change in a name or in a comment could be the only difference + # So we need to resort to walking through the CFG and comparing the IR operations + for i, node in enumerate(function.nodes): + if function in new_modified_functions: + break + for j, ir in enumerate(node.irs): + if ir != orig_function.nodes[i].irs[j]: + new_modified_functions.append(function) + results["modified-functions"].append(function) + + # Find all unmodified functions that call a modified function, i.e., tainted functions + for function in v2.functions: + if function in new_modified_functions: + continue + modified_calls = [funct for func in new_modified_functions if func in function.internal_calls] + if len(modified_calls) > 0: + results["tainted-functions"].append(function) + + # Find all new or tainted variables, i.e., variables that are read or written by a new/modified function + for idx, var in enumerate(order_vars2): + read_by = v2.get_functions_reading_from_variable(var) + written_by = v2.get_functions_writing_to_variable(var) + if len(order_vars1) <= idx: + results["new-variables"].append(var) + elif any([func in read_by or func in written_by for func in new_modified_functions]): + results["tainted-variables"].append(var) From ebd2201bdd26266b5fb7618d3c3e37508a4d96e1 Mon Sep 17 00:00:00 2001 From: webthethird Date: Mon, 27 Feb 2023 13:13:02 -0600 Subject: [PATCH 03/50] Pylint cut down on branches and nested blocks by adding an `is_function_modified` helper function --- slither/utils/upgradeability.py | 73 +++++++++++++++++++++++++-------- 1 file changed, 56 insertions(+), 17 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index be4fe313bf..9eedc6dfec 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -1,7 +1,28 @@ from slither.core.declarations.contract import Contract +from slither.core.declarations.function import Function -def compare(v1: Contract, v2: Contract): +# pylint: disable=too-many-locals +def compare(v1: Contract, v2: Contract) -> dict: + """ + Compares two versions of a contract. Most useful for upgradeable (logic) contracts, + but does not require that Contract.is_upgradeable returns true for either contract. + + Args: + v1: Original version of (upgradeable) contract + v2: Updated version of (upgradeable) contract + + Returns: dict { + "missing-vars-in-v2": list[Variable], + "new-variables": list[Variable], + "tainted-variables": list[Variable], + "new-functions": list[Function], + "modified-functions": list[Function], + "tainted-functions": list[Function] + } + + """ + order_vars1 = [v for v in v1.state_variables if not v.is_constant and not v.is_immutable] order_vars2 = [v for v in v2.state_variables if not v.is_constant and not v.is_immutable] func_sigs1 = [function.solidity_signature for function in v1.functions] @@ -13,7 +34,7 @@ def compare(v1: Contract, v2: Contract): "tainted-variables": [], "new-functions": [], "modified-functions": [], - "tainted-functions": [] + "tainted-functions": [], } # Since this is not a detector, include any missing variables in the v2 contract @@ -26,28 +47,21 @@ def compare(v1: Contract, v2: Contract): new_modified_functions = [] for sig in func_sigs2: function = v2.get_function_from_signature(sig) + orig_function = v1.get_function_from_signature(sig) if sig not in func_sigs1: new_modified_functions.append(function) results["new-functions"].append(function) - else: - orig_function = v1.get_function_from_signature(sig) - # If the function content hashes are the same, no need to investigate the function further - if function.source_mapping.content_hash != orig_function.source_mapping.content_hash: - # If the hashes differ, it is possible a change in a name or in a comment could be the only difference - # So we need to resort to walking through the CFG and comparing the IR operations - for i, node in enumerate(function.nodes): - if function in new_modified_functions: - break - for j, ir in enumerate(node.irs): - if ir != orig_function.nodes[i].irs[j]: - new_modified_functions.append(function) - results["modified-functions"].append(function) + elif is_function_modified(orig_function, function): + new_modified_functions.append(function) + results["modified-functions"].append(function) # Find all unmodified functions that call a modified function, i.e., tainted functions for function in v2.functions: if function in new_modified_functions: continue - modified_calls = [funct for func in new_modified_functions if func in function.internal_calls] + modified_calls = [ + func for func in new_modified_functions if func in function.internal_calls + ] if len(modified_calls) > 0: results["tainted-functions"].append(function) @@ -57,5 +71,30 @@ def compare(v1: Contract, v2: Contract): written_by = v2.get_functions_writing_to_variable(var) if len(order_vars1) <= idx: results["new-variables"].append(var) - elif any([func in read_by or func in written_by for func in new_modified_functions]): + elif any(func in read_by or func in written_by for func in new_modified_functions): results["tainted-variables"].append(var) + + +def is_function_modified(f1: Function, f2: Function) -> bool: + """ + Compares two versions of a function, and returns True if the function has been modified. + First checks whether the functions' content hashes are equal to quickly rule out identical functions. + Walks the CFGs and compares IR operations if hashes differ to rule out false positives, i.e., from changed comments. + + Args: + f1: Original version of the function + f2: New version of the function + + Returns: True if the functions differ, otherwise False + + """ + # If the function content hashes are the same, no need to investigate the function further + if f1.source_mapping.content_hash == f2.source_mapping.content_hash: + return False + # If the hashes differ, it is possible a change in a name or in a comment could be the only difference + # So we need to resort to walking through the CFG and comparing the IR operations + for i, node in enumerate(f2.nodes): + for j, ir in enumerate(node.irs): + if ir != f1.nodes[i].irs[j]: + return True + return False From 04c71c24bbdc6344a920b6dcd8b0d6434ca6325f Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 28 Feb 2023 12:13:06 -0600 Subject: [PATCH 04/50] Add return statement (whoops!) --- slither/utils/upgradeability.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 9eedc6dfec..9750157435 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -74,6 +74,8 @@ def compare(v1: Contract, v2: Contract) -> dict: elif any(func in read_by or func in written_by for func in new_modified_functions): results["tainted-variables"].append(var) + return results + def is_function_modified(f1: Function, f2: Function) -> bool: """ From 5b361e82876c165f69364ecd3c4973f74fe4d16c Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 28 Feb 2023 12:27:35 -0600 Subject: [PATCH 05/50] Also consider an unmodified function tainted if it reads/writes the same state variable(s) as a new/modified function --- slither/utils/upgradeability.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 9750157435..d0655d3245 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -45,6 +45,7 @@ def compare(v1: Contract, v2: Contract) -> dict: # Find all new and modified functions in the v2 contract new_modified_functions = [] + new_modified_function_vars = [] for sig in func_sigs2: function = v2.get_function_from_signature(sig) orig_function = v1.get_function_from_signature(sig) @@ -54,15 +55,24 @@ def compare(v1: Contract, v2: Contract) -> dict: elif is_function_modified(orig_function, function): new_modified_functions.append(function) results["modified-functions"].append(function) + else: + continue + for var in function.state_variables_read + function.state_variables_written: + if var not in new_modified_function_vars: + new_modified_function_vars.append(var) - # Find all unmodified functions that call a modified function, i.e., tainted functions + # Find all unmodified functions that call a modified function or read/write the + # same state variable(s) as a new/modified function, i.e., tainted functions for function in v2.functions: if function in new_modified_functions: continue modified_calls = [ func for func in new_modified_functions if func in function.internal_calls ] - if len(modified_calls) > 0: + tainted_vars = [ + var for var in new_modified_function_vars if var in function.variables_read_or_written + ] + if len(modified_calls) > 0 or len(tainted_vars) > 0: results["tainted-functions"].append(function) # Find all new or tainted variables, i.e., variables that are read or written by a new/modified function From 6b9d21abc23f7588a28e546482515f9549ff3340 Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 28 Feb 2023 12:40:43 -0600 Subject: [PATCH 06/50] Make pylint happy (reduce branches) --- slither/utils/upgradeability.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index d0655d3245..2f7263f45e 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -52,14 +52,11 @@ def compare(v1: Contract, v2: Contract) -> dict: if sig not in func_sigs1: new_modified_functions.append(function) results["new-functions"].append(function) + new_modified_function_vars += function.state_variables_read + function.state_variables_written elif is_function_modified(orig_function, function): new_modified_functions.append(function) results["modified-functions"].append(function) - else: - continue - for var in function.state_variables_read + function.state_variables_written: - if var not in new_modified_function_vars: - new_modified_function_vars.append(var) + new_modified_function_vars += function.state_variables_read + function.state_variables_written # Find all unmodified functions that call a modified function or read/write the # same state variable(s) as a new/modified function, i.e., tainted functions From 3757601640053d4fdc609f12bb591b78923194b6 Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 28 Feb 2023 13:00:15 -0600 Subject: [PATCH 07/50] Avoid duplicates, constants and immutables when finding functions tainted by `new_modified_function_vars` --- slither/utils/upgradeability.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 2f7263f45e..8b4f2a2150 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -52,11 +52,15 @@ def compare(v1: Contract, v2: Contract) -> dict: if sig not in func_sigs1: new_modified_functions.append(function) results["new-functions"].append(function) - new_modified_function_vars += function.state_variables_read + function.state_variables_written + new_modified_function_vars += ( + function.state_variables_read + function.state_variables_written + ) elif is_function_modified(orig_function, function): new_modified_functions.append(function) results["modified-functions"].append(function) - new_modified_function_vars += function.state_variables_read + function.state_variables_written + new_modified_function_vars += ( + function.state_variables_read + function.state_variables_written + ) # Find all unmodified functions that call a modified function or read/write the # same state variable(s) as a new/modified function, i.e., tainted functions @@ -67,7 +71,11 @@ def compare(v1: Contract, v2: Contract) -> dict: func for func in new_modified_functions if func in function.internal_calls ] tainted_vars = [ - var for var in new_modified_function_vars if var in function.variables_read_or_written + var + for var in set(new_modified_function_vars) + if var in function.variables_read_or_written + and not var.is_constant + and not var.is_immutable ] if len(modified_calls) > 0 or len(tainted_vars) > 0: results["tainted-functions"].append(function) From 770ca81368728b91b7b54be65b057a2d4e277bf6 Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 28 Feb 2023 13:05:16 -0600 Subject: [PATCH 08/50] Avoid constructor when finding functions tainted by `new_modified_functions` --- slither/utils/upgradeability.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 8b4f2a2150..45bc78bdad 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -65,7 +65,7 @@ def compare(v1: Contract, v2: Contract) -> dict: # Find all unmodified functions that call a modified function or read/write the # same state variable(s) as a new/modified function, i.e., tainted functions for function in v2.functions: - if function in new_modified_functions: + if function in new_modified_functions or function.is_constructor: continue modified_calls = [ func for func in new_modified_functions if func in function.internal_calls From 596b4d08657593ed97b5df847a19a520d658724c Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 28 Feb 2023 13:49:56 -0600 Subject: [PATCH 09/50] Avoid `slitherConstructorConstantVariables()` when finding modified functions --- slither/utils/upgradeability.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 45bc78bdad..d8bc645f6e 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -65,7 +65,11 @@ def compare(v1: Contract, v2: Contract) -> dict: # Find all unmodified functions that call a modified function or read/write the # same state variable(s) as a new/modified function, i.e., tainted functions for function in v2.functions: - if function in new_modified_functions or function.is_constructor: + if ( + function in new_modified_functions + or function.is_constructor + or function.name.startswith("slither") + ): continue modified_calls = [ func for func in new_modified_functions if func in function.internal_calls From 480ee60cc36a119fb18e0f32640ad96bc25fe1cb Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 28 Feb 2023 13:51:01 -0600 Subject: [PATCH 10/50] Avoid `slitherConstructorConstantVariables()` when finding modified functions --- slither/utils/upgradeability.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index d8bc645f6e..22a076ff32 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -55,7 +55,9 @@ def compare(v1: Contract, v2: Contract) -> dict: new_modified_function_vars += ( function.state_variables_read + function.state_variables_written ) - elif is_function_modified(orig_function, function): + elif not function.name.startswith("slither") and is_function_modified( + orig_function, function + ): new_modified_functions.append(function) results["modified-functions"].append(function) new_modified_function_vars += ( From 37554ecda32473c40670d7f2d32c40b93500b11b Mon Sep 17 00:00:00 2001 From: webthethird Date: Sat, 11 Mar 2023 13:13:27 -0600 Subject: [PATCH 11/50] Bump to re-run CI tests --- slither/utils/upgradeability.py | 1 - 1 file changed, 1 deletion(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 22a076ff32..64143a9eb1 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -20,7 +20,6 @@ def compare(v1: Contract, v2: Contract) -> dict: "modified-functions": list[Function], "tainted-functions": list[Function] } - """ order_vars1 = [v for v in v1.state_variables if not v.is_constant and not v.is_immutable] From f1947bb8e132cd4be512de010c1b3aa9d79f620c Mon Sep 17 00:00:00 2001 From: webthethird Date: Sun, 12 Mar 2023 17:29:44 -0500 Subject: [PATCH 12/50] Add additional upgradeability utils like `get_proxy_implementation_slot(proxy: Contract)` --- slither/utils/upgradeability.py | 105 ++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 64143a9eb1..7a384726eb 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -1,5 +1,16 @@ +from typing import Optional +from slither.analyses.data_dependency.data_dependency import get_dependencies from slither.core.declarations.contract import Contract from slither.core.declarations.function import Function +from slither.core.variables.variable import Variable +from slither.core.variables.state_variable import StateVariable +from slither.core.variables.local_variable import LocalVariable +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.call_expression import CallExpression +from slither.core.expressions.assignment_operation import AssignmentOperation +from slither.core.cfg.node import Node, NodeType +from slither.slithir.operations import LowLevelCall +from slither.tools.read_storage.read_storage import SlotInfo, SlitherReadStorage # pylint: disable=too-many-locals @@ -120,3 +131,97 @@ def is_function_modified(f1: Function, f2: Function) -> bool: if ir != f1.nodes[i].irs[j]: return True return False + + +def get_proxy_implementation_slot(proxy: Contract) -> Optional[SlotInfo]: + available_functions = proxy.available_functions_as_dict() + + if not proxy.is_upgradeable_proxy or not available_functions["fallback()"]: + return None + + delegate: Optional[Variable] = find_delegate_in_fallback(proxy) + + if isinstance(delegate, LocalVariable): + dependencies = get_dependencies(delegate, proxy) + delegate = next(var for var in dependencies if isinstance(var, StateVariable)) + if isinstance(delegate, StateVariable): + if not delegate.is_constant and not delegate.is_immutable: + srs = SlitherReadStorage([proxy], 20) + return srs.get_storage_slot(delegate, proxy) + if delegate.is_constant and delegate.type.name == "bytes32": + return SlotInfo( + name=delegate.name, + type_string="address", + slot=int(delegate.expression.value, 16), + size=160, + offset=0, + ) + return None + + +def find_delegate_in_fallback(proxy: Contract) -> Optional[Variable]: + delegate: Optional[Variable] = None + fallback = proxy.available_functions_as_dict()["fallback()"] + for node in fallback.all_nodes(): + for ir in node.irs: + if isinstance(ir, LowLevelCall) and ir.function_name == "delegatecall": + delegate = ir.destination + if delegate is not None: + break + if ( + node.type == NodeType.ASSEMBLY + and isinstance(node.inline_asm, str) + and "delegatecall" in node.inline_asm + ): + delegate = extract_delegate_from_asm(proxy, node) + elif node.type == NodeType.EXPRESSION: + expression = node.expression + if isinstance(expression, AssignmentOperation): + expression = expression.expression_right + if ( + isinstance(expression, CallExpression) + and "delegatecall" in str(expression.called) + and len(expression.arguments) > 1 + ): + dest = expression.arguments[1] + if isinstance(dest, Identifier): + delegate = dest.value + return delegate + + +def extract_delegate_from_asm(contract: Contract, node: Node) -> Optional[Variable]: + asm_split = str(node.inline_asm).split("\n") + asm = next(line for line in asm_split if "delegatecall" in line) + params = asm.split("call(")[1].split(", ") + dest = params[1] + if dest.endswith(")"): + dest = params[2] + if dest.startswith("sload("): + dest = dest.replace(")", "(").split("(")[1] + for v in node.function.variables_read_or_written: + if v.name == dest: + if isinstance(v, LocalVariable) and v.expression is not None: + e = v.expression + if isinstance(e, Identifier) and isinstance(e.value, StateVariable): + v = e.value + # Fall through, return constant storage slot + if isinstance(v, StateVariable) and v.is_constant: + return v + if "_fallback_asm" in dest or "_slot" in dest: + dest = dest.split("_")[0] + return find_delegate_from_name(contract, dest, node.function) + + +def find_delegate_from_name( + contract: Contract, dest: str, parent_func: Function +) -> Optional[Variable]: + for sv in contract.state_variables: + if sv.name == dest: + return sv + for lv in parent_func.local_variables: + if lv.name == dest: + return lv + for pv in parent_func.parameters: + if pv.name == dest: + return pv + return None From 8181faeec539ca4b96d14713492953b14d49eb6b Mon Sep 17 00:00:00 2001 From: webthethird Date: Sun, 12 Mar 2023 18:14:42 -0500 Subject: [PATCH 13/50] Add docstrings --- slither/utils/upgradeability.py | 45 +++++++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 7a384726eb..621e6e9577 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -118,8 +118,8 @@ def is_function_modified(f1: Function, f2: Function) -> bool: f1: Original version of the function f2: New version of the function - Returns: True if the functions differ, otherwise False - + Returns: + True if the functions differ, otherwise False """ # If the function content hashes are the same, no need to investigate the function further if f1.source_mapping.content_hash == f2.source_mapping.content_hash: @@ -134,10 +134,14 @@ def is_function_modified(f1: Function, f2: Function) -> bool: def get_proxy_implementation_slot(proxy: Contract) -> Optional[SlotInfo]: - available_functions = proxy.available_functions_as_dict() + """ + Gets information about the storage slot where a proxy's implementation address is stored. + Args: + proxy: A Contract object (proxy.is_upgradeable_proxy should be true). - if not proxy.is_upgradeable_proxy or not available_functions["fallback()"]: - return None + Returns: + (`SlotInfo`) | None : A dictionary of the slot information. + """ delegate: Optional[Variable] = find_delegate_in_fallback(proxy) @@ -160,6 +164,15 @@ def get_proxy_implementation_slot(proxy: Contract) -> Optional[SlotInfo]: def find_delegate_in_fallback(proxy: Contract) -> Optional[Variable]: + """ + Searches a proxy's fallback function for a delegatecall, then extracts the Variable being passed in as the target. + Should typically be called by get_proxy_implementation_var(proxy). + Args: + proxy: A Contract object (should have a fallback function). + + Returns: + (`Variable`) | None : The variable being passed as the destination argument in a delegatecall in the fallback. + """ delegate: Optional[Variable] = None fallback = proxy.available_functions_as_dict()["fallback()"] for node in fallback.all_nodes(): @@ -190,6 +203,17 @@ def find_delegate_in_fallback(proxy: Contract) -> Optional[Variable]: def extract_delegate_from_asm(contract: Contract, node: Node) -> Optional[Variable]: + """ + Finds a Variable with a name matching the argument passed into a delegatecall, when all we have is an Assembly node + with a block of code as one long string. Usually only the case for solc versions < 0.6.0. + Should typically be called by find_delegate_in_fallback(proxy). + Args: + contract: The parent Contract. + node: The Assembly Node (i.e., node.type == NodeType.ASSEMBLY) + + Returns: + (`Variable`) | None : The variable being passed as the destination argument in a delegatecall in the fallback. + """ asm_split = str(node.inline_asm).split("\n") asm = next(line for line in asm_split if "delegatecall" in line) params = asm.split("call(")[1].split(", ") @@ -215,6 +239,17 @@ def extract_delegate_from_asm(contract: Contract, node: Node) -> Optional[Variab def find_delegate_from_name( contract: Contract, dest: str, parent_func: Function ) -> Optional[Variable]: + """ + Searches for a variable with a given name, starting with StateVariables declared in the contract, followed by + LocalVariables in the parent function, either declared in the function body or as parameters in the signature. + Args: + contract: The Contract object to search. + dest: The variable name to search for. + parent_func: The Function object to search. + + Returns: + (`Variable`) | None : The variable with the matching name, if found + """ for sv in contract.state_variables: if sv.name == dest: return sv From 4c571684415b818de518fae39ef96bc7ec6b4ab0 Mon Sep 17 00:00:00 2001 From: webthethird Date: Sun, 12 Mar 2023 18:15:57 -0500 Subject: [PATCH 14/50] Separate `get_proxy_implementation_var` out of `get_proxy_implementation_slot` since either one could be more useful. --- slither/utils/upgradeability.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 621e6e9577..ed2e29ac14 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -143,11 +143,7 @@ def get_proxy_implementation_slot(proxy: Contract) -> Optional[SlotInfo]: (`SlotInfo`) | None : A dictionary of the slot information. """ - delegate: Optional[Variable] = find_delegate_in_fallback(proxy) - - if isinstance(delegate, LocalVariable): - dependencies = get_dependencies(delegate, proxy) - delegate = next(var for var in dependencies if isinstance(var, StateVariable)) + delegate = get_proxy_implementation_var(proxy) if isinstance(delegate, StateVariable): if not delegate.is_constant and not delegate.is_immutable: srs = SlitherReadStorage([proxy], 20) @@ -163,6 +159,30 @@ def get_proxy_implementation_slot(proxy: Contract) -> Optional[SlotInfo]: return None +def get_proxy_implementation_var(proxy: Contract) -> Optional[Variable]: + """ + Gets the Variable that stores a proxy's implementation address. Uses data dependency to trace any LocalVariable + that is passed into a delegatecall as the target address back to its data source, ideally a StateVariable. + Args: + proxy: A Contract object (proxy.is_upgradeable_proxy should be true). + + Returns: + (`Variable`) | None : The variable, ideally a StateVariable, which stores the proxy's implementation address. + """ + available_functions = proxy.available_functions_as_dict() + if not proxy.is_upgradeable_proxy or not available_functions["fallback()"]: + return None + + delegate = find_delegate_in_fallback(proxy) + if isinstance(delegate, LocalVariable): + dependencies = get_dependencies(delegate, proxy) + try: + delegate = next(var for var in dependencies if isinstance(var, StateVariable)) + except: + return delegate + return delegate + + def find_delegate_in_fallback(proxy: Contract) -> Optional[Variable]: """ Searches a proxy's fallback function for a delegatecall, then extracts the Variable being passed in as the target. From 09128abf11fb223d98ea33a6150d3c81fbf11fa2 Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 14 Mar 2023 12:27:28 -0500 Subject: [PATCH 15/50] pylint --- slither/utils/upgradeability.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index ed2e29ac14..ce1c43937d 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -178,7 +178,7 @@ def get_proxy_implementation_var(proxy: Contract) -> Optional[Variable]: dependencies = get_dependencies(delegate, proxy) try: delegate = next(var for var in dependencies if isinstance(var, StateVariable)) - except: + except StopIteration: return delegate return delegate From 3655708367f1d3fba86dcab3f0714cad59b06f94 Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 14 Mar 2023 14:32:34 -0500 Subject: [PATCH 16/50] Redesign `utils.upgradeability.is_function_modified` Due to the non-deterministic order of `Function.all_nodes()` --- slither/utils/upgradeability.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index ce1c43937d..06918981d5 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -11,6 +11,7 @@ from slither.core.cfg.node import Node, NodeType from slither.slithir.operations import LowLevelCall from slither.tools.read_storage.read_storage import SlotInfo, SlitherReadStorage +from slither.tools.similarity.encode import encode_ir # pylint: disable=too-many-locals @@ -126,9 +127,17 @@ def is_function_modified(f1: Function, f2: Function) -> bool: return False # If the hashes differ, it is possible a change in a name or in a comment could be the only difference # So we need to resort to walking through the CFG and comparing the IR operations - for i, node in enumerate(f2.nodes): - for j, ir in enumerate(node.irs): - if ir != f1.nodes[i].irs[j]: + queue_f1 = [f1.entry_point] + queue_f2 = [f2.entry_point] + visited = [] + while len(queue_f1) > 0 and len(queue_f2) > 0: + node_f1 = queue_f1.pop(0) + node_f2 = queue_f2.pop(0) + visited.extend([node_f1, node_f2]) + queue_f1.extend(son for son in node_f1.sons if son not in visited) + queue_f2.extend(son for son in node_f2.sons if son not in visited) + for i, ir in enumerate(node_f1.irs): + if encode_ir(ir) != encode_ir(node_f2.irs[i]): return True return False From c195a15da09b8c3a552e9e159ad168f4cbf62c52 Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 14 Mar 2023 15:06:51 -0500 Subject: [PATCH 17/50] Handle `sload` from slot in `delegatecall` args i.e., `delegatecall(gas(), sload(0x3608...), 0, calldatasize(), 0, 0)` where the slot is not defined as a bytes32 constant but rather is hardcoded in the fallback. --- slither/utils/upgradeability.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 06918981d5..1c4da1b5fb 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -5,9 +5,11 @@ from slither.core.variables.variable import Variable from slither.core.variables.state_variable import StateVariable from slither.core.variables.local_variable import LocalVariable +from slither.core.expressions.literal import Literal from slither.core.expressions.identifier import Identifier from slither.core.expressions.call_expression import CallExpression from slither.core.expressions.assignment_operation import AssignmentOperation +from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.cfg.node import Node, NodeType from slither.slithir.operations import LowLevelCall from slither.tools.read_storage.read_storage import SlotInfo, SlitherReadStorage @@ -226,8 +228,18 @@ def find_delegate_in_fallback(proxy: Contract) -> Optional[Variable]: and len(expression.arguments) > 1 ): dest = expression.arguments[1] + if isinstance(dest, CallExpression) and "sload" in str(dest.called): + dest = dest.arguments[0] if isinstance(dest, Identifier): delegate = dest.value + break + if isinstance(dest, Literal) and len(dest.value) == 66: + delegate = StateVariable() + delegate.is_constant = True + delegate.expression = dest + delegate.name = dest.value + delegate.type = ElementaryType("bytes32") + break return delegate From 317af452adf4b583c6a5af27b14ed72430455a20 Mon Sep 17 00:00:00 2001 From: webthethird Date: Thu, 16 Mar 2023 14:24:38 -0500 Subject: [PATCH 18/50] Minor bug fixes --- slither/utils/upgradeability.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 1c4da1b5fb..d52bb65146 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -51,7 +51,7 @@ def compare(v1: Contract, v2: Contract) -> dict: } # Since this is not a detector, include any missing variables in the v2 contract - if len(order_vars2) <= len(order_vars1): + if len(order_vars2) < len(order_vars1): for variable in order_vars1: if variable.name not in [v.name for v in order_vars2]: results["missing-vars-in-v2"].append(variable) @@ -99,11 +99,11 @@ def compare(v1: Contract, v2: Contract) -> dict: if len(modified_calls) > 0 or len(tainted_vars) > 0: results["tainted-functions"].append(function) - # Find all new or tainted variables, i.e., variables that are read or written by a new/modified function - for idx, var in enumerate(order_vars2): + # Find all new or tainted variables, i.e., variables that are read or written by a new/modified/tainted function + for _, var in enumerate(order_vars2): read_by = v2.get_functions_reading_from_variable(var) written_by = v2.get_functions_writing_to_variable(var) - if len(order_vars1) <= idx: + if v1.get_state_variable_from_name(var.name) is None: results["new-variables"].append(var) elif any(func in read_by or func in written_by for func in new_modified_functions): results["tainted-variables"].append(var) From 9f4be7d7fbfcc429e9e4ed6d9b069216958f4675 Mon Sep 17 00:00:00 2001 From: webthethird Date: Thu, 16 Mar 2023 14:24:57 -0500 Subject: [PATCH 19/50] Include variables touched by tainted functions --- slither/utils/upgradeability.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index d52bb65146..f69b336f54 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -105,7 +105,10 @@ def compare(v1: Contract, v2: Contract) -> dict: written_by = v2.get_functions_writing_to_variable(var) if v1.get_state_variable_from_name(var.name) is None: results["new-variables"].append(var) - elif any(func in read_by or func in written_by for func in new_modified_functions): + elif any( + func in read_by or func in written_by + for func in new_modified_functions + results["tainted-functions"] + ): results["tainted-variables"].append(var) return results From 6ff59f9c9a8a7edc4ff960185cb30d5397dabed9 Mon Sep 17 00:00:00 2001 From: webthethird Date: Thu, 16 Mar 2023 14:27:12 -0500 Subject: [PATCH 20/50] Copy/paste and tweak `encode_ir` to compare ir originally from slither.tools.simil.encode Needed to tweak how binary operations are encoded --- slither/utils/upgradeability.py | 173 ++++++++++++++++++++++++++++++-- 1 file changed, 165 insertions(+), 8 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index f69b336f54..3d4c64c04f 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -1,19 +1,62 @@ from typing import Optional +from slither.core.declarations import ( + Contract, + Structure, + Enum, + SolidityVariableComposed, + SolidityVariable, + Function, +) +from slither.core.solidity_types import ( + ElementaryType, + ArrayType, + MappingType, + UserDefinedType, +) +from slither.core.variables.local_variable import LocalVariable +from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple +from slither.core.variables.state_variable import StateVariable from slither.analyses.data_dependency.data_dependency import get_dependencies -from slither.core.declarations.contract import Contract -from slither.core.declarations.function import Function from slither.core.variables.variable import Variable -from slither.core.variables.state_variable import StateVariable -from slither.core.variables.local_variable import LocalVariable from slither.core.expressions.literal import Literal from slither.core.expressions.identifier import Identifier from slither.core.expressions.call_expression import CallExpression from slither.core.expressions.assignment_operation import AssignmentOperation -from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.cfg.node import Node, NodeType -from slither.slithir.operations import LowLevelCall +from slither.slithir.operations import ( + Assignment, + Index, + Member, + Length, + Binary, + Unary, + Condition, + NewArray, + NewStructure, + NewContract, + NewElementaryType, + SolidityCall, + Delete, + EventCall, + LibraryCall, + InternalDynamicCall, + HighLevelCall, + LowLevelCall, + TypeConversion, + Return, + Transfer, + Send, + Unpack, + InitArray, + InternalCall, +) +from slither.slithir.variables import ( + TemporaryVariable, + TupleVariable, + Constant, + ReferenceVariable, +) from slither.tools.read_storage.read_storage import SlotInfo, SlitherReadStorage -from slither.tools.similarity.encode import encode_ir # pylint: disable=too-many-locals @@ -142,11 +185,125 @@ def is_function_modified(f1: Function, f2: Function) -> bool: queue_f1.extend(son for son in node_f1.sons if son not in visited) queue_f2.extend(son for son in node_f2.sons if son not in visited) for i, ir in enumerate(node_f1.irs): - if encode_ir(ir) != encode_ir(node_f2.irs[i]): + if encode_ir_for_compare(ir) != encode_ir_for_compare(node_f2.irs[i]): return True return False +def ntype(_type): # pylint: disable=too-many-branches + if isinstance(_type, ElementaryType): + _type = str(_type) + elif isinstance(_type, ArrayType): + if isinstance(_type.type, ElementaryType): + _type = str(_type) + else: + _type = "user_defined_array" + elif isinstance(_type, Structure): + _type = str(_type) + elif isinstance(_type, Enum): + _type = str(_type) + elif isinstance(_type, MappingType): + _type = str(_type) + elif isinstance(_type, UserDefinedType): + _type = "user_defined_type" # TODO: this could be Contract, Enum or Struct + else: + _type = str(_type) + + _type = _type.replace(" memory", "") + _type = _type.replace(" storage ref", "") + + if "struct" in _type: + return "struct" + if "enum" in _type: + return "enum" + if "tuple" in _type: + return "tuple" + if "contract" in _type: + return "contract" + if "mapping" in _type: + return "mapping" + return _type.replace(" ", "_") + + +def encode_ir_for_compare(ir) -> str: # pylint: disable=too-many-branches + # operations + if isinstance(ir, Assignment): + return f"({encode_ir_for_compare(ir.lvalue)}):=({encode_ir_for_compare(ir.rvalue)})" + if isinstance(ir, Index): + return f"index({ntype(ir.index_type)})" + if isinstance(ir, Member): + return "member" # .format(ntype(ir._type)) + if isinstance(ir, Length): + return "length" + if isinstance(ir, Binary): + return f"binary({str(ir.variable_left)}{str(ir.type)}{str(ir.variable_right)})" + if isinstance(ir, Unary): + return f"unary({str(ir.type)})" + if isinstance(ir, Condition): + return f"condition({encode_ir_for_compare(ir.value)})" + if isinstance(ir, NewStructure): + return "new_structure" + if isinstance(ir, NewContract): + return "new_contract" + if isinstance(ir, NewArray): + return f"new_array({ntype(ir.array_type)})" + if isinstance(ir, NewElementaryType): + return f"new_elementary({ntype(ir.type)})" + if isinstance(ir, Delete): + return f"delete({encode_ir_for_compare(ir.lvalue)},{encode_ir_for_compare(ir.variable)})" + if isinstance(ir, SolidityCall): + return f"solidity_call({ir.function.full_name})" + if isinstance(ir, InternalCall): + return f"internal_call({ntype(ir.type_call)})" + if isinstance(ir, EventCall): # is this useful? + return "event" + if isinstance(ir, LibraryCall): + return "library_call" + if isinstance(ir, InternalDynamicCall): + return "internal_dynamic_call" + if isinstance(ir, HighLevelCall): # TODO: improve + return "high_level_call" + if isinstance(ir, LowLevelCall): # TODO: improve + return "low_level_call" + if isinstance(ir, TypeConversion): + return f"type_conversion({ntype(ir.type)})" + if isinstance(ir, Return): # this can be improved using values + return "return" # .format(ntype(ir.type)) + if isinstance(ir, Transfer): + return f"transfer({encode_ir_for_compare(ir.call_value)})" + if isinstance(ir, Send): + return f"send({encode_ir_for_compare(ir.call_value)})" + if isinstance(ir, Unpack): # TODO: improve + return "unpack" + if isinstance(ir, InitArray): # TODO: improve + return "init_array" + if isinstance(ir, Function): # TODO: investigate this + return "function_solc" + + # variables + if isinstance(ir, Constant): + return f"constant({ntype(ir.type)})" + if isinstance(ir, SolidityVariableComposed): + return f"solidity_variable_composed({ir.name})" + if isinstance(ir, SolidityVariable): + return f"solidity_variable{ir.name}" + if isinstance(ir, TemporaryVariable): + return "temporary_variable" + if isinstance(ir, ReferenceVariable): + return f"reference({ntype(ir.type)})" + if isinstance(ir, LocalVariable): + return f"local_solc_variable({ir.location})" + if isinstance(ir, StateVariable): + return f"state_solc_variable({ntype(ir.type)})" + if isinstance(ir, LocalVariableInitFromTuple): + return "local_variable_init_tuple" + if isinstance(ir, TupleVariable): + return "tuple_variable" + + # default + return "" + + def get_proxy_implementation_slot(proxy: Contract) -> Optional[SlotInfo]: """ Gets information about the storage slot where a proxy's implementation address is stored. From 701c8f3770c2adf27561aba3b505aa6d4635d083 Mon Sep 17 00:00:00 2001 From: webthethird Date: Thu, 16 Mar 2023 14:31:10 -0500 Subject: [PATCH 21/50] Add test for slither.utils.upgradeability.compare And contracts to test it with, including some for future tests --- tests/test_upgradeability_util.py | 70 +++++ .../TEST_upgrade_diff.json | 20 ++ tests/upgradeability-util/TestUpgrades.sol | 6 + tests/upgradeability-util/src/Address.sol | 244 ++++++++++++++++++ tests/upgradeability-util/src/ContractV1.sol | 36 +++ tests/upgradeability-util/src/ContractV2.sol | 41 +++ .../upgradeability-util/src/ERC1967Proxy.sol | 15 ++ .../src/ERC1967Upgrade.sol | 105 ++++++++ .../src/InheritedStorageProxy.sol | 39 +++ tests/upgradeability-util/src/Proxy.sol | 36 +++ .../upgradeability-util/src/ProxyStorage.sol | 6 + tests/upgradeability-util/src/StorageSlot.sol | 88 +++++++ 12 files changed, 706 insertions(+) create mode 100644 tests/test_upgradeability_util.py create mode 100644 tests/upgradeability-util/TEST_upgrade_diff.json create mode 100644 tests/upgradeability-util/TestUpgrades.sol create mode 100644 tests/upgradeability-util/src/Address.sol create mode 100644 tests/upgradeability-util/src/ContractV1.sol create mode 100644 tests/upgradeability-util/src/ContractV2.sol create mode 100644 tests/upgradeability-util/src/ERC1967Proxy.sol create mode 100644 tests/upgradeability-util/src/ERC1967Upgrade.sol create mode 100644 tests/upgradeability-util/src/InheritedStorageProxy.sol create mode 100644 tests/upgradeability-util/src/Proxy.sol create mode 100644 tests/upgradeability-util/src/ProxyStorage.sol create mode 100644 tests/upgradeability-util/src/StorageSlot.sol diff --git a/tests/test_upgradeability_util.py b/tests/test_upgradeability_util.py new file mode 100644 index 0000000000..6fbffc899a --- /dev/null +++ b/tests/test_upgradeability_util.py @@ -0,0 +1,70 @@ +import os +import json + +from solc_select import solc_select +from deepdiff import DeepDiff + +from slither import Slither +from slither.core.declarations import Function +from slither.core.variables import StateVariable +from slither.utils.upgradeability import compare + +SLITHER_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +UPGRADE_TEST_ROOT = os.path.join(SLITHER_ROOT, "tests", "upgradeability-util") + + +def test_upgrades_compare() -> None: + solc_select.switch_global_version("0.8.2", always_install=True) + + sl = Slither(os.path.join(UPGRADE_TEST_ROOT, "TestUpgrades.sol")) + v1 = sl.get_contract_from_name("ContractV1")[0] + v2 = sl.get_contract_from_name("ContractV2")[0] + diff = compare(v1, v2) + for key in diff.keys(): + if len(diff[key]) > 0: + print(f' * {str(key).replace("-", " ")}:') + for obj in diff[key]: + if isinstance(obj, StateVariable): + print(f" * {obj.full_name}") + elif isinstance(obj, Function): + print(f" * {obj.signature_str}") + with open("upgrade_diff.json", "w", encoding="utf-8") as file: + json_str = diff_to_json_str(diff) + diff_json = json.loads(json_str) + json.dump(diff_json, file, indent=4) + + expected_file = os.path.join(UPGRADE_TEST_ROOT, "TEST_upgrade_diff.json") + actual_file = os.path.join(SLITHER_ROOT, "upgrade_diff.json") + + with open(expected_file, "r", encoding="utf8") as f: + expected = json.load(f) + with open(actual_file, "r", encoding="utf8") as f: + actual = json.load(f) + + diff = DeepDiff(expected, actual, ignore_order=True, verbose_level=2, view="tree") + if diff: + for change in diff.get("values_changed", []): + path_list = re.findall(r"\['(.*?)'\]", change.path()) + path = "_".join(path_list) + with open(f"{path}_expected.txt", "w", encoding="utf8") as f: + f.write(str(change.t1)) + with open(f"{path}_actual.txt", "w", encoding="utf8") as f: + f.write(str(change.t2)) + + assert not diff + + +def diff_to_json_str(diff: dict) -> str: + out: dict = {} + for key in diff.keys(): + out[key] = [] + for obj in diff[key]: + if isinstance(obj, StateVariable): + out[key].append(obj.canonical_name) + elif isinstance(obj, Function): + out[key].append(obj.signature_str) + return str(out).replace("'", '"') + + +def __main__(): + test_upgradeability_util() diff --git a/tests/upgradeability-util/TEST_upgrade_diff.json b/tests/upgradeability-util/TEST_upgrade_diff.json new file mode 100644 index 0000000000..4ccc4bd3fa --- /dev/null +++ b/tests/upgradeability-util/TEST_upgrade_diff.json @@ -0,0 +1,20 @@ +{ + "missing-vars-in-v2": [], + "new-variables": [ + "ContractV2.stateC" + ], + "tainted-variables": [ + "ContractV2.stateB", + "ContractV2.bug" + ], + "new-functions": [ + "i() returns()" + ], + "modified-functions": [ + "checkB() returns(bool)" + ], + "tainted-functions": [ + "g(uint256) returns()", + "h() returns()" + ] +} \ No newline at end of file diff --git a/tests/upgradeability-util/TestUpgrades.sol b/tests/upgradeability-util/TestUpgrades.sol new file mode 100644 index 0000000000..d3371d3c68 --- /dev/null +++ b/tests/upgradeability-util/TestUpgrades.sol @@ -0,0 +1,6 @@ +pragma solidity ^0.8.2; + +import "./src/ContractV1.sol"; +import "./src/ContractV2.sol"; +import "./src/InheritedStorageProxy.sol"; +import "./src/ERC1967Proxy.sol"; diff --git a/tests/upgradeability-util/src/Address.sol b/tests/upgradeability-util/src/Address.sol new file mode 100644 index 0000000000..d440b259ee --- /dev/null +++ b/tests/upgradeability-util/src/Address.sol @@ -0,0 +1,244 @@ +// SPDX-License-Identifier: MIT +// OpenZeppelin Contracts (last updated v4.8.0) (utils/Address.sol) + +pragma solidity ^0.8.1; + +/** + * @dev Collection of functions related to the address type + */ +library Address { + /** + * @dev Returns true if `account` is a contract. + * + * [IMPORTANT] + * ==== + * It is unsafe to assume that an address for which this function returns + * false is an externally-owned account (EOA) and not a contract. + * + * Among others, `isContract` will return false for the following + * types of addresses: + * + * - an externally-owned account + * - a contract in construction + * - an address where a contract will be created + * - an address where a contract lived, but was destroyed + * ==== + * + * [IMPORTANT] + * ==== + * You shouldn't rely on `isContract` to protect against flash loan attacks! + * + * Preventing calls from contracts is highly discouraged. It breaks composability, breaks support for smart wallets + * like Gnosis Safe, and does not provide security since it can be circumvented by calling from a contract + * constructor. + * ==== + */ + function isContract(address account) internal view returns (bool) { + // This method relies on extcodesize/address.code.length, which returns 0 + // for contracts in construction, since the code is only stored at the end + // of the constructor execution. + + return account.code.length > 0; + } + + /** + * @dev Replacement for Solidity's `transfer`: sends `amount` wei to + * `recipient`, forwarding all available gas and reverting on errors. + * + * https://eips.ethereum.org/EIPS/eip-1884[EIP1884] increases the gas cost + * of certain opcodes, possibly making contracts go over the 2300 gas limit + * imposed by `transfer`, making them unable to receive funds via + * `transfer`. {sendValue} removes this limitation. + * + * https://diligence.consensys.net/posts/2019/09/stop-using-soliditys-transfer-now/[Learn more]. + * + * IMPORTANT: because control is transferred to `recipient`, care must be + * taken to not create reentrancy vulnerabilities. Consider using + * {ReentrancyGuard} or the + * https://solidity.readthedocs.io/en/v0.5.11/security-considerations.html#use-the-checks-effects-interactions-pattern[checks-effects-interactions pattern]. + */ + function sendValue(address payable recipient, uint256 amount) internal { + require(address(this).balance >= amount, "Address: insufficient balance"); + + (bool success, ) = recipient.call{value: amount}(""); + require(success, "Address: unable to send value, recipient may have reverted"); + } + + /** + * @dev Performs a Solidity function call using a low level `call`. A + * plain `call` is an unsafe replacement for a function call: use this + * function instead. + * + * If `target` reverts with a revert reason, it is bubbled up by this + * function (like regular Solidity function calls). + * + * Returns the raw returned data. To convert to the expected return value, + * use https://solidity.readthedocs.io/en/latest/units-and-global-variables.html?highlight=abi.decode#abi-encoding-and-decoding-functions[`abi.decode`]. + * + * Requirements: + * + * - `target` must be a contract. + * - calling `target` with `data` must not revert. + * + * _Available since v3.1._ + */ + function functionCall(address target, bytes memory data) internal returns (bytes memory) { + return functionCallWithValue(target, data, 0, "Address: low-level call failed"); + } + + /** + * @dev Same as {xref-Address-functionCall-address-bytes-}[`functionCall`], but with + * `errorMessage` as a fallback revert reason when `target` reverts. + * + * _Available since v3.1._ + */ + function functionCall( + address target, + bytes memory data, + string memory errorMessage + ) internal returns (bytes memory) { + return functionCallWithValue(target, data, 0, errorMessage); + } + + /** + * @dev Same as {xref-Address-functionCall-address-bytes-}[`functionCall`], + * but also transferring `value` wei to `target`. + * + * Requirements: + * + * - the calling contract must have an ETH balance of at least `value`. + * - the called Solidity function must be `payable`. + * + * _Available since v3.1._ + */ + function functionCallWithValue( + address target, + bytes memory data, + uint256 value + ) internal returns (bytes memory) { + return functionCallWithValue(target, data, value, "Address: low-level call with value failed"); + } + + /** + * @dev Same as {xref-Address-functionCallWithValue-address-bytes-uint256-}[`functionCallWithValue`], but + * with `errorMessage` as a fallback revert reason when `target` reverts. + * + * _Available since v3.1._ + */ + function functionCallWithValue( + address target, + bytes memory data, + uint256 value, + string memory errorMessage + ) internal returns (bytes memory) { + require(address(this).balance >= value, "Address: insufficient balance for call"); + (bool success, bytes memory returndata) = target.call{value: value}(data); + return verifyCallResultFromTarget(target, success, returndata, errorMessage); + } + + /** + * @dev Same as {xref-Address-functionCall-address-bytes-}[`functionCall`], + * but performing a static call. + * + * _Available since v3.3._ + */ + function functionStaticCall(address target, bytes memory data) internal view returns (bytes memory) { + return functionStaticCall(target, data, "Address: low-level static call failed"); + } + + /** + * @dev Same as {xref-Address-functionCall-address-bytes-string-}[`functionCall`], + * but performing a static call. + * + * _Available since v3.3._ + */ + function functionStaticCall( + address target, + bytes memory data, + string memory errorMessage + ) internal view returns (bytes memory) { + (bool success, bytes memory returndata) = target.staticcall(data); + return verifyCallResultFromTarget(target, success, returndata, errorMessage); + } + + /** + * @dev Same as {xref-Address-functionCall-address-bytes-}[`functionCall`], + * but performing a delegate call. + * + * _Available since v3.4._ + */ + function functionDelegateCall(address target, bytes memory data) internal returns (bytes memory) { + return functionDelegateCall(target, data, "Address: low-level delegate call failed"); + } + + /** + * @dev Same as {xref-Address-functionCall-address-bytes-string-}[`functionCall`], + * but performing a delegate call. + * + * _Available since v3.4._ + */ + function functionDelegateCall( + address target, + bytes memory data, + string memory errorMessage + ) internal returns (bytes memory) { + (bool success, bytes memory returndata) = target.delegatecall(data); + return verifyCallResultFromTarget(target, success, returndata, errorMessage); + } + + /** + * @dev Tool to verify that a low level call to smart-contract was successful, and revert (either by bubbling + * the revert reason or using the provided one) in case of unsuccessful call or if target was not a contract. + * + * _Available since v4.8._ + */ + function verifyCallResultFromTarget( + address target, + bool success, + bytes memory returndata, + string memory errorMessage + ) internal view returns (bytes memory) { + if (success) { + if (returndata.length == 0) { + // only check isContract if the call was successful and the return data is empty + // otherwise we already know that it was a contract + require(isContract(target), "Address: call to non-contract"); + } + return returndata; + } else { + _revert(returndata, errorMessage); + } + } + + /** + * @dev Tool to verify that a low level call was successful, and revert if it wasn't, either by bubbling the + * revert reason or using the provided one. + * + * _Available since v4.3._ + */ + function verifyCallResult( + bool success, + bytes memory returndata, + string memory errorMessage + ) internal pure returns (bytes memory) { + if (success) { + return returndata; + } else { + _revert(returndata, errorMessage); + } + } + + function _revert(bytes memory returndata, string memory errorMessage) private pure { + // Look for revert reason and bubble it up if present + if (returndata.length > 0) { + // The easiest way to bubble the revert reason is using memory via assembly + /// @solidity memory-safe-assembly + assembly { + let returndata_size := mload(returndata) + revert(add(32, returndata), returndata_size) + } + } else { + revert(errorMessage); + } + } +} diff --git a/tests/upgradeability-util/src/ContractV1.sol b/tests/upgradeability-util/src/ContractV1.sol new file mode 100644 index 0000000000..1e2c4b476e --- /dev/null +++ b/tests/upgradeability-util/src/ContractV1.sol @@ -0,0 +1,36 @@ +pragma solidity ^0.8.2; + +import "./ProxyStorage.sol"; + +contract ContractV1 is ProxyStorage { + uint private stateA = 0; + uint private stateB = 0; + uint constant CONST = 32; + bool bug = false; + + function f(uint x) public { + if (msg.sender == admin) { + stateA = x; + } + } + + function g(uint y) public { + if (checkA()) { + stateB = y - 10; + } + } + + function h() public { + if (checkB()) { + bug = true; + } + } + + function checkA() internal returns (bool) { + return stateA % CONST == 1; + } + + function checkB() internal returns (bool) { + return stateB == 62; + } +} diff --git a/tests/upgradeability-util/src/ContractV2.sol b/tests/upgradeability-util/src/ContractV2.sol new file mode 100644 index 0000000000..9b102f3e9f --- /dev/null +++ b/tests/upgradeability-util/src/ContractV2.sol @@ -0,0 +1,41 @@ +pragma solidity ^0.8.2; + +import "./ProxyStorage.sol"; + +contract ContractV2 is ProxyStorage { + uint private stateA = 0; + uint private stateB = 0; + uint constant CONST = 32; + bool bug = false; + uint private stateC = 0; + + function f(uint x) public { + if (msg.sender == admin) { + stateA = x; + } + } + + function g(uint y) public { + if (checkA()) { + stateB = y - 10; + } + } + + function h() public { + if (checkB()) { + bug = true; + } + } + + function i() public { + stateC = stateC + 1; + } + + function checkA() internal returns (bool) { + return stateA % CONST == 1; + } + + function checkB() internal returns (bool) { + return stateB == 32; + } +} diff --git a/tests/upgradeability-util/src/ERC1967Proxy.sol b/tests/upgradeability-util/src/ERC1967Proxy.sol new file mode 100644 index 0000000000..f1496c27e1 --- /dev/null +++ b/tests/upgradeability-util/src/ERC1967Proxy.sol @@ -0,0 +1,15 @@ +pragma solidity ^0.8.0; + +import "./Proxy.sol"; +import "./ERC1967Upgrade.sol"; + +contract ERC1967Proxy is Proxy, ERC1967Upgrade { + + constructor(address _logic, bytes memory _data) payable { + _upgradeToAndCall(_logic, _data, false); + } + + function _implementation() internal view virtual override returns (address impl) { + return ERC1967Upgrade._getImplementation(); + } +} diff --git a/tests/upgradeability-util/src/ERC1967Upgrade.sol b/tests/upgradeability-util/src/ERC1967Upgrade.sol new file mode 100644 index 0000000000..d089e94d9d --- /dev/null +++ b/tests/upgradeability-util/src/ERC1967Upgrade.sol @@ -0,0 +1,105 @@ +pragma solidity ^0.8.2; + +import "./Address.sol"; +import "./StorageSlot.sol"; + +interface IBeacon { + function implementation() external view returns (address); +} + +interface IERC1822Proxiable { + function proxiableUUID() external view returns (bytes32); +} + +abstract contract ERC1967Upgrade { + + bytes32 private constant _ROLLBACK_SLOT = 0x4910fdfa16fed3260ed0e7147f7cc6da11a60208b5b9406d12a635614ffd9143; + bytes32 internal constant _IMPLEMENTATION_SLOT = 0x360894a13ba1a3210667c828492db98dca3e2076cc3735a920a3ca505d382bbc; + bytes32 internal constant _ADMIN_SLOT = 0xb53127684a568b3173ae13b9f8a6016e243e63b6e8ee1178d6a717850b5d6103; + bytes32 internal constant _BEACON_SLOT = 0xa3f0ad74e5423aebfd80d3ef4346578335a9a72aeaee59ff6cb3582b35133d50; + + event Upgraded(address indexed implementation); + event AdminChanged(address previousAdmin, address newAdmin); + event BeaconUpgraded(address indexed beacon); + + function _getImplementation() internal view returns (address) { + return StorageSlot.getAddressSlot(_IMPLEMENTATION_SLOT).value; + } + + function _setImplementation(address newImplementation) private { + require(Address.isContract(newImplementation), "ERC1967: new implementation is not a contract"); + StorageSlot.getAddressSlot(_IMPLEMENTATION_SLOT).value = newImplementation; + } + + function _upgradeTo(address newImplementation) internal { + _setImplementation(newImplementation); + emit Upgraded(newImplementation); + } + + function _upgradeToAndCall( + address newImplementation, + bytes memory data, + bool forceCall + ) internal { + _upgradeTo(newImplementation); + if (data.length > 0 || forceCall) { + Address.functionDelegateCall(newImplementation, data); + } + } + + function _upgradeToAndCallUUPS( + address newImplementation, + bytes memory data, + bool forceCall + ) internal { + if (StorageSlot.getBooleanSlot(_ROLLBACK_SLOT).value) { + _setImplementation(newImplementation); + } else { + try IERC1822Proxiable(newImplementation).proxiableUUID() returns (bytes32 slot) { + require(slot == _IMPLEMENTATION_SLOT, "ERC1967Upgrade: unsupported proxiableUUID"); + } catch { + revert("ERC1967Upgrade: new implementation is not UUPS"); + } + _upgradeToAndCall(newImplementation, data, forceCall); + } + } + + function _getAdmin() internal view returns (address) { + return StorageSlot.getAddressSlot(_ADMIN_SLOT).value; + } + + function _setAdmin(address newAdmin) private { + require(newAdmin != address(0), "ERC1967: new admin is the zero address"); + StorageSlot.getAddressSlot(_ADMIN_SLOT).value = newAdmin; + } + + function _changeAdmin(address newAdmin) internal { + emit AdminChanged(_getAdmin(), newAdmin); + _setAdmin(newAdmin); + } + + function _getBeacon() internal view returns (address) { + return StorageSlot.getAddressSlot(_BEACON_SLOT).value; + } + + function _setBeacon(address newBeacon) private { + require(Address.isContract(newBeacon), "ERC1967: new beacon is not a contract"); + require( + Address.isContract(IBeacon(newBeacon).implementation()), + "ERC1967: beacon implementation is not a contract" + ); + StorageSlot.getAddressSlot(_BEACON_SLOT).value = newBeacon; + } + + function _upgradeBeaconToAndCall( + address newBeacon, + bytes memory data, + bool forceCall + ) internal { + _setBeacon(newBeacon); + emit BeaconUpgraded(newBeacon); + if (data.length > 0 || forceCall) { + Address.functionDelegateCall(IBeacon(newBeacon).implementation(), data); + } + } +} \ No newline at end of file diff --git a/tests/upgradeability-util/src/InheritedStorageProxy.sol b/tests/upgradeability-util/src/InheritedStorageProxy.sol new file mode 100644 index 0000000000..eddbfb0f12 --- /dev/null +++ b/tests/upgradeability-util/src/InheritedStorageProxy.sol @@ -0,0 +1,39 @@ +pragma solidity ^0.8.0; + +import "./Proxy.sol"; +import "./ProxyStorage.sol"; + +contract InheritedStorageProxy is Proxy, ProxyStorage { + constructor(address _implementation) { + admin = msg.sender; + implementation = _implementation; + } + + function getImplementation() external view returns (address) { + return _implementation(); + } + + function getAdmin() external view returns (address) { + return _admin(); + } + + function upgrade(address _newImplementation) external { + require(msg.sender == admin, "Only admin can upgrade"); + implementation = _newImplementation; + } + + function setAdmin(address _newAdmin) external { + require(msg.sender == admin, "Only current admin can change admin"); + admin = _newAdmin; + } + + function _implementation() internal view override returns (address) { + return implementation; + } + + function _admin() internal view returns (address) { + return admin; + } + + function _beforeFallback() internal override {} +} diff --git a/tests/upgradeability-util/src/Proxy.sol b/tests/upgradeability-util/src/Proxy.sol new file mode 100644 index 0000000000..445ddb1704 --- /dev/null +++ b/tests/upgradeability-util/src/Proxy.sol @@ -0,0 +1,36 @@ +pragma solidity ^0.8.0; + +abstract contract Proxy { + + function _delegate(address implementation) internal virtual { + assembly { + calldatacopy(0, 0, calldatasize()) + let result := delegatecall(gas(), implementation, 0, calldatasize(), 0, 0) + returndatacopy(0, 0, returndatasize()) + switch result + case 0 { + revert(0, returndatasize()) + } + default { + return(0, returndatasize()) + } + } + } + + function _implementation() internal view virtual returns (address); + + function _fallback() internal virtual { + _beforeFallback(); + _delegate(_implementation()); + } + + fallback() external payable virtual { + _fallback(); + } + + receive() external payable virtual { + _fallback(); + } + + function _beforeFallback() internal virtual {} +} diff --git a/tests/upgradeability-util/src/ProxyStorage.sol b/tests/upgradeability-util/src/ProxyStorage.sol new file mode 100644 index 0000000000..d591040bd2 --- /dev/null +++ b/tests/upgradeability-util/src/ProxyStorage.sol @@ -0,0 +1,6 @@ +pragma solidity ^0.8.0; + +contract ProxyStorage { + address internal admin; + address internal implementation; +} diff --git a/tests/upgradeability-util/src/StorageSlot.sol b/tests/upgradeability-util/src/StorageSlot.sol new file mode 100644 index 0000000000..6ab8f5dc6b --- /dev/null +++ b/tests/upgradeability-util/src/StorageSlot.sol @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// OpenZeppelin Contracts (last updated v4.7.0) (utils/StorageSlot.sol) + +pragma solidity ^0.8.0; + +/** + * @dev Library for reading and writing primitive types to specific storage slots. + * + * Storage slots are often used to avoid storage conflict when dealing with upgradeable contracts. + * This library helps with reading and writing to such slots without the need for inline assembly. + * + * The functions in this library return Slot structs that contain a `value` member that can be used to read or write. + * + * Example usage to set ERC1967 implementation slot: + * ``` + * contract ERC1967 { + * bytes32 internal constant _IMPLEMENTATION_SLOT = 0x360894a13ba1a3210667c828492db98dca3e2076cc3735a920a3ca505d382bbc; + * + * function _getImplementation() internal view returns (address) { + * return StorageSlot.getAddressSlot(_IMPLEMENTATION_SLOT).value; + * } + * + * function _setImplementation(address newImplementation) internal { + * require(Address.isContract(newImplementation), "ERC1967: new implementation is not a contract"); + * StorageSlot.getAddressSlot(_IMPLEMENTATION_SLOT).value = newImplementation; + * } + * } + * ``` + * + * _Available since v4.1 for `address`, `bool`, `bytes32`, and `uint256`._ + */ +library StorageSlot { + struct AddressSlot { + address value; + } + + struct BooleanSlot { + bool value; + } + + struct Bytes32Slot { + bytes32 value; + } + + struct Uint256Slot { + uint256 value; + } + + /** + * @dev Returns an `AddressSlot` with member `value` located at `slot`. + */ + function getAddressSlot(bytes32 slot) internal pure returns (AddressSlot storage r) { + /// @solidity memory-safe-assembly + assembly { + r.slot := slot + } + } + + /** + * @dev Returns an `BooleanSlot` with member `value` located at `slot`. + */ + function getBooleanSlot(bytes32 slot) internal pure returns (BooleanSlot storage r) { + /// @solidity memory-safe-assembly + assembly { + r.slot := slot + } + } + + /** + * @dev Returns an `Bytes32Slot` with member `value` located at `slot`. + */ + function getBytes32Slot(bytes32 slot) internal pure returns (Bytes32Slot storage r) { + /// @solidity memory-safe-assembly + assembly { + r.slot := slot + } + } + + /** + * @dev Returns an `Uint256Slot` with member `value` located at `slot`. + */ + function getUint256Slot(bytes32 slot) internal pure returns (Uint256Slot storage r) { + /// @solidity memory-safe-assembly + assembly { + r.slot := slot + } + } +} From a035d271e03ced8d52423cda70528e638fe1ab94 Mon Sep 17 00:00:00 2001 From: webthethird Date: Thu, 16 Mar 2023 14:40:10 -0500 Subject: [PATCH 22/50] pylint --- tests/test_upgradeability_util.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/test_upgradeability_util.py b/tests/test_upgradeability_util.py index 6fbffc899a..9e673d949a 100644 --- a/tests/test_upgradeability_util.py +++ b/tests/test_upgradeability_util.py @@ -1,5 +1,6 @@ import os import json +import re from solc_select import solc_select from deepdiff import DeepDiff @@ -13,6 +14,7 @@ UPGRADE_TEST_ROOT = os.path.join(SLITHER_ROOT, "tests", "upgradeability-util") +# pylint: disable=too-many-locals def test_upgrades_compare() -> None: solc_select.switch_global_version("0.8.2", always_install=True) @@ -20,10 +22,10 @@ def test_upgrades_compare() -> None: v1 = sl.get_contract_from_name("ContractV1")[0] v2 = sl.get_contract_from_name("ContractV2")[0] diff = compare(v1, v2) - for key in diff.keys(): - if len(diff[key]) > 0: + for key, lst in diff.items(): + if len(lst) > 0: print(f' * {str(key).replace("-", " ")}:') - for obj in diff[key]: + for obj in lst: if isinstance(obj, StateVariable): print(f" * {obj.full_name}") elif isinstance(obj, Function): @@ -64,7 +66,3 @@ def diff_to_json_str(diff: dict) -> str: elif isinstance(obj, Function): out[key].append(obj.signature_str) return str(out).replace("'", '"') - - -def __main__(): - test_upgradeability_util() From c5b54635354388f2a08c5e228485628cbc6fc6ce Mon Sep 17 00:00:00 2001 From: webthethird Date: Thu, 16 Mar 2023 14:44:03 -0500 Subject: [PATCH 23/50] Rename shadowed var --- tests/test_upgradeability_util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_upgradeability_util.py b/tests/test_upgradeability_util.py index 9e673d949a..9d3daba9c1 100644 --- a/tests/test_upgradeability_util.py +++ b/tests/test_upgradeability_util.py @@ -21,8 +21,8 @@ def test_upgrades_compare() -> None: sl = Slither(os.path.join(UPGRADE_TEST_ROOT, "TestUpgrades.sol")) v1 = sl.get_contract_from_name("ContractV1")[0] v2 = sl.get_contract_from_name("ContractV2")[0] - diff = compare(v1, v2) - for key, lst in diff.items(): + diff_dict = compare(v1, v2) + for key, lst in diff_dict.items(): if len(lst) > 0: print(f' * {str(key).replace("-", " ")}:') for obj in lst: @@ -31,7 +31,7 @@ def test_upgrades_compare() -> None: elif isinstance(obj, Function): print(f" * {obj.signature_str}") with open("upgrade_diff.json", "w", encoding="utf-8") as file: - json_str = diff_to_json_str(diff) + json_str = diff_to_json_str(diff_dict) diff_json = json.loads(json_str) json.dump(diff_json, file, indent=4) From 5622230117ba2d45eb2be7e8acc5f51f26c07823 Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 09:33:15 -0500 Subject: [PATCH 24/50] Return six lists instead of dictionary of lists --- slither/utils/upgradeability.py | 34 +++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 3d4c64c04f..5f1a395f53 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple from slither.core.declarations import ( Contract, Structure, @@ -60,7 +60,11 @@ # pylint: disable=too-many-locals -def compare(v1: Contract, v2: Contract) -> dict: +def compare( + v1: Contract, v2: Contract +) -> Tuple[ + list[Variable], list[Variable], list[Variable], list[Function], list[Function], list[Function] +]: """ Compares two versions of a contract. Most useful for upgradeable (logic) contracts, but does not require that Contract.is_upgradeable returns true for either contract. @@ -69,14 +73,13 @@ def compare(v1: Contract, v2: Contract) -> dict: v1: Original version of (upgradeable) contract v2: Updated version of (upgradeable) contract - Returns: dict { - "missing-vars-in-v2": list[Variable], - "new-variables": list[Variable], - "tainted-variables": list[Variable], - "new-functions": list[Function], - "modified-functions": list[Function], - "tainted-functions": list[Function] - } + Returns: + missing-vars-in-v2: list[Variable], + new-variables: list[Variable], + tainted-variables: list[Variable], + new-functions: list[Function], + modified-functions: list[Function], + tainted-functions: list[Function] """ order_vars1 = [v for v in v1.state_variables if not v.is_constant and not v.is_immutable] @@ -143,7 +146,7 @@ def compare(v1: Contract, v2: Contract) -> dict: results["tainted-functions"].append(function) # Find all new or tainted variables, i.e., variables that are read or written by a new/modified/tainted function - for _, var in enumerate(order_vars2): + for var in order_vars2: read_by = v2.get_functions_reading_from_variable(var) written_by = v2.get_functions_writing_to_variable(var) if v1.get_state_variable_from_name(var.name) is None: @@ -154,7 +157,14 @@ def compare(v1: Contract, v2: Contract) -> dict: ): results["tainted-variables"].append(var) - return results + return ( + results["missing-vars-in-v2"], + results["new-variables"], + results["tainted-variables"], + results["new-functions"], + results["modified-functions"], + results["tainted-functions"], + ) def is_function_modified(f1: Function, f2: Function) -> bool: From c159f56d70f0c0a4a4f20250f662a07168596ffc Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 09:33:55 -0500 Subject: [PATCH 25/50] Rewrite test for `utils.upgradeability` --- tests/test_upgradeability_util.py | 58 +++++++------------------------ 1 file changed, 13 insertions(+), 45 deletions(-) diff --git a/tests/test_upgradeability_util.py b/tests/test_upgradeability_util.py index 9d3daba9c1..72fb8c4ece 100644 --- a/tests/test_upgradeability_util.py +++ b/tests/test_upgradeability_util.py @@ -21,48 +21,16 @@ def test_upgrades_compare() -> None: sl = Slither(os.path.join(UPGRADE_TEST_ROOT, "TestUpgrades.sol")) v1 = sl.get_contract_from_name("ContractV1")[0] v2 = sl.get_contract_from_name("ContractV2")[0] - diff_dict = compare(v1, v2) - for key, lst in diff_dict.items(): - if len(lst) > 0: - print(f' * {str(key).replace("-", " ")}:') - for obj in lst: - if isinstance(obj, StateVariable): - print(f" * {obj.full_name}") - elif isinstance(obj, Function): - print(f" * {obj.signature_str}") - with open("upgrade_diff.json", "w", encoding="utf-8") as file: - json_str = diff_to_json_str(diff_dict) - diff_json = json.loads(json_str) - json.dump(diff_json, file, indent=4) - - expected_file = os.path.join(UPGRADE_TEST_ROOT, "TEST_upgrade_diff.json") - actual_file = os.path.join(SLITHER_ROOT, "upgrade_diff.json") - - with open(expected_file, "r", encoding="utf8") as f: - expected = json.load(f) - with open(actual_file, "r", encoding="utf8") as f: - actual = json.load(f) - - diff = DeepDiff(expected, actual, ignore_order=True, verbose_level=2, view="tree") - if diff: - for change in diff.get("values_changed", []): - path_list = re.findall(r"\['(.*?)'\]", change.path()) - path = "_".join(path_list) - with open(f"{path}_expected.txt", "w", encoding="utf8") as f: - f.write(str(change.t1)) - with open(f"{path}_actual.txt", "w", encoding="utf8") as f: - f.write(str(change.t2)) - - assert not diff - - -def diff_to_json_str(diff: dict) -> str: - out: dict = {} - for key in diff.keys(): - out[key] = [] - for obj in diff[key]: - if isinstance(obj, StateVariable): - out[key].append(obj.canonical_name) - elif isinstance(obj, Function): - out[key].append(obj.signature_str) - return str(out).replace("'", '"') + missing_vars, new_vars, tainted_vars, new_funcs, modified_funcs, tainted_funcs = compare(v1, v2) + assert len(missing_vars) == 0 + assert new_vars == [v2.get_state_variable_from_name("stateC")] + assert tainted_vars == [ + v2.get_state_variable_from_name("stateB"), + v2.get_state_variable_from_name("bug") + ] + assert new_funcs == [v2.get_function_from_signature("i()")] + assert modified_funcs == [v2.get_function_from_signature("checkB()")] + assert tainted_funcs == [ + v2.get_function_from_signature("g(uint256)"), + v2.get_function_from_signature("h()") + ] From a6f6fc0bfa8ada650ab508ad2387b7dd150bf0a6 Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 09:40:17 -0500 Subject: [PATCH 26/50] Use `.state_variables_ordered` and `.is_constructor_variables` --- slither/utils/upgradeability.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 5f1a395f53..f948138766 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -82,8 +82,8 @@ def compare( tainted-functions: list[Function] """ - order_vars1 = [v for v in v1.state_variables if not v.is_constant and not v.is_immutable] - order_vars2 = [v for v in v2.state_variables if not v.is_constant and not v.is_immutable] + order_vars1 = [v for v in v1.state_variables_ordered if not v.is_constant and not v.is_immutable] + order_vars2 = [v for v in v2.state_variables_ordered if not v.is_constant and not v.is_immutable] func_sigs1 = [function.solidity_signature for function in v1.functions] func_sigs2 = [function.solidity_signature for function in v2.functions] @@ -114,7 +114,7 @@ def compare( new_modified_function_vars += ( function.state_variables_read + function.state_variables_written ) - elif not function.name.startswith("slither") and is_function_modified( + elif not function.is_constructor_variables and is_function_modified( orig_function, function ): new_modified_functions.append(function) From af6727c5610a38a326aecb73f887d48987c1ffa2 Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 09:52:29 -0500 Subject: [PATCH 27/50] Add `Contract.fallback_function` and `.receive_function` properties --- slither/core/declarations/contract.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index 95b05aa6b9..88405b98a3 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -89,6 +89,9 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope self._signatures: Optional[List[str]] = None self._signatures_declared: Optional[List[str]] = None + self._fallback_function: Optional["FunctionContract"] = None + self._receive_function: Optional["FunctionContract"] = None + self._is_upgradeable: Optional[bool] = None self._is_upgradeable_proxy: Optional[bool] = None self._upgradeable_version: Optional[str] = None @@ -649,6 +652,24 @@ def functions_and_modifiers_declared(self) -> List["Function"]: """ return self.functions_declared + self.modifiers_declared # type: ignore + @property + def fallback_function(self) -> Optional["FunctionContract"]: + if self._fallback_function is None: + for f in self.functions: + if f.is_fallback: + self._fallback_function = f + break + return self._fallback_function + + @property + def receive_function(self) -> Optional["FunctionContract"]: + if self._receive_function is None: + for f in self.functions: + if f.is_receive: + self._receive_function = f + break + return self._receive_function + def available_elements_from_inheritances( self, elements: Dict[str, "Function"], From 26f80cfcf03c6a2f371d06e6cf178e04bfbf7691 Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 09:53:50 -0500 Subject: [PATCH 28/50] Use `Contract.fallback_function` --- slither/utils/upgradeability.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index f948138766..cefc0868c7 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -350,8 +350,7 @@ def get_proxy_implementation_var(proxy: Contract) -> Optional[Variable]: Returns: (`Variable`) | None : The variable, ideally a StateVariable, which stores the proxy's implementation address. """ - available_functions = proxy.available_functions_as_dict() - if not proxy.is_upgradeable_proxy or not available_functions["fallback()"]: + if not proxy.is_upgradeable_proxy or not proxy.fallback_function: return None delegate = find_delegate_in_fallback(proxy) @@ -375,7 +374,7 @@ def find_delegate_in_fallback(proxy: Contract) -> Optional[Variable]: (`Variable`) | None : The variable being passed as the destination argument in a delegatecall in the fallback. """ delegate: Optional[Variable] = None - fallback = proxy.available_functions_as_dict()["fallback()"] + fallback = proxy.fallback_function for node in fallback.all_nodes(): for ir in node.irs: if isinstance(ir, LowLevelCall) and ir.function_name == "delegatecall": From 1a54d0d325e78288eeaf5c2978ec25dc00c6974e Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 10:01:54 -0500 Subject: [PATCH 29/50] Handle hardcoded slot sloaded in delegatecall in `extract_delegate_from_asm` --- slither/utils/upgradeability.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index cefc0868c7..20759cb180 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -428,10 +428,17 @@ def extract_delegate_from_asm(contract: Contract, node: Node) -> Optional[Variab asm = next(line for line in asm_split if "delegatecall" in line) params = asm.split("call(")[1].split(", ") dest = params[1] - if dest.endswith(")"): + if dest.endswith(")") and not dest.startswith("sload("): dest = params[2] if dest.startswith("sload("): dest = dest.replace(")", "(").split("(")[1] + if len(dest) == 66 and dest.startswith("0x"): + v = StateVariable() + v.is_constant = True + v.expression = Literal(dest, ElementaryType("bytes32")) + v.name = dest + v.type = ElementaryType("bytes32") + return v for v in node.function.variables_read_or_written: if v.name == dest: if isinstance(v, LocalVariable) and v.expression is not None: From 24dad035f4dd00b0fd267a5bb6e8d0bd37d7cb06 Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 10:02:16 -0500 Subject: [PATCH 30/50] Document when a newly created variable can be returned --- slither/utils/upgradeability.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 20759cb180..1547116da1 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -344,6 +344,7 @@ def get_proxy_implementation_var(proxy: Contract) -> Optional[Variable]: """ Gets the Variable that stores a proxy's implementation address. Uses data dependency to trace any LocalVariable that is passed into a delegatecall as the target address back to its data source, ideally a StateVariable. + Can return a newly created StateVariable if an `sload` from a hardcoded storage slot is found in assembly. Args: proxy: A Contract object (proxy.is_upgradeable_proxy should be true). @@ -366,6 +367,7 @@ def get_proxy_implementation_var(proxy: Contract) -> Optional[Variable]: def find_delegate_in_fallback(proxy: Contract) -> Optional[Variable]: """ Searches a proxy's fallback function for a delegatecall, then extracts the Variable being passed in as the target. + Can return a newly created StateVariable if an `sload` from a hardcoded storage slot is found in assembly. Should typically be called by get_proxy_implementation_var(proxy). Args: proxy: A Contract object (should have a fallback function). @@ -416,6 +418,7 @@ def extract_delegate_from_asm(contract: Contract, node: Node) -> Optional[Variab """ Finds a Variable with a name matching the argument passed into a delegatecall, when all we have is an Assembly node with a block of code as one long string. Usually only the case for solc versions < 0.6.0. + Can return a newly created StateVariable if an `sload` from a hardcoded storage slot is found in assembly. Should typically be called by find_delegate_in_fallback(proxy). Args: contract: The parent Contract. From 9a9acbede982fa7363a6dffe08b6c08bd0a19b1f Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 10:05:57 -0500 Subject: [PATCH 31/50] Comment when a newly created variable can be returned and explain the magic value 66: 32 bytes = 64 chars + "0x" = 66 chars --- slither/utils/upgradeability.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 1547116da1..abaac0c10d 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -404,7 +404,9 @@ def find_delegate_in_fallback(proxy: Contract) -> Optional[Variable]: if isinstance(dest, Identifier): delegate = dest.value break - if isinstance(dest, Literal) and len(dest.value) == 66: + if isinstance(dest, Literal) and len(dest.value) == 66: # 32 bytes = 64 chars + "0x" = 66 chars + # Storage slot is not declared as a constant, but rather is hardcoded in the assembly, + # so create a new StateVariable to represent it. delegate = StateVariable() delegate.is_constant = True delegate.expression = dest @@ -435,7 +437,9 @@ def extract_delegate_from_asm(contract: Contract, node: Node) -> Optional[Variab dest = params[2] if dest.startswith("sload("): dest = dest.replace(")", "(").split("(")[1] - if len(dest) == 66 and dest.startswith("0x"): + if len(dest) == 66 and dest.startswith("0x"): # 32 bytes = 64 chars + "0x" = 66 chars + # Storage slot is not declared as a constant, but rather is hardcoded in the assembly, + # so create a new StateVariable to represent it. v = StateVariable() v.is_constant = True v.expression = Literal(dest, ElementaryType("bytes32")) From f148bbc0a89a7221a7950ea4d7aab2fa5abc91f9 Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 10:17:51 -0500 Subject: [PATCH 32/50] Also search `parent_func.returns` in `find_delegate_from_name` --- slither/utils/upgradeability.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index abaac0c10d..ead8039681 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -480,7 +480,7 @@ def find_delegate_from_name( for lv in parent_func.local_variables: if lv.name == dest: return lv - for pv in parent_func.parameters: + for pv in parent_func.parameters + parent_func.returns: if pv.name == dest: return pv return None From 722a343b77ccd590ce63631146ecdd0a1e173767 Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 10:28:36 -0500 Subject: [PATCH 33/50] Remove unused TEST_upgrade_diff.json --- .../TEST_upgrade_diff.json | 20 ------------------- 1 file changed, 20 deletions(-) delete mode 100644 tests/upgradeability-util/TEST_upgrade_diff.json diff --git a/tests/upgradeability-util/TEST_upgrade_diff.json b/tests/upgradeability-util/TEST_upgrade_diff.json deleted file mode 100644 index 4ccc4bd3fa..0000000000 --- a/tests/upgradeability-util/TEST_upgrade_diff.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "missing-vars-in-v2": [], - "new-variables": [ - "ContractV2.stateC" - ], - "tainted-variables": [ - "ContractV2.stateB", - "ContractV2.bug" - ], - "new-functions": [ - "i() returns()" - ], - "modified-functions": [ - "checkB() returns(bool)" - ], - "tainted-functions": [ - "g(uint256) returns()", - "h() returns()" - ] -} \ No newline at end of file From fc1b94c89c0e28d54baf3bc163f9d649c1f00f14 Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 12:05:39 -0500 Subject: [PATCH 34/50] Handle named variable declared in assembly which is not parsed as a LocalVariable for solc <0.6.0 If it gets its value from an sload, create new StateVariable --- slither/utils/upgradeability.py | 52 ++++++++++++++++++++++++--------- 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index ead8039681..a5017fafcc 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -407,11 +407,7 @@ def find_delegate_in_fallback(proxy: Contract) -> Optional[Variable]: if isinstance(dest, Literal) and len(dest.value) == 66: # 32 bytes = 64 chars + "0x" = 66 chars # Storage slot is not declared as a constant, but rather is hardcoded in the assembly, # so create a new StateVariable to represent it. - delegate = StateVariable() - delegate.is_constant = True - delegate.expression = dest - delegate.name = dest.value - delegate.type = ElementaryType("bytes32") + delegate = create_state_variable_from_slot(dest.value) break return delegate @@ -437,14 +433,8 @@ def extract_delegate_from_asm(contract: Contract, node: Node) -> Optional[Variab dest = params[2] if dest.startswith("sload("): dest = dest.replace(")", "(").split("(")[1] - if len(dest) == 66 and dest.startswith("0x"): # 32 bytes = 64 chars + "0x" = 66 chars - # Storage slot is not declared as a constant, but rather is hardcoded in the assembly, - # so create a new StateVariable to represent it. - v = StateVariable() - v.is_constant = True - v.expression = Literal(dest, ElementaryType("bytes32")) - v.name = dest - v.type = ElementaryType("bytes32") + v = create_state_variable_from_slot(dest) + if v is not None: return v for v in node.function.variables_read_or_written: if v.name == dest: @@ -466,6 +456,7 @@ def find_delegate_from_name( """ Searches for a variable with a given name, starting with StateVariables declared in the contract, followed by LocalVariables in the parent function, either declared in the function body or as parameters in the signature. + Can return a newly created StateVariable if an `sload` from a hardcoded storage slot is found in assembly. Args: contract: The Contract object to search. dest: The variable name to search for. @@ -483,4 +474,39 @@ def find_delegate_from_name( for pv in parent_func.parameters + parent_func.returns: if pv.name == dest: return pv + if parent_func.contains_assembly: + for node in parent_func.all_nodes(): + if node.type == NodeType.ASSEMBLY and isinstance(node.inline_asm, str): + asm = next((s for s in node.inline_asm.split("\n") if f"{dest}:=sload(" in s.replace(" ", "")), None) + if asm: + slot = asm.split("sload(")[1].split(")")[0] + return create_state_variable_from_slot(slot, name=dest) return None + + +def create_state_variable_from_slot(slot: str, name: str = None) -> Optional[StateVariable]: + """ + Creates a new StateVariable object to wrap a hardcoded storage slot found in assembly. + Args: + slot: The storage slot hex string. + name: Optional name for the variable. The slot string is used if name is not provided. + + Returns: + A newly created constant StateVariable of type bytes32, with the slot as the variable's expression and name, + if slot matches the length and prefix of a bytes32. Otherwise, returns None. + """ + if len(slot) == 66 and slot.startswith("0x"): # 32 bytes = 64 chars + "0x" = 66 chars + # Storage slot is not declared as a constant, but rather is hardcoded in the assembly, + # so create a new StateVariable to represent it. + v = StateVariable() + v.is_constant = True + v.expression = Literal(slot, ElementaryType("bytes32")) + if name is not None: + v.name = name + else: + v.name = slot + v.type = ElementaryType("bytes32") + return v + else: + # This should probably also handle hashed strings, but for now return None + return None From cbbcb8c8b128c329868614bccd493277daf4a4cc Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 13:16:50 -0500 Subject: [PATCH 35/50] pylint and black --- slither/utils/upgradeability.py | 41 ++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index a5017fafcc..1c53a6cf84 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -82,8 +82,12 @@ def compare( tainted-functions: list[Function] """ - order_vars1 = [v for v in v1.state_variables_ordered if not v.is_constant and not v.is_immutable] - order_vars2 = [v for v in v2.state_variables_ordered if not v.is_constant and not v.is_immutable] + order_vars1 = [ + v for v in v1.state_variables_ordered if not v.is_constant and not v.is_immutable + ] + order_vars2 = [ + v for v in v2.state_variables_ordered if not v.is_constant and not v.is_immutable + ] func_sigs1 = [function.solidity_signature for function in v1.functions] func_sigs2 = [function.solidity_signature for function in v2.functions] @@ -404,7 +408,9 @@ def find_delegate_in_fallback(proxy: Contract) -> Optional[Variable]: if isinstance(dest, Identifier): delegate = dest.value break - if isinstance(dest, Literal) and len(dest.value) == 66: # 32 bytes = 64 chars + "0x" = 66 chars + if ( + isinstance(dest, Literal) and len(dest.value) == 66 + ): # 32 bytes = 64 chars + "0x" = 66 chars # Storage slot is not declared as a constant, but rather is hardcoded in the assembly, # so create a new StateVariable to represent it. delegate = create_state_variable_from_slot(dest.value) @@ -477,10 +483,30 @@ def find_delegate_from_name( if parent_func.contains_assembly: for node in parent_func.all_nodes(): if node.type == NodeType.ASSEMBLY and isinstance(node.inline_asm, str): - asm = next((s for s in node.inline_asm.split("\n") if f"{dest}:=sload(" in s.replace(" ", "")), None) + asm = next( + ( + s + for s in node.inline_asm.split("\n") + if f"{dest}:=sload(" in s.replace(" ", "") + ), + None, + ) if asm: slot = asm.split("sload(")[1].split(")")[0] - return create_state_variable_from_slot(slot, name=dest) + if slot.startswith("0x"): + return create_state_variable_from_slot(slot, name=dest) + try: + slot_idx = int(slot) + return next( + ( + v + for v in contract.state_variables_ordered + if SlitherReadStorage.get_variable_info(contract, v)[0] == slot_idx + ), + None, + ) + except TypeError: + continue return None @@ -507,6 +533,5 @@ def create_state_variable_from_slot(slot: str, name: str = None) -> Optional[Sta v.name = slot v.type = ElementaryType("bytes32") return v - else: - # This should probably also handle hashed strings, but for now return None - return None + # This should probably also handle hashed strings, but for now return None + return None From bff30a3481f61babbd55323b3ab6a42eb3be6118 Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 13:18:28 -0500 Subject: [PATCH 36/50] Update upgradeability util tests with tests of implementation variable/slot getter functions --- tests/test_upgradeability_util.py | 60 ++++++++++++++--- .../TestUpgrades-0.5.0.sol | 5 ++ ...estUpgrades.sol => TestUpgrades-0.8.2.sol} | 0 .../upgradeability-util/src/EIP1822Proxy.sol | 47 +++++++++++++ .../src/MasterCopyProxy.sol | 27 ++++++++ tests/upgradeability-util/src/ZosProxy.sol | 67 +++++++++++++++++++ 6 files changed, 197 insertions(+), 9 deletions(-) create mode 100644 tests/upgradeability-util/TestUpgrades-0.5.0.sol rename tests/upgradeability-util/{TestUpgrades.sol => TestUpgrades-0.8.2.sol} (100%) create mode 100644 tests/upgradeability-util/src/EIP1822Proxy.sol create mode 100644 tests/upgradeability-util/src/MasterCopyProxy.sol create mode 100644 tests/upgradeability-util/src/ZosProxy.sol diff --git a/tests/test_upgradeability_util.py b/tests/test_upgradeability_util.py index 72fb8c4ece..b0beba0cdf 100644 --- a/tests/test_upgradeability_util.py +++ b/tests/test_upgradeability_util.py @@ -1,14 +1,14 @@ import os -import json -import re from solc_select import solc_select -from deepdiff import DeepDiff from slither import Slither -from slither.core.declarations import Function -from slither.core.variables import StateVariable -from slither.utils.upgradeability import compare +from slither.core.expressions import Literal +from slither.utils.upgradeability import ( + compare, + get_proxy_implementation_var, + get_proxy_implementation_slot, +) SLITHER_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) UPGRADE_TEST_ROOT = os.path.join(SLITHER_ROOT, "tests", "upgradeability-util") @@ -18,7 +18,7 @@ def test_upgrades_compare() -> None: solc_select.switch_global_version("0.8.2", always_install=True) - sl = Slither(os.path.join(UPGRADE_TEST_ROOT, "TestUpgrades.sol")) + sl = Slither(os.path.join(UPGRADE_TEST_ROOT, "TestUpgrades-0.8.2.sol")) v1 = sl.get_contract_from_name("ContractV1")[0] v2 = sl.get_contract_from_name("ContractV2")[0] missing_vars, new_vars, tainted_vars, new_funcs, modified_funcs, tainted_funcs = compare(v1, v2) @@ -26,11 +26,53 @@ def test_upgrades_compare() -> None: assert new_vars == [v2.get_state_variable_from_name("stateC")] assert tainted_vars == [ v2.get_state_variable_from_name("stateB"), - v2.get_state_variable_from_name("bug") + v2.get_state_variable_from_name("bug"), ] assert new_funcs == [v2.get_function_from_signature("i()")] assert modified_funcs == [v2.get_function_from_signature("checkB()")] assert tainted_funcs == [ v2.get_function_from_signature("g(uint256)"), - v2.get_function_from_signature("h()") + v2.get_function_from_signature("h()"), ] + + +def test_upgrades_implementation_var() -> None: + solc_select.switch_global_version("0.8.2", always_install=True) + sl = Slither(os.path.join(UPGRADE_TEST_ROOT, "TestUpgrades-0.8.2.sol")) + + erc_1967_proxy = sl.get_contract_from_name("ERC1967Proxy")[0] + storage_proxy = sl.get_contract_from_name("InheritedStorageProxy")[0] + + target = get_proxy_implementation_var(erc_1967_proxy) + slot = get_proxy_implementation_slot(erc_1967_proxy) + assert target == erc_1967_proxy.get_state_variable_from_name("_IMPLEMENTATION_SLOT") + assert slot.slot == 0x360894A13BA1A3210667C828492DB98DCA3E2076CC3735A920A3CA505D382BBC + target = get_proxy_implementation_var(storage_proxy) + slot = get_proxy_implementation_slot(storage_proxy) + assert target == storage_proxy.get_state_variable_from_name("implementation") + assert slot.slot == 1 + + solc_select.switch_global_version("0.5.0", always_install=True) + sl = Slither(os.path.join(UPGRADE_TEST_ROOT, "TestUpgrades-0.5.0.sol")) + + eip_1822_proxy = sl.get_contract_from_name("EIP1822Proxy")[0] + zos_proxy = sl.get_contract_from_name("ZosProxy")[0] + master_copy_proxy = sl.get_contract_from_name("MasterCopyProxy")[0] + + target = get_proxy_implementation_var(eip_1822_proxy) + slot = get_proxy_implementation_slot(eip_1822_proxy) + assert target not in eip_1822_proxy.state_variables_ordered + assert target.name == "contractLogic" and isinstance(target.expression, Literal) + assert ( + target.expression.value + == "0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7" + ) + assert slot.slot == 0xC5F16F0FCC639FA48A6947836D9850F504798523BF8C9A3A87D5876CF622BCF7 + target = get_proxy_implementation_var(zos_proxy) + slot = get_proxy_implementation_slot(zos_proxy) + assert target == zos_proxy.get_state_variable_from_name("IMPLEMENTATION_SLOT") + assert slot.slot == 0x7050C9E0F4CA769C69BD3A8EF740BC37934F8E2C036E5A723FD8EE048ED3F8C3 + target = get_proxy_implementation_var(master_copy_proxy) + slot = get_proxy_implementation_slot(master_copy_proxy) + assert target == master_copy_proxy.get_state_variable_from_name("masterCopy") + assert slot.slot == 0 diff --git a/tests/upgradeability-util/TestUpgrades-0.5.0.sol b/tests/upgradeability-util/TestUpgrades-0.5.0.sol new file mode 100644 index 0000000000..86fa42ec59 --- /dev/null +++ b/tests/upgradeability-util/TestUpgrades-0.5.0.sol @@ -0,0 +1,5 @@ +pragma solidity ^0.5.0; + +import "./src/EIP1822Proxy.sol"; +import "./src/ZosProxy.sol"; +import "./src/MasterCopyProxy.sol"; diff --git a/tests/upgradeability-util/TestUpgrades.sol b/tests/upgradeability-util/TestUpgrades-0.8.2.sol similarity index 100% rename from tests/upgradeability-util/TestUpgrades.sol rename to tests/upgradeability-util/TestUpgrades-0.8.2.sol diff --git a/tests/upgradeability-util/src/EIP1822Proxy.sol b/tests/upgradeability-util/src/EIP1822Proxy.sol new file mode 100644 index 0000000000..3145eb17e1 --- /dev/null +++ b/tests/upgradeability-util/src/EIP1822Proxy.sol @@ -0,0 +1,47 @@ +pragma solidity ^0.5.0; + +contract EIP1822Proxy { + // Code position in storage is keccak256("PROXIABLE") = "0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7" + constructor(bytes memory constructData, address contractLogic) public { + // save the code address + assembly { // solium-disable-line + sstore(0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7, contractLogic) + } + (bool success, bytes memory _ ) = contractLogic.delegatecall(constructData); // solium-disable-line + require(success, "Construction failed"); + } + + function() external payable { + assembly { // solium-disable-line + let contractLogic := sload(0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7) + calldatacopy(0x0, 0x0, calldatasize) + let success := delegatecall(sub(gas, 10000), contractLogic, 0x0, calldatasize, 0, 0) + let retSz := returndatasize + returndatacopy(0, 0, retSz) + switch success + case 0 { + revert(0, retSz) + } + default { + return(0, retSz) + } + } + } +} + +contract EIP1822Proxiable { + // Code position in storage is keccak256("PROXIABLE") = "0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7" + + function updateCodeAddress(address newAddress) internal { + require( + bytes32(0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7) == EIP1822Proxiable(newAddress).proxiableUUID(), + "Not compatible" + ); + assembly { // solium-disable-line + sstore(0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7, newAddress) + } + } + function proxiableUUID() public pure returns (bytes32) { + return 0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7; + } +} \ No newline at end of file diff --git a/tests/upgradeability-util/src/MasterCopyProxy.sol b/tests/upgradeability-util/src/MasterCopyProxy.sol new file mode 100644 index 0000000000..d25a2a920b --- /dev/null +++ b/tests/upgradeability-util/src/MasterCopyProxy.sol @@ -0,0 +1,27 @@ +pragma solidity ^0.5.0; + +contract MasterCopyProxy { + address internal masterCopy; + + constructor(address _masterCopy) + public + { + require(_masterCopy != address(0), "Invalid master copy address provided"); + masterCopy = _masterCopy; + } + + /// @dev Fallback function forwards all transactions and returns all received return data. + function () + external + payable + { + // solium-disable-next-line security/no-inline-assembly + assembly { + calldatacopy(0, 0, calldatasize()) + let success := delegatecall(gas, sload(0), 0, calldatasize(), 0, 0) + returndatacopy(0, 0, returndatasize()) + if eq(success, 0) { revert(0, returndatasize()) } + return(0, returndatasize()) + } + } +} diff --git a/tests/upgradeability-util/src/ZosProxy.sol b/tests/upgradeability-util/src/ZosProxy.sol new file mode 100644 index 0000000000..db44f4c983 --- /dev/null +++ b/tests/upgradeability-util/src/ZosProxy.sol @@ -0,0 +1,67 @@ +pragma solidity ^0.5.0; + +contract ZosProxy { + function () payable external { + _fallback(); + } + + function _implementation() internal view returns (address); + + function _delegate(address implementation) internal { + assembly { + calldatacopy(0, 0, calldatasize) + let result := delegatecall(gas, implementation, 0, calldatasize, 0, 0) + returndatacopy(0, 0, returndatasize) + switch result + case 0 { revert(0, returndatasize) } + default { return(0, returndatasize) } + } + } + + function _willFallback() internal { + } + + function _fallback() internal { + _willFallback(); + _delegate(_implementation()); + } +} + +library AddressUtils { + function isContract(address addr) internal view returns (bool) { + uint256 size; + assembly { size := extcodesize(addr) } + return size > 0; + } +} + +contract UpgradeabilityProxy is ZosProxy { + event Upgraded(address indexed implementation); + + bytes32 private constant IMPLEMENTATION_SLOT = 0x7050c9e0f4ca769c69bd3a8ef740bc37934f8e2c036e5a723fd8ee048ed3f8c3; + + constructor(address _implementation) public payable { + assert(IMPLEMENTATION_SLOT == keccak256("org.zeppelinos.proxy.implementation")); + _setImplementation(_implementation); + } + + function _implementation() internal view returns (address impl) { + bytes32 slot = IMPLEMENTATION_SLOT; + assembly { + impl := sload(slot) + } + } + + function _upgradeTo(address newImplementation) internal { + _setImplementation(newImplementation); + emit Upgraded(newImplementation); + } + + function _setImplementation(address newImplementation) private { + require(AddressUtils.isContract(newImplementation), "Cannot set a proxy implementation to a non-contract address"); + bytes32 slot = IMPLEMENTATION_SLOT; + assembly { + sstore(slot, newImplementation) + } + } +} From 574afbe04bf399bdb72c1fad160557062f101a56 Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 13:39:41 -0500 Subject: [PATCH 37/50] Add `get_missing_vars` to util, use it in `compare` --- slither/utils/upgradeability.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 1c53a6cf84..ef3000fe17 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, List from slither.core.declarations import ( Contract, Structure, @@ -102,9 +102,7 @@ def compare( # Since this is not a detector, include any missing variables in the v2 contract if len(order_vars2) < len(order_vars1): - for variable in order_vars1: - if variable.name not in [v.name for v in order_vars2]: - results["missing-vars-in-v2"].append(variable) + results["missing-vars-in-v2"].extend(get_missing_vars(v1, v2)) # Find all new and modified functions in the v2 contract new_modified_functions = [] @@ -171,6 +169,30 @@ def compare( ) +def get_missing_vars(v1: Contract, v2: Contract) -> List[StateVariable]: + """ + Gets all non-constant/immutable StateVariables that appear in v1 but not v2 + Args: + v1: Contract version 1 + v2: Contract version 2 + + Returns: + List of StateVariables from v1 missing in v2 + """ + results = [] + order_vars1 = [ + v for v in v1.state_variables_ordered if not v.is_constant and not v.is_immutable + ] + order_vars2 = [ + v for v in v2.state_variables_ordered if not v.is_constant and not v.is_immutable + ] + if len(order_vars2) < len(order_vars1): + for variable in order_vars1: + if variable.name not in [v.name for v in order_vars2]: + results.append(variable) + return results + + def is_function_modified(f1: Function, f2: Function) -> bool: """ Compares two versions of a function, and returns True if the function has been modified. From f216817d21e8398f4663708cc7d1ca07984dd0e2 Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 13:40:28 -0500 Subject: [PATCH 38/50] Use `get_missing_vars` in MissingVariable detector --- .../upgradeability/checks/variables_order.py | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/slither/tools/upgradeability/checks/variables_order.py b/slither/tools/upgradeability/checks/variables_order.py index 030fb0f65f..e83d8b32b4 100644 --- a/slither/tools/upgradeability/checks/variables_order.py +++ b/slither/tools/upgradeability/checks/variables_order.py @@ -2,6 +2,7 @@ CheckClassification, AbstractCheck, ) +from slither.utils.upgradeability import get_missing_vars class MissingVariable(AbstractCheck): @@ -48,24 +49,13 @@ class MissingVariable(AbstractCheck): def _check(self): contract1 = self.contract contract2 = self.contract_v2 - order1 = [ - variable - for variable in contract1.state_variables_ordered - if not (variable.is_constant or variable.is_immutable) - ] - order2 = [ - variable - for variable in contract2.state_variables_ordered - if not (variable.is_constant or variable.is_immutable) - ] + missing = get_missing_vars(contract1, contract2) results = [] - for idx, _ in enumerate(order1): - variable1 = order1[idx] - if len(order2) <= idx: - info = ["Variable missing in ", contract2, ": ", variable1, "\n"] - json = self.generate_result(info) - results.append(json) + for variable1 in missing: + info = ["Variable missing in ", contract2, ": ", variable1, "\n"] + json = self.generate_result(info) + results.append(json) return results From 1588334844ba8e137e2f3bf69dd970041fc485ae Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 13:49:45 -0500 Subject: [PATCH 39/50] Fix `compare` return signature --- slither/utils/upgradeability.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index ef3000fe17..8ea7a4d0cc 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -63,7 +63,7 @@ def compare( v1: Contract, v2: Contract ) -> Tuple[ - list[Variable], list[Variable], list[Variable], list[Function], list[Function], list[Function] + List[Variable], List[Variable], List[Variable], List[Function], List[Function], List[Function] ]: """ Compares two versions of a contract. Most useful for upgradeable (logic) contracts, From 0e708e6c051bce69e61a2ee80f7c379d17eeb7ff Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 14:21:31 -0500 Subject: [PATCH 40/50] Handle sload from integer slot, i.e., `sload(0)` --- slither/utils/upgradeability.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 8ea7a4d0cc..73720ca54f 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -461,9 +461,18 @@ def extract_delegate_from_asm(contract: Contract, node: Node) -> Optional[Variab dest = params[2] if dest.startswith("sload("): dest = dest.replace(")", "(").split("(")[1] - v = create_state_variable_from_slot(dest) - if v is not None: - return v + if dest.startswith("0x"): + return create_state_variable_from_slot(dest) + if dest.isnumeric(): + slot_idx = int(dest) + return next( + ( + v + for v in contract.state_variables_ordered + if SlitherReadStorage.get_variable_info(contract, v)[0] == slot_idx + ), + None, + ) for v in node.function.variables_read_or_written: if v.name == dest: if isinstance(v, LocalVariable) and v.expression is not None: From 9192fef4a7ad007b52c07eb79b6c3477836d1c89 Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 14:22:25 -0500 Subject: [PATCH 41/50] Comment out ZosProxy test for now (see issue #1775 for why it fails) --- tests/test_upgradeability_util.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_upgradeability_util.py b/tests/test_upgradeability_util.py index b0beba0cdf..ddd0ac5935 100644 --- a/tests/test_upgradeability_util.py +++ b/tests/test_upgradeability_util.py @@ -56,7 +56,7 @@ def test_upgrades_implementation_var() -> None: sl = Slither(os.path.join(UPGRADE_TEST_ROOT, "TestUpgrades-0.5.0.sol")) eip_1822_proxy = sl.get_contract_from_name("EIP1822Proxy")[0] - zos_proxy = sl.get_contract_from_name("ZosProxy")[0] + # zos_proxy = sl.get_contract_from_name("ZosProxy")[0] master_copy_proxy = sl.get_contract_from_name("MasterCopyProxy")[0] target = get_proxy_implementation_var(eip_1822_proxy) @@ -68,10 +68,11 @@ def test_upgrades_implementation_var() -> None: == "0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7" ) assert slot.slot == 0xC5F16F0FCC639FA48A6947836D9850F504798523BF8C9A3A87D5876CF622BCF7 - target = get_proxy_implementation_var(zos_proxy) - slot = get_proxy_implementation_slot(zos_proxy) - assert target == zos_proxy.get_state_variable_from_name("IMPLEMENTATION_SLOT") - assert slot.slot == 0x7050C9E0F4CA769C69BD3A8EF740BC37934F8E2C036E5A723FD8EE048ED3F8C3 + # # The util fails with this proxy due to how Slither parses assembly w/ Solidity versions < 0.6.0 (see issue #1775) + # target = get_proxy_implementation_var(zos_proxy) + # slot = get_proxy_implementation_slot(zos_proxy) + # assert target == zos_proxy.get_state_variable_from_name("IMPLEMENTATION_SLOT") + # assert slot.slot == 0x7050C9E0F4CA769C69BD3A8EF740BC37934F8E2C036E5A723FD8EE048ED3F8C3 target = get_proxy_implementation_var(master_copy_proxy) slot = get_proxy_implementation_slot(master_copy_proxy) assert target == master_copy_proxy.get_state_variable_from_name("masterCopy") From 72c6d78130b5d9cc42d6b46f61c8b90368f16a2b Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 15:04:59 -0500 Subject: [PATCH 42/50] Add SynthProxy.sol to test_upgradeability_util.py test --- tests/test_upgradeability_util.py | 5 ++ .../TestUpgrades-0.5.0.sol | 1 + tests/upgradeability-util/src/SynthProxy.sol | 58 +++++++++++++++++++ 3 files changed, 64 insertions(+) create mode 100644 tests/upgradeability-util/src/SynthProxy.sol diff --git a/tests/test_upgradeability_util.py b/tests/test_upgradeability_util.py index ddd0ac5935..dd12d68a15 100644 --- a/tests/test_upgradeability_util.py +++ b/tests/test_upgradeability_util.py @@ -58,6 +58,7 @@ def test_upgrades_implementation_var() -> None: eip_1822_proxy = sl.get_contract_from_name("EIP1822Proxy")[0] # zos_proxy = sl.get_contract_from_name("ZosProxy")[0] master_copy_proxy = sl.get_contract_from_name("MasterCopyProxy")[0] + synth_proxy = sl.get_contract_from_name("SynthProxy")[0] target = get_proxy_implementation_var(eip_1822_proxy) slot = get_proxy_implementation_slot(eip_1822_proxy) @@ -77,3 +78,7 @@ def test_upgrades_implementation_var() -> None: slot = get_proxy_implementation_slot(master_copy_proxy) assert target == master_copy_proxy.get_state_variable_from_name("masterCopy") assert slot.slot == 0 + target = get_proxy_implementation_var(synth_proxy) + slot = get_proxy_implementation_slot(synth_proxy) + assert target == synth_proxy.get_state_variable_from_name("target") + assert slot.slot == 1 diff --git a/tests/upgradeability-util/TestUpgrades-0.5.0.sol b/tests/upgradeability-util/TestUpgrades-0.5.0.sol index 86fa42ec59..eaecfa6e97 100644 --- a/tests/upgradeability-util/TestUpgrades-0.5.0.sol +++ b/tests/upgradeability-util/TestUpgrades-0.5.0.sol @@ -3,3 +3,4 @@ pragma solidity ^0.5.0; import "./src/EIP1822Proxy.sol"; import "./src/ZosProxy.sol"; import "./src/MasterCopyProxy.sol"; +import "./src/SynthProxy.sol"; diff --git a/tests/upgradeability-util/src/SynthProxy.sol b/tests/upgradeability-util/src/SynthProxy.sol new file mode 100644 index 0000000000..9b3a6bdefa --- /dev/null +++ b/tests/upgradeability-util/src/SynthProxy.sol @@ -0,0 +1,58 @@ +pragma solidity ^0.5.0; + +contract Owned { + address public owner; + + constructor(address _owner) public { + require(_owner != address(0), "Owner address cannot be 0"); + owner = _owner; + } + + modifier onlyOwner { + require(msg.sender == owner, "Only the contract owner may perform this action"); + _; + } +} + +contract Proxyable is Owned { + /* The proxy this contract exists behind. */ + SynthProxy public proxy; + + constructor(address payable _proxy) internal { + // This contract is abstract, and thus cannot be instantiated directly + require(owner != address(0), "Owner must be set"); + + proxy = SynthProxy(_proxy); + } + + function setProxy(address payable _proxy) external onlyOwner { + proxy = SynthProxy(_proxy); + } +} + + +contract SynthProxy is Owned { + Proxyable public target; + + constructor(address _owner) public Owned(_owner) {} + + function setTarget(Proxyable _target) external onlyOwner { + target = _target; + } + + // solhint-disable no-complex-fallback + function() external payable { + assembly { + calldatacopy(0, 0, calldatasize) + + /* We must explicitly forward ether to the underlying contract as well. */ + let result := delegatecall(gas, sload(target_slot), 0, calldatasize, 0, 0) + returndatacopy(0, 0, returndatasize) + + if iszero(result) { + revert(0, returndatasize) + } + return(0, returndatasize) + } + } +} From 78e2ea37da4a00e325d2168abea570964935a231 Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 24 Mar 2023 08:27:31 -0500 Subject: [PATCH 43/50] Add types to function signatures --- slither/utils/upgradeability.py | 50 +++++++++++++++++---------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 73720ca54f..d3b2c0a53b 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, Union from slither.core.declarations import ( Contract, Structure, @@ -8,6 +8,7 @@ Function, ) from slither.core.solidity_types import ( + Type, ElementaryType, ArrayType, MappingType, @@ -24,6 +25,7 @@ from slither.core.expressions.assignment_operation import AssignmentOperation from slither.core.cfg.node import Node, NodeType from slither.slithir.operations import ( + Operation, Assignment, Index, Member, @@ -91,18 +93,16 @@ def compare( func_sigs1 = [function.solidity_signature for function in v1.functions] func_sigs2 = [function.solidity_signature for function in v2.functions] - results = { - "missing-vars-in-v2": [], - "new-variables": [], - "tainted-variables": [], - "new-functions": [], - "modified-functions": [], - "tainted-functions": [], - } + missing_vars_in_v2 = [] + new_variables = [] + tainted_variables = [] + new_functions = [] + modified_functions = [] + tainted_functions = [] # Since this is not a detector, include any missing variables in the v2 contract if len(order_vars2) < len(order_vars1): - results["missing-vars-in-v2"].extend(get_missing_vars(v1, v2)) + missing_vars_in_v2.extend(get_missing_vars(v1, v2)) # Find all new and modified functions in the v2 contract new_modified_functions = [] @@ -112,7 +112,7 @@ def compare( orig_function = v1.get_function_from_signature(sig) if sig not in func_sigs1: new_modified_functions.append(function) - results["new-functions"].append(function) + new_functions.append(function) new_modified_function_vars += ( function.state_variables_read + function.state_variables_written ) @@ -120,7 +120,7 @@ def compare( orig_function, function ): new_modified_functions.append(function) - results["modified-functions"].append(function) + modified_functions.append(function) new_modified_function_vars += ( function.state_variables_read + function.state_variables_written ) @@ -145,27 +145,27 @@ def compare( and not var.is_immutable ] if len(modified_calls) > 0 or len(tainted_vars) > 0: - results["tainted-functions"].append(function) + tainted_functions.append(function) # Find all new or tainted variables, i.e., variables that are read or written by a new/modified/tainted function for var in order_vars2: read_by = v2.get_functions_reading_from_variable(var) written_by = v2.get_functions_writing_to_variable(var) if v1.get_state_variable_from_name(var.name) is None: - results["new-variables"].append(var) + new_variables.append(var) elif any( func in read_by or func in written_by - for func in new_modified_functions + results["tainted-functions"] + for func in new_modified_functions + tainted_functions ): - results["tainted-variables"].append(var) + tainted_variables.append(var) return ( - results["missing-vars-in-v2"], - results["new-variables"], - results["tainted-variables"], - results["new-functions"], - results["modified-functions"], - results["tainted-functions"], + missing_vars_in_v2, + new_variables, + tainted_variables, + new_functions, + modified_functions, + tainted_functions, ) @@ -226,7 +226,7 @@ def is_function_modified(f1: Function, f2: Function) -> bool: return False -def ntype(_type): # pylint: disable=too-many-branches +def ntype(_type: Union[Type, str]) -> str: # pylint: disable=too-many-branches if isinstance(_type, ElementaryType): _type = str(_type) elif isinstance(_type, ArrayType): @@ -261,7 +261,9 @@ def ntype(_type): # pylint: disable=too-many-branches return _type.replace(" ", "_") -def encode_ir_for_compare(ir) -> str: # pylint: disable=too-many-branches +def encode_ir_for_compare( + ir: Union[Operation, Variable] +) -> str: # pylint: disable=too-many-branches # operations if isinstance(ir, Assignment): return f"({encode_ir_for_compare(ir.lvalue)}):=({encode_ir_for_compare(ir.rvalue)})" From c094818be4a3e2fc4e94f4b33ec02d205db936a6 Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 24 Mar 2023 09:55:59 -0500 Subject: [PATCH 44/50] Pylint --- slither/utils/upgradeability.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index d3b2c0a53b..9ef76170cb 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -226,7 +226,8 @@ def is_function_modified(f1: Function, f2: Function) -> bool: return False -def ntype(_type: Union[Type, str]) -> str: # pylint: disable=too-many-branches +# pylint: disable=too-many-branches +def ntype(_type: Union[Type, str]) -> str: if isinstance(_type, ElementaryType): _type = str(_type) elif isinstance(_type, ArrayType): @@ -261,9 +262,10 @@ def ntype(_type: Union[Type, str]) -> str: # pylint: disable=too-many-branches return _type.replace(" ", "_") +# pylint: disable=too-many-branches def encode_ir_for_compare( ir: Union[Operation, Variable] -) -> str: # pylint: disable=too-many-branches +) -> str: # operations if isinstance(ir, Assignment): return f"({encode_ir_for_compare(ir.lvalue)}):=({encode_ir_for_compare(ir.rvalue)})" From 0ac4c01abddd5f9ba184f6910606e91fcc9c8d8e Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 24 Mar 2023 12:51:49 -0500 Subject: [PATCH 45/50] Black --- slither/utils/upgradeability.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 9ef76170cb..30213a53c6 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -263,9 +263,7 @@ def ntype(_type: Union[Type, str]) -> str: # pylint: disable=too-many-branches -def encode_ir_for_compare( - ir: Union[Operation, Variable] -) -> str: +def encode_ir_for_compare(ir: Union[Operation, Variable]) -> str: # operations if isinstance(ir, Assignment): return f"({encode_ir_for_compare(ir.lvalue)}):=({encode_ir_for_compare(ir.rvalue)})" From 2695243b2fca9c9eb2de7d4b6feb92dcde41f5d5 Mon Sep 17 00:00:00 2001 From: webthethird Date: Sat, 25 Mar 2023 11:42:34 -0500 Subject: [PATCH 46/50] Create separate `encore_var_for_compare` --- slither/utils/upgradeability.py | 54 +++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 30213a53c6..b05f0e2139 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -221,7 +221,18 @@ def is_function_modified(f1: Function, f2: Function) -> bool: queue_f1.extend(son for son in node_f1.sons if son not in visited) queue_f2.extend(son for son in node_f2.sons if son not in visited) for i, ir in enumerate(node_f1.irs): - if encode_ir_for_compare(ir) != encode_ir_for_compare(node_f2.irs[i]): + ir2 = node_f2.irs[i] + encoded1 = ( + encode_var_for_compare(ir) + if isinstance(ir, Variable) + else encode_ir_for_compare(ir) + ) + encoded2 = ( + encode_var_for_compare(ir2) + if isinstance(ir2, Variable) + else encode_ir_for_compare(ir2) + ) + if encoded1 != encoded2: return True return False @@ -263,7 +274,7 @@ def ntype(_type: Union[Type, str]) -> str: # pylint: disable=too-many-branches -def encode_ir_for_compare(ir: Union[Operation, Variable]) -> str: +def encode_ir_for_compare(ir: Operation) -> str: # operations if isinstance(ir, Assignment): return f"({encode_ir_for_compare(ir.lvalue)}):=({encode_ir_for_compare(ir.rvalue)})" @@ -315,27 +326,32 @@ def encode_ir_for_compare(ir: Union[Operation, Variable]) -> str: return "unpack" if isinstance(ir, InitArray): # TODO: improve return "init_array" - if isinstance(ir, Function): # TODO: investigate this - return "function_solc" + + # default + return "" + + +# pylint: disable=too-many-branches +def encode_var_for_compare(var: Variable) -> str: # variables - if isinstance(ir, Constant): - return f"constant({ntype(ir.type)})" - if isinstance(ir, SolidityVariableComposed): - return f"solidity_variable_composed({ir.name})" - if isinstance(ir, SolidityVariable): - return f"solidity_variable{ir.name}" - if isinstance(ir, TemporaryVariable): + if isinstance(var, Constant): + return f"constant({ntype(var.type)})" + if isinstance(var, SolidityVariableComposed): + return f"solidity_variable_composed({var.name})" + if isinstance(var, SolidityVariable): + return f"solidity_variable{var.name}" + if isinstance(var, TemporaryVariable): return "temporary_variable" - if isinstance(ir, ReferenceVariable): - return f"reference({ntype(ir.type)})" - if isinstance(ir, LocalVariable): - return f"local_solc_variable({ir.location})" - if isinstance(ir, StateVariable): - return f"state_solc_variable({ntype(ir.type)})" - if isinstance(ir, LocalVariableInitFromTuple): + if isinstance(var, ReferenceVariable): + return f"reference({ntype(var.type)})" + if isinstance(var, LocalVariable): + return f"local_solc_variable({var.location})" + if isinstance(var, StateVariable): + return f"state_solc_variable({ntype(var.type)})" + if isinstance(var, LocalVariableInitFromTuple): return "local_variable_init_tuple" - if isinstance(ir, TupleVariable): + if isinstance(var, TupleVariable): return "tuple_variable" # default From 47c92f80f09492588834a506774ad090ab50193b Mon Sep 17 00:00:00 2001 From: webthethird Date: Mon, 27 Mar 2023 08:28:01 -0500 Subject: [PATCH 47/50] Fix ir encoding in comparison --- slither/utils/upgradeability.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index b05f0e2139..7b4e8493a7 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -221,18 +221,7 @@ def is_function_modified(f1: Function, f2: Function) -> bool: queue_f1.extend(son for son in node_f1.sons if son not in visited) queue_f2.extend(son for son in node_f2.sons if son not in visited) for i, ir in enumerate(node_f1.irs): - ir2 = node_f2.irs[i] - encoded1 = ( - encode_var_for_compare(ir) - if isinstance(ir, Variable) - else encode_ir_for_compare(ir) - ) - encoded2 = ( - encode_var_for_compare(ir2) - if isinstance(ir2, Variable) - else encode_ir_for_compare(ir2) - ) - if encoded1 != encoded2: + if encode_ir_for_compare(ir) != encode_ir_for_compare(node_f2.irs[i]): return True return False @@ -253,7 +242,12 @@ def ntype(_type: Union[Type, str]) -> str: elif isinstance(_type, MappingType): _type = str(_type) elif isinstance(_type, UserDefinedType): - _type = "user_defined_type" # TODO: this could be Contract, Enum or Struct + if isinstance(_type.type, Contract): + _type = f"contract({_type.type.name})" + elif isinstance(_type.type, Structure): + _type = f"struct({_type.type.name})" + elif isinstance(_type.type, Enum): + _type = f"enum({_type.type.name})" else: _type = str(_type) @@ -277,7 +271,7 @@ def ntype(_type: Union[Type, str]) -> str: def encode_ir_for_compare(ir: Operation) -> str: # operations if isinstance(ir, Assignment): - return f"({encode_ir_for_compare(ir.lvalue)}):=({encode_ir_for_compare(ir.rvalue)})" + return f"({encode_var_for_compare(ir.lvalue)}):=({encode_var_for_compare(ir.rvalue)})" if isinstance(ir, Index): return f"index({ntype(ir.index_type)})" if isinstance(ir, Member): @@ -289,7 +283,7 @@ def encode_ir_for_compare(ir: Operation) -> str: if isinstance(ir, Unary): return f"unary({str(ir.type)})" if isinstance(ir, Condition): - return f"condition({encode_ir_for_compare(ir.value)})" + return f"condition({encode_var_for_compare(ir.value)})" if isinstance(ir, NewStructure): return "new_structure" if isinstance(ir, NewContract): @@ -299,7 +293,7 @@ def encode_ir_for_compare(ir: Operation) -> str: if isinstance(ir, NewElementaryType): return f"new_elementary({ntype(ir.type)})" if isinstance(ir, Delete): - return f"delete({encode_ir_for_compare(ir.lvalue)},{encode_ir_for_compare(ir.variable)})" + return f"delete({encode_var_for_compare(ir.lvalue)},{encode_var_for_compare(ir.variable)})" if isinstance(ir, SolidityCall): return f"solidity_call({ir.function.full_name})" if isinstance(ir, InternalCall): @@ -319,9 +313,9 @@ def encode_ir_for_compare(ir: Operation) -> str: if isinstance(ir, Return): # this can be improved using values return "return" # .format(ntype(ir.type)) if isinstance(ir, Transfer): - return f"transfer({encode_ir_for_compare(ir.call_value)})" + return f"transfer({encode_var_for_compare(ir.call_value)})" if isinstance(ir, Send): - return f"send({encode_ir_for_compare(ir.call_value)})" + return f"send({encode_var_for_compare(ir.call_value)})" if isinstance(ir, Unpack): # TODO: improve return "unpack" if isinstance(ir, InitArray): # TODO: improve From be0e405f48d89bbcba858bf3acd1321dea02ea96 Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 28 Mar 2023 09:12:51 -0500 Subject: [PATCH 48/50] Move upgradeability util test files --- .../test_data/upgradeability_util}/TestUpgrades-0.5.0.sol | 0 .../test_data/upgradeability_util}/TestUpgrades-0.8.2.sol | 0 .../utils/test_data/upgradeability_util}/src/Address.sol | 0 .../test_data/upgradeability_util}/src/ContractV1.sol | 0 .../test_data/upgradeability_util}/src/ContractV2.sol | 0 .../test_data/upgradeability_util}/src/EIP1822Proxy.sol | 0 .../test_data/upgradeability_util}/src/ERC1967Proxy.sol | 0 .../test_data/upgradeability_util}/src/ERC1967Upgrade.sol | 0 .../upgradeability_util}/src/InheritedStorageProxy.sol | 0 .../upgradeability_util}/src/MasterCopyProxy.sol | 0 .../utils/test_data/upgradeability_util}/src/Proxy.sol | 0 .../test_data/upgradeability_util}/src/ProxyStorage.sol | 0 .../test_data/upgradeability_util}/src/StorageSlot.sol | 0 .../test_data/upgradeability_util}/src/SynthProxy.sol | 0 .../utils/test_data/upgradeability_util}/src/ZosProxy.sol | 0 tests/{ => unit/utils}/test_upgradeability_util.py | 8 ++++---- 16 files changed, 4 insertions(+), 4 deletions(-) rename tests/{upgradeability-util => unit/utils/test_data/upgradeability_util}/TestUpgrades-0.5.0.sol (100%) rename tests/{upgradeability-util => unit/utils/test_data/upgradeability_util}/TestUpgrades-0.8.2.sol (100%) rename tests/{upgradeability-util => unit/utils/test_data/upgradeability_util}/src/Address.sol (100%) rename tests/{upgradeability-util => unit/utils/test_data/upgradeability_util}/src/ContractV1.sol (100%) rename tests/{upgradeability-util => unit/utils/test_data/upgradeability_util}/src/ContractV2.sol (100%) rename tests/{upgradeability-util => unit/utils/test_data/upgradeability_util}/src/EIP1822Proxy.sol (100%) rename tests/{upgradeability-util => unit/utils/test_data/upgradeability_util}/src/ERC1967Proxy.sol (100%) rename tests/{upgradeability-util => unit/utils/test_data/upgradeability_util}/src/ERC1967Upgrade.sol (100%) rename tests/{upgradeability-util => unit/utils/test_data/upgradeability_util}/src/InheritedStorageProxy.sol (100%) rename tests/{upgradeability-util => unit/utils/test_data/upgradeability_util}/src/MasterCopyProxy.sol (100%) rename tests/{upgradeability-util => unit/utils/test_data/upgradeability_util}/src/Proxy.sol (100%) rename tests/{upgradeability-util => unit/utils/test_data/upgradeability_util}/src/ProxyStorage.sol (100%) rename tests/{upgradeability-util => unit/utils/test_data/upgradeability_util}/src/StorageSlot.sol (100%) rename tests/{upgradeability-util => unit/utils/test_data/upgradeability_util}/src/SynthProxy.sol (100%) rename tests/{upgradeability-util => unit/utils/test_data/upgradeability_util}/src/ZosProxy.sol (100%) rename tests/{ => unit/utils}/test_upgradeability_util.py (92%) diff --git a/tests/upgradeability-util/TestUpgrades-0.5.0.sol b/tests/unit/utils/test_data/upgradeability_util/TestUpgrades-0.5.0.sol similarity index 100% rename from tests/upgradeability-util/TestUpgrades-0.5.0.sol rename to tests/unit/utils/test_data/upgradeability_util/TestUpgrades-0.5.0.sol diff --git a/tests/upgradeability-util/TestUpgrades-0.8.2.sol b/tests/unit/utils/test_data/upgradeability_util/TestUpgrades-0.8.2.sol similarity index 100% rename from tests/upgradeability-util/TestUpgrades-0.8.2.sol rename to tests/unit/utils/test_data/upgradeability_util/TestUpgrades-0.8.2.sol diff --git a/tests/upgradeability-util/src/Address.sol b/tests/unit/utils/test_data/upgradeability_util/src/Address.sol similarity index 100% rename from tests/upgradeability-util/src/Address.sol rename to tests/unit/utils/test_data/upgradeability_util/src/Address.sol diff --git a/tests/upgradeability-util/src/ContractV1.sol b/tests/unit/utils/test_data/upgradeability_util/src/ContractV1.sol similarity index 100% rename from tests/upgradeability-util/src/ContractV1.sol rename to tests/unit/utils/test_data/upgradeability_util/src/ContractV1.sol diff --git a/tests/upgradeability-util/src/ContractV2.sol b/tests/unit/utils/test_data/upgradeability_util/src/ContractV2.sol similarity index 100% rename from tests/upgradeability-util/src/ContractV2.sol rename to tests/unit/utils/test_data/upgradeability_util/src/ContractV2.sol diff --git a/tests/upgradeability-util/src/EIP1822Proxy.sol b/tests/unit/utils/test_data/upgradeability_util/src/EIP1822Proxy.sol similarity index 100% rename from tests/upgradeability-util/src/EIP1822Proxy.sol rename to tests/unit/utils/test_data/upgradeability_util/src/EIP1822Proxy.sol diff --git a/tests/upgradeability-util/src/ERC1967Proxy.sol b/tests/unit/utils/test_data/upgradeability_util/src/ERC1967Proxy.sol similarity index 100% rename from tests/upgradeability-util/src/ERC1967Proxy.sol rename to tests/unit/utils/test_data/upgradeability_util/src/ERC1967Proxy.sol diff --git a/tests/upgradeability-util/src/ERC1967Upgrade.sol b/tests/unit/utils/test_data/upgradeability_util/src/ERC1967Upgrade.sol similarity index 100% rename from tests/upgradeability-util/src/ERC1967Upgrade.sol rename to tests/unit/utils/test_data/upgradeability_util/src/ERC1967Upgrade.sol diff --git a/tests/upgradeability-util/src/InheritedStorageProxy.sol b/tests/unit/utils/test_data/upgradeability_util/src/InheritedStorageProxy.sol similarity index 100% rename from tests/upgradeability-util/src/InheritedStorageProxy.sol rename to tests/unit/utils/test_data/upgradeability_util/src/InheritedStorageProxy.sol diff --git a/tests/upgradeability-util/src/MasterCopyProxy.sol b/tests/unit/utils/test_data/upgradeability_util/src/MasterCopyProxy.sol similarity index 100% rename from tests/upgradeability-util/src/MasterCopyProxy.sol rename to tests/unit/utils/test_data/upgradeability_util/src/MasterCopyProxy.sol diff --git a/tests/upgradeability-util/src/Proxy.sol b/tests/unit/utils/test_data/upgradeability_util/src/Proxy.sol similarity index 100% rename from tests/upgradeability-util/src/Proxy.sol rename to tests/unit/utils/test_data/upgradeability_util/src/Proxy.sol diff --git a/tests/upgradeability-util/src/ProxyStorage.sol b/tests/unit/utils/test_data/upgradeability_util/src/ProxyStorage.sol similarity index 100% rename from tests/upgradeability-util/src/ProxyStorage.sol rename to tests/unit/utils/test_data/upgradeability_util/src/ProxyStorage.sol diff --git a/tests/upgradeability-util/src/StorageSlot.sol b/tests/unit/utils/test_data/upgradeability_util/src/StorageSlot.sol similarity index 100% rename from tests/upgradeability-util/src/StorageSlot.sol rename to tests/unit/utils/test_data/upgradeability_util/src/StorageSlot.sol diff --git a/tests/upgradeability-util/src/SynthProxy.sol b/tests/unit/utils/test_data/upgradeability_util/src/SynthProxy.sol similarity index 100% rename from tests/upgradeability-util/src/SynthProxy.sol rename to tests/unit/utils/test_data/upgradeability_util/src/SynthProxy.sol diff --git a/tests/upgradeability-util/src/ZosProxy.sol b/tests/unit/utils/test_data/upgradeability_util/src/ZosProxy.sol similarity index 100% rename from tests/upgradeability-util/src/ZosProxy.sol rename to tests/unit/utils/test_data/upgradeability_util/src/ZosProxy.sol diff --git a/tests/test_upgradeability_util.py b/tests/unit/utils/test_upgradeability_util.py similarity index 92% rename from tests/test_upgradeability_util.py rename to tests/unit/utils/test_upgradeability_util.py index dd12d68a15..520edaef9d 100644 --- a/tests/test_upgradeability_util.py +++ b/tests/unit/utils/test_upgradeability_util.py @@ -11,14 +11,14 @@ ) SLITHER_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -UPGRADE_TEST_ROOT = os.path.join(SLITHER_ROOT, "tests", "upgradeability-util") +TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data" / "upgradeability_util" # pylint: disable=too-many-locals def test_upgrades_compare() -> None: solc_select.switch_global_version("0.8.2", always_install=True) - sl = Slither(os.path.join(UPGRADE_TEST_ROOT, "TestUpgrades-0.8.2.sol")) + sl = Slither(os.path.join(TEST_DATA_DIR, "TestUpgrades-0.8.2.sol")) v1 = sl.get_contract_from_name("ContractV1")[0] v2 = sl.get_contract_from_name("ContractV2")[0] missing_vars, new_vars, tainted_vars, new_funcs, modified_funcs, tainted_funcs = compare(v1, v2) @@ -38,7 +38,7 @@ def test_upgrades_compare() -> None: def test_upgrades_implementation_var() -> None: solc_select.switch_global_version("0.8.2", always_install=True) - sl = Slither(os.path.join(UPGRADE_TEST_ROOT, "TestUpgrades-0.8.2.sol")) + sl = Slither(os.path.join(TEST_DATA_DIR, "TestUpgrades-0.8.2.sol")) erc_1967_proxy = sl.get_contract_from_name("ERC1967Proxy")[0] storage_proxy = sl.get_contract_from_name("InheritedStorageProxy")[0] @@ -53,7 +53,7 @@ def test_upgrades_implementation_var() -> None: assert slot.slot == 1 solc_select.switch_global_version("0.5.0", always_install=True) - sl = Slither(os.path.join(UPGRADE_TEST_ROOT, "TestUpgrades-0.5.0.sol")) + sl = Slither(os.path.join(TEST_DATA_DIR, "TestUpgrades-0.5.0.sol")) eip_1822_proxy = sl.get_contract_from_name("EIP1822Proxy")[0] # zos_proxy = sl.get_contract_from_name("ZosProxy")[0] From 8731a923617857473fc36103d696bea44daa001c Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 28 Mar 2023 09:25:55 -0500 Subject: [PATCH 49/50] Add Path import --- tests/unit/utils/test_upgradeability_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/utils/test_upgradeability_util.py b/tests/unit/utils/test_upgradeability_util.py index 520edaef9d..7d6fb82dad 100644 --- a/tests/unit/utils/test_upgradeability_util.py +++ b/tests/unit/utils/test_upgradeability_util.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from solc_select import solc_select From 5703b9db8006c4d9f09028892824d40eb6268709 Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 28 Mar 2023 09:35:07 -0500 Subject: [PATCH 50/50] Add test for `Contract.fallback_function` and `Contract.receive_function` --- tests/unit/core/test_data/fallback.sol | 29 ++++++++++++++++++++++++ tests/unit/core/test_fallback_receive.py | 20 ++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 tests/unit/core/test_data/fallback.sol create mode 100644 tests/unit/core/test_fallback_receive.py diff --git a/tests/unit/core/test_data/fallback.sol b/tests/unit/core/test_data/fallback.sol new file mode 100644 index 0000000000..cd7dc18121 --- /dev/null +++ b/tests/unit/core/test_data/fallback.sol @@ -0,0 +1,29 @@ +pragma solidity ^0.6.12; + +contract FakeFallback { + mapping(address => uint) public contributions; + address payable public owner; + + constructor() public { + owner = payable(msg.sender); + contributions[msg.sender] = 1000 * (1 ether); + } + + function fallback() public payable { + contributions[msg.sender] += msg.value; + } + + function receive() public payable { + contributions[msg.sender] += msg.value; + } +} + +contract Fallback is FakeFallback { + receive() external payable { + contributions[msg.sender] += msg.value; + } + + fallback() external payable { + contributions[msg.sender] += msg.value; + } +} diff --git a/tests/unit/core/test_fallback_receive.py b/tests/unit/core/test_fallback_receive.py new file mode 100644 index 0000000000..505a9dd6fd --- /dev/null +++ b/tests/unit/core/test_fallback_receive.py @@ -0,0 +1,20 @@ +from pathlib import Path +from solc_select import solc_select + +from slither import Slither +from slither.core.declarations.function import FunctionType + +TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data" + + +def test_fallback_receive(): + solc_select.switch_global_version("0.6.12", always_install=True) + file = Path(TEST_DATA_DIR, "fallback.sol").as_posix() + slither = Slither(file) + fake_fallback = slither.get_contract_from_name("FakeFallback")[0] + real_fallback = slither.get_contract_from_name("Fallback")[0] + + assert fake_fallback.fallback_function is None + assert fake_fallback.receive_function is None + assert real_fallback.fallback_function.function_type == FunctionType.FALLBACK + assert real_fallback.receive_function.function_type == FunctionType.RECEIVE