diff --git a/vyper/optimizer.py b/vyper/optimizer.py index e23ea5c9f5..91dce1c64b 100644 --- a/vyper/optimizer.py +++ b/vyper/optimizer.py @@ -61,6 +61,60 @@ def optimize(lll_node: LLLnode) -> LLLnode: def apply_general_optimizations(node: LLLnode) -> LLLnode: # TODO refactor this into several functions argz = [apply_general_optimizations(arg) for arg in node.args] + + if node.value == "seq": + # look for sequential mzero / calldatacopy operations that are zero'ing memory + # and merge them into a single calldatacopy + mzero_nodes: List = [] + initial_offset = 0 + total_length = 0 + for lll_node in [i for i in argz if i.value != "pass"]: + if ( + lll_node.value == "mstore" + and isinstance(lll_node.args[0].value, int) + and lll_node.args[1].value == 0 + ): + # mstore of a zero value + offset = lll_node.args[0].value + if not mzero_nodes: + initial_offset = offset + if initial_offset + total_length == offset: + mzero_nodes.append(lll_node) + total_length += 32 + continue + + if ( + lll_node.value == "calldatacopy" + and isinstance(lll_node.args[0].value, int) + and lll_node.args[1].value == "calldatasize" + and isinstance(lll_node.args[2].value, int) + ): + # calldatacopy from the end of calldata - efficient zero'ing via `empty()` + offset, length = lll_node.args[0].value, lll_node.args[2].value + if not mzero_nodes: + initial_offset = offset + if initial_offset + total_length == offset: + mzero_nodes.append(lll_node) + total_length += length + continue + + # if we get this far, the current node is not a zero'ing operation + # it's time to apply the optimization if possible + if len(mzero_nodes) > 1: + new_lll = LLLnode.from_list( + ["calldatacopy", initial_offset, "calldatasize", total_length], + pos=mzero_nodes[0].pos, + ) + # replace first zero'ing operation with optimized node and remove the rest + idx = argz.index(mzero_nodes[0]) + argz[idx] = new_lll + for i in mzero_nodes[1:]: + argz.remove(i) + + initial_offset = 0 + total_length = 0 + mzero_nodes.clear() + if node.value in arith and int_at(argz, 0) and int_at(argz, 1): left, right = get_int_at(argz, 0), get_int_at(argz, 1) # `node.value in arith` implies that `node.value` is a `str`