Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support user defined operators #1684

Merged
merged 9 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion slither/core/solidity_types/type_alias.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Tuple, Dict

from slither.core.declarations.top_level import TopLevel
from slither.core.declarations.contract_level import ContractLevel
from slither.core.solidity_types import Type, ElementaryType

if TYPE_CHECKING:
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations import Contract
from slither.core.scope.scope import FileScope

Expand Down Expand Up @@ -43,6 +44,8 @@ class TypeAliasTopLevel(TypeAlias, TopLevel):
def __init__(self, underlying_type: ElementaryType, name: str, scope: "FileScope") -> None:
super().__init__(underlying_type, name)
self.file_scope: "FileScope" = scope
# operators redefined
self.operators: Dict[str, "FunctionTopLevel"] = {}

def __str__(self) -> str:
return self.name
Expand Down
48 changes: 34 additions & 14 deletions slither/solc_parsing/declarations/using_for_top_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,29 @@ def analyze(self) -> None:
self._propagate_global(type_name)
else:
for f in self._functions:
full_name_split = f["function"]["name"].split(".")
if len(full_name_split) == 1:
# User defined operator
if "operator" in f:
# Top level function
function_name: str = full_name_split[0]
self._analyze_top_level_function(function_name, type_name)
elif len(full_name_split) == 2:
# It can be a top level function behind an aliased import
# or a library function
first_part = full_name_split[0]
function_name = full_name_split[1]
self._check_aliased_import(first_part, function_name, type_name)
function_name: str = f["definition"]["name"]
operator: str = f["operator"]
self._analyze_operator(operator, function_name, type_name)
else:
# MyImport.MyLib.a we don't care of the alias
library_name_str = full_name_split[1]
function_name = full_name_split[2]
self._analyze_library_function(library_name_str, function_name, type_name)
full_name_split = f["function"]["name"].split(".")
if len(full_name_split) == 1:
# Top level function
function_name: str = full_name_split[0]
self._analyze_top_level_function(function_name, type_name)
elif len(full_name_split) == 2:
# It can be a top level function behind an aliased import
# or a library function
first_part = full_name_split[0]
function_name = full_name_split[1]
self._check_aliased_import(first_part, function_name, type_name)
else:
# MyImport.MyLib.a we don't care of the alias
library_name_str = full_name_split[1]
function_name = full_name_split[2]
self._analyze_library_function(library_name_str, function_name, type_name)

def _check_aliased_import(
self,
Expand Down Expand Up @@ -101,6 +108,19 @@ def _analyze_top_level_function(
self._propagate_global(type_name)
break

def _analyze_operator(
self, operator: str, function_name: str, type_name: TypeAliasTopLevel
) -> None:
for tl_function in self._using_for.file_scope.functions:
# The library function is bound to the first parameter's type
if (
tl_function.name == function_name
and tl_function.parameters
and type_name == tl_function.parameters[0].type
):
type_name.operators[operator] = tl_function
break

def _analyze_library_function(
self,
library_name: str,
Expand Down
58 changes: 46 additions & 12 deletions slither/solc_parsing/expressions/expression_parsing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import re
from typing import Union, Dict, TYPE_CHECKING
from typing import Union, Dict, TYPE_CHECKING, List, Any

import slither.core.expressions.type_conversion
from slither.core.declarations.solidity_variables import (
Expand Down Expand Up @@ -236,6 +236,24 @@ def _parse_elementary_type_name_expression(
pass


def _user_defined_op_call(
caller_context: CallerContextExpression, src, function_id: int, args: List[Any], type_call: str
) -> CallExpression:
var, was_created = find_variable(None, caller_context, function_id)

if was_created:
var.set_offset(src, caller_context.compilation_unit)

identifier = Identifier(var)
identifier.set_offset(src, caller_context.compilation_unit)

var.references.append(identifier.source_mapping)

call = CallExpression(identifier, args, type_call)
call.set_offset(src, caller_context.compilation_unit)
return call


def parse_expression(expression: Dict, caller_context: CallerContextExpression) -> "Expression":
# pylint: disable=too-many-nested-blocks,too-many-statements
"""
Expand Down Expand Up @@ -274,34 +292,50 @@ def parse_expression(expression: Dict, caller_context: CallerContextExpression)
if name == "UnaryOperation":
if is_compact_ast:
attributes = expression
else:
attributes = expression["attributes"]
assert "prefix" in attributes
operation_type = UnaryOperationType.get_type(attributes["operator"], attributes["prefix"])

if is_compact_ast:
expression = parse_expression(expression["subExpression"], caller_context)
else:
attributes = expression["attributes"]
assert len(expression["children"]) == 1
expression = parse_expression(expression["children"][0], caller_context)
assert "prefix" in attributes

# Use of user defined operation
if "function" in attributes:
return _user_defined_op_call(
caller_context,
src,
attributes["function"],
[expression],
attributes["typeDescriptions"]["typeString"],
)

operation_type = UnaryOperationType.get_type(attributes["operator"], attributes["prefix"])
unary_op = UnaryOperation(expression, operation_type)
unary_op.set_offset(src, caller_context.compilation_unit)
return unary_op

if name == "BinaryOperation":
if is_compact_ast:
attributes = expression
else:
attributes = expression["attributes"]
operation_type = BinaryOperationType.get_type(attributes["operator"])

if is_compact_ast:
left_expression = parse_expression(expression["leftExpression"], caller_context)
right_expression = parse_expression(expression["rightExpression"], caller_context)
else:
assert len(expression["children"]) == 2
attributes = expression["attributes"]
left_expression = parse_expression(expression["children"][0], caller_context)
right_expression = parse_expression(expression["children"][1], caller_context)

# Use of user defined operation
if "function" in attributes:
return _user_defined_op_call(
caller_context,
src,
attributes["function"],
[left_expression, right_expression],
attributes["typeDescriptions"]["typeString"],
)

operation_type = BinaryOperationType.get_type(attributes["operator"])
binary_op = BinaryOperation(left_expression, right_expression, operation_type)
binary_op.set_offset(src, caller_context.compilation_unit)
return binary_op
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/solc_parsing/test_ast_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ def make_version(minor: int, patch_min: int, patch_max: int) -> List[str]:
"assembly-functions.sol",
["0.6.9", "0.7.6", "0.8.16"],
),
Test("user_defined_operators-0.8.19.sol", ["0.8.19"]),
]
# create the output folder if needed
try:
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"Lib": {
"f(Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n}\n"
},
"T": {
"add_function_call(Int,Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: NEW VARIABLE 1\n\"];\n1->2;\n2[label=\"Node Type: RETURN 2\n\"];\n}\n",
"add_op(Int,Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n",
"lib_call(Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n",
"neg_usertype(Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: NEW VARIABLE 1\n\"];\n1->2;\n2[label=\"Node Type: RETURN 2\n\"];\n}\n",
"neg_int(int256)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n",
"eq_op(Int,Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
pragma solidity ^0.8.19;

type Int is int;
using {add as +, eq as ==, add, neg as -, Lib.f} for Int global;

function add(Int a, Int b) pure returns (Int) {
return Int.wrap(Int.unwrap(a) + Int.unwrap(b));
}

function eq(Int a, Int b) pure returns (bool) {
return true;
}

function neg(Int a) pure returns (Int) {
return a;
}

library Lib {
function f(Int r) internal {}
}

contract T {
function add_function_call(Int b, Int c) public returns(Int) {
Int res = add(b,c);
return res;
}

function add_op(Int b, Int c) public returns(Int) {
return b + c;
}

function lib_call(Int b) public {
return b.f();
}

function neg_usertype(Int b) public returns(Int) {
Int res = -b;
return res;
}

function neg_int(int b) public returns(int) {
return -b;
}

function eq_op(Int b, Int c) public returns(bool) {
return b == c;
}
}