diff --git a/tests/parser/features/decorators/test_nonreentrant.py b/tests/parser/features/decorators/test_nonreentrant.py index a844e13b32..fa0cbad1e8 100644 --- a/tests/parser/features/decorators/test_nonreentrant.py +++ b/tests/parser/features/decorators/test_nonreentrant.py @@ -30,11 +30,14 @@ def set_callback(c: address): @public @nonreentrant('protect_special_value') -def protected_function(val: string[100], do_callback: bool): +def protected_function(val: string[100], do_callback: bool) -> uint256: self.special_value = val if do_callback: self.callback.updated_protected() + return 1 + else: + return 2 @public def unprotected_function(val: string[100], do_callback: bool): diff --git a/vyper/parser/context.py b/vyper/parser/context.py index ea64a83d76..81bcbd27e2 100644 --- a/vyper/parser/context.py +++ b/vyper/parser/context.py @@ -33,7 +33,8 @@ def __init__(self, is_private=False, is_payable=False, origcode='', - method_id=''): + method_id='', + sig=None): # In-memory variables, in the form (name, memory location, type) self.vars = vars or {} # Memory alloctor, keeps track of currently allocated memory. @@ -79,6 +80,8 @@ def __init__(self, self.method_id = method_id # store global context self.global_ctx = global_ctx + # full function signature + self.sig = sig def is_constant(self): return self.constancy is Constancy.Constant or \ diff --git a/vyper/parser/function_definitions/parse_function.py b/vyper/parser/function_definitions/parse_function.py index 00888c074a..5b80157de9 100644 --- a/vyper/parser/function_definitions/parse_function.py +++ b/vyper/parser/function_definitions/parse_function.py @@ -62,7 +62,8 @@ def parse_function(code, sigs, origcode, global_ctx, _vars=None): is_payable=sig.payable, origcode=origcode, is_private=sig.private, - method_id=sig.method_id + method_id=sig.method_id, + sig=sig ) if sig.private: diff --git a/vyper/parser/parser_utils.py b/vyper/parser/parser_utils.py index 5b8f805a31..b5e7199128 100644 --- a/vyper/parser/parser_utils.py +++ b/vyper/parser/parser_utils.py @@ -906,6 +906,10 @@ def zero_pad(bytez_placeholder, maxlen, context): # Generate return code for stmt def make_return_stmt(stmt, context, begin_pos, _size, loop_memory_position=None): + from vyper.parser.function_definitions.utils import ( + get_nonreentrant_lock + ) + _, nonreentrant_post = get_nonreentrant_lock(context.sig, context.global_ctx) if context.is_private: if loop_memory_position is None: loop_memory_position = context.new_placeholder(typ=BaseType('uint256')) @@ -922,7 +926,8 @@ def make_return_stmt(stmt, context, begin_pos, _size, loop_memory_position=None) mloads = [ ['mload', pos] for pos in range(begin_pos, _size, 32) ] - return ['seq_unchecked'] + mloads + [['jump', ['mload', context.callback_ptr]]] + return ['seq_unchecked'] + mloads + nonreentrant_post + \ + [['jump', ['mload', context.callback_ptr]]] else: mloads = [ 'seq_unchecked', @@ -945,9 +950,10 @@ def make_return_stmt(stmt, context, begin_pos, _size, loop_memory_position=None) ['goto', start_label], ['label', exit_label] ] - return ['seq_unchecked'] + [mloads] + [['jump', ['mload', context.callback_ptr]]] + return ['seq_unchecked'] + [mloads] + nonreentrant_post + \ + [['jump', ['mload', context.callback_ptr]]] else: - return ['return', begin_pos, _size] + return ['seq_unchecked'] + nonreentrant_post + [['return', begin_pos, _size]] # Generate code for returning a tuple or struct. diff --git a/vyper/parser/self_call.py b/vyper/parser/self_call.py index 072ef9885d..237a5183b0 100644 --- a/vyper/parser/self_call.py +++ b/vyper/parser/self_call.py @@ -130,7 +130,7 @@ def call_self_private(stmt_expr, context, sig): static_arg_size = 32 * sum( [get_static_size_of_type(arg.typ) for arg in expr_args]) - static_pos = arg_pos + static_arg_size + static_pos = int(arg_pos + static_arg_size) needs_dyn_section = any( [has_dynamic_data(arg.typ) for arg in expr_args])