diff --git a/tests/parser/features/test_assignment.py b/tests/parser/features/test_assignment.py index 65fb3a7a0e..0dd63a0d09 100644 --- a/tests/parser/features/test_assignment.py +++ b/tests/parser/features/test_assignment.py @@ -255,3 +255,63 @@ def foo(): ret : bool = self.bar() """ assert_compile_failed(lambda: get_contract_with_gas_estimation(code), InvalidType) + + +def test_assign_rhs_lhs_overlap(get_contract): + # GH issue 2418 + code = """ +@external +def bug(xs: uint256[2]) -> uint256[2]: + # Initial value + ys: uint256[2] = xs + ys = [ys[1], ys[0]] + return ys + """ + c = get_contract(code) + + assert c.bug([1, 2]) == [2, 1] + + +def test_assign_rhs_lhs_partial_overlap(get_contract): + # GH issue 2418, generalize when lhs is not only dependency of rhs. + code = """ +@external +def bug(xs: uint256[2]) -> uint256[2]: + # Initial value + ys: uint256[2] = xs + ys = [xs[1], ys[0]] + return ys + """ + c = get_contract(code) + + assert c.bug([1, 2]) == [2, 1] + + +def test_assign_rhs_lhs_overlap_dynarray(get_contract): + # GH issue 2418, generalize to dynarrays + code = """ +@external +def bug(xs: DynArray[uint256, 2]) -> DynArray[uint256, 2]: + ys: DynArray[uint256, 2] = xs + ys = [ys[1], ys[0]] + return ys + """ + c = get_contract(code) + assert c.bug([1, 2]) == [2, 1] + + +def test_assign_rhs_lhs_overlap_struct(get_contract): + # GH issue 2418, generalize to structs + code = """ +struct Point: + x: uint256 + y: uint256 + +@external +def bug(p: Point) -> Point: + t: Point = p + t = Point({x: t.y, y: t.x}) + return t + """ + c = get_contract(code) + assert c.bug((1, 2)) == (2, 1) diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 9902cd0cf7..696b81d124 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -30,6 +30,9 @@ class VariableRecord: is_immutable: bool = False data_offset: Optional[int] = None + def __hash__(self): + return hash(id(self)) + def __post_init__(self): if self.blockscopes is None: self.blockscopes = [] diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index dd764fbe20..6da3d9501b 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -166,7 +166,7 @@ def parse_Name(self): return IRnode.from_list(["address"], typ=AddressT()) elif self.expr.id in self.context.vars: var = self.context.vars[self.expr.id] - return IRnode.from_list( + ret = IRnode.from_list( var.pos, typ=var.typ, location=var.location, # either 'memory' or 'calldata' storage is handled above. @@ -174,6 +174,8 @@ def parse_Name(self): annotation=self.expr.id, mutable=var.mutable, ) + ret._referenced_variables = {var} + return ret # TODO: use self.expr._expr_info elif self.expr.id in self.context.globals: @@ -189,9 +191,11 @@ def parse_Name(self): mutable = False location = DATA - return IRnode.from_list( + ret = IRnode.from_list( ofst, typ=varinfo.typ, location=location, annotation=self.expr.id, mutable=mutable ) + ret._referenced_variables = {varinfo} + return ret # x.y or x[5] def parse_Attribute(self): @@ -255,12 +259,16 @@ def parse_Attribute(self): # self.x: global attribute elif isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id == "self": varinfo = self.context.globals[self.expr.attr] - return IRnode.from_list( + ret = IRnode.from_list( varinfo.position.position, typ=varinfo.typ, location=STORAGE, annotation="self." + self.expr.attr, ) + ret._referenced_variables = {varinfo} + + return ret + # Reserved keywords elif ( isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id in ENVIRONMENT_VARIABLES diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index 1ba4122c66..d36a18ec66 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -397,6 +397,16 @@ def cache_when_complex(self, name): return _WithBuilder(self, name, should_inline) + @cached_property + def referenced_variables(self): + ret = set() + for arg in self.args: + ret |= arg.referenced_variables + + ret |= getattr(self, "_referenced_variables", set()) + + return ret + @cached_property def contains_self_call(self): return getattr(self, "is_self_call", False) or any(x.contains_self_call for x in self.args) diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 01c1d5f121..e24c429638 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -73,11 +73,22 @@ def parse_AnnAssign(self): def parse_Assign(self): # Assignment (e.g. x[4] = y) - sub = Expr(self.stmt.value, self.context).ir_node - target = self._get_target(self.stmt.target) + src = Expr(self.stmt.value, self.context).ir_node + dst = self._get_target(self.stmt.target) - ir_node = make_setter(target, sub) - return ir_node + ret = ["seq"] + overlap = len(dst.referenced_variables & src.referenced_variables) > 0 + if overlap and not dst.typ._is_prim_word: + # there is overlap between the lhs and rhs, and the type is + # complex - i.e., it spans multiple words. for safety, we + # copy to a temporary buffer before copying to the destination. + tmp = self.context.new_internal_variable(src.typ) + tmp = IRnode.from_list(tmp, typ=src.typ, location=MEMORY) + ret.append(make_setter(tmp, src)) + src = tmp + + ret.append(make_setter(dst, src)) + return IRnode.from_list(ret) def parse_If(self): if self.stmt.orelse: @@ -336,8 +347,12 @@ def _parse_For_list(self): def parse_AugAssign(self): target = self._get_target(self.stmt.target) + sub = Expr.parse_value_expr(self.stmt.value, self.context) if not target.typ._is_prim_word: + # because of this check, we do not need to check for + # make_setter references lhs<->rhs as in parse_Assign - + # single word load/stores are atomic. return with target.cache_when_complex("_loc") as (b, target): diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 5919c96686..5065131f29 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -165,6 +165,9 @@ class VarInfo: is_local_var: bool = False decl_node: Optional[vy_ast.VyperNode] = None + def __hash__(self): + return hash(id(self)) + def __post_init__(self): self._modification_count = 0