Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: assignment when rhs is complex type and references lhs #3410

Merged
merged 6 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tests/parser/features/test_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,33 @@ def foo():
ret : bool = self.bar()
"""
assert_compile_failed(lambda: get_contract_with_gas_estimation(code), InvalidType)


def test_assign_rhs_lhs_overlap(get_contract):
# issue #2418
code = """
@external
def bug(xs: uint256[2]) -> uint256[2]:
# Initial value
ys: uint256[2] = xs
ys = [ys[1], ys[0]]
return ys
"""
c = get_contract(code)

assert c.bug([1, 2]) == [2, 1]


def test_assign_rhs_lhs_partial_overlap(get_contract):
# issue #2418, generalize when lhs is not only dependency of rhs.
code = """
@external
def bug(xs: uint256[2]) -> uint256[2]:
# Initial value
ys: uint256[2] = xs
ys = [xs[1], ys[0]]
return ys
"""
c = get_contract(code)

assert c.bug([1, 2]) == [2, 1]
3 changes: 3 additions & 0 deletions vyper/codegen/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class VariableRecord:
is_immutable: bool = False
data_offset: Optional[int] = None

def __hash__(self):
return hash(id(self))

def __post_init__(self):
if self.blockscopes is None:
self.blockscopes = []
Expand Down
14 changes: 11 additions & 3 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,16 @@ def parse_Name(self):
return IRnode.from_list(["address"], typ=AddressT())
elif self.expr.id in self.context.vars:
var = self.context.vars[self.expr.id]
return IRnode.from_list(
ret = IRnode.from_list(
var.pos,
typ=var.typ,
location=var.location, # either 'memory' or 'calldata' storage is handled above.
encoding=var.encoding,
annotation=self.expr.id,
mutable=var.mutable,
)
ret._referenced_variables = {var}
return ret

# TODO: use self.expr._expr_info
elif self.expr.id in self.context.globals:
Expand All @@ -189,9 +191,11 @@ def parse_Name(self):
mutable = False
location = DATA

return IRnode.from_list(
ret = IRnode.from_list(
ofst, typ=varinfo.typ, location=location, annotation=self.expr.id, mutable=mutable
)
ret._referenced_variables = {varinfo}
return ret

# x.y or x[5]
def parse_Attribute(self):
Expand Down Expand Up @@ -255,12 +259,16 @@ def parse_Attribute(self):
# self.x: global attribute
elif isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id == "self":
varinfo = self.context.globals[self.expr.attr]
return IRnode.from_list(
ret = IRnode.from_list(
varinfo.position.position,
typ=varinfo.typ,
location=STORAGE,
annotation="self." + self.expr.attr,
)
ret._referenced_variables = {varinfo}

return ret

# Reserved keywords
elif (
isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id in ENVIRONMENT_VARIABLES
Expand Down
10 changes: 10 additions & 0 deletions vyper/codegen/ir_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,16 @@ def cache_when_complex(self, name):

return _WithBuilder(self, name, should_inline)

@cached_property
def referenced_variables(self):
ret = set()
for arg in self.args:
ret |= arg.referenced_variables

ret |= getattr(self, "_referenced_variables", set())

return ret

@cached_property
def contains_self_call(self):
return getattr(self, "is_self_call", False) or any(x.contains_self_call for x in self.args)
Expand Down
23 changes: 19 additions & 4 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,22 @@ def parse_AnnAssign(self):

def parse_Assign(self):
# Assignment (e.g. x[4] = y)
sub = Expr(self.stmt.value, self.context).ir_node
target = self._get_target(self.stmt.target)
src = Expr(self.stmt.value, self.context).ir_node
dst = self._get_target(self.stmt.target)

ir_node = make_setter(target, sub)
return ir_node
ret = ["seq"]
overlap = len(dst.referenced_variables & src.referenced_variables) > 0
if overlap and not dst.typ._is_prim_word:
# there is overlap between the lhs and rhs, and the type is
# complex - i.e., it spans multiple words. for safety, we
# copy to a temporary buffer before copying to the destination.
tmp = self.context.new_internal_variable(src.typ)
tmp = IRnode.from_list(tmp, typ=src.typ, location=MEMORY)
ret.append(make_setter(tmp, src))
src = tmp

ret.append(make_setter(dst, src))
return IRnode.from_list(ret)

def parse_If(self):
if self.stmt.orelse:
Expand Down Expand Up @@ -336,8 +347,12 @@ def _parse_For_list(self):

def parse_AugAssign(self):
target = self._get_target(self.stmt.target)

sub = Expr.parse_value_expr(self.stmt.value, self.context)
if not target.typ._is_prim_word:
# because of this check, we do not need to check for
# make_setter references lhs<->rhs as in parse_Assign -
# single word load/stores are atomic.
return

with target.cache_when_complex("_loc") as (b, target):
Expand Down
3 changes: 3 additions & 0 deletions vyper/semantics/analysis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ class VarInfo:
is_local_var: bool = False
decl_node: Optional[vy_ast.VyperNode] = None

def __hash__(self):
return hash(id(self))

def __post_init__(self):
self._modification_count = 0

Expand Down