Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Block scoping #601

Merged
merged 6 commits into from
Jan 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/voting/ballot.v.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def vote(proposal: num):
@constant
def winning_proposal() -> num:
winning_vote_count = 0
winning_proposal = 0
for i in range(2):
if self.proposals[i].vote_count > winning_vote_count:
winning_vote_count = self.proposals[i].vote_count
Expand Down
81 changes: 81 additions & 0 deletions tests/parser/syntax/test_blockscope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@

import pytest
from pytest import raises

from viper import compiler
from viper.exceptions import VariableDeclarationException


fail_list = [
"""
@public
def foo(choice: bool):
if (choice):
a = 1
a += 1
""",
"""
@public
def foo(choice: bool):
if (choice):
a = 0
else:
a = 1
a += 1
""",
"""
@public
def foo(choice: bool):
if (choice):
a = 0
else:
a += 1
""",
"""
@public
def foo(choice: bool):

for i in range(4):
a = 0
a += 1
""",
"""
@public
def foo(choice: bool):

for i in range(4):
a = 0
a += 1
""",
"""
a: num

@public
def foo():
a = 5
""",
]


@pytest.mark.parametrize('bad_code', fail_list)
def test_fail_(bad_code):

with raises(VariableDeclarationException):
compiler.compile(bad_code)


valid_list = [
"""
@public
def foo(choice: bool, choice2: bool):
if (choice):
a = 11
if choice2 and a > 1:
a -= 1 # should be visible here.
"""
]


@pytest.mark.parametrize('good_code', valid_list)
def test_valid_blockscope(good_code):
assert compiler.compile(good_code) is not None
3 changes: 2 additions & 1 deletion viper/function_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@

# Function argument
class VariableRecord():
def __init__(self, name, pos, typ, mutable):
def __init__(self, name, pos, typ, mutable, blockscopes=[]):
self.name = name
self.pos = pos
self.typ = typ
self.mutable = mutable
self.blockscopes = blockscopes

@property
def size(self):
Expand Down
16 changes: 15 additions & 1 deletion viper/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,27 @@ def __init__(self, vars=None, globals=None, sigs=None, forvars=None, return_type
self.in_for_loop = set()
# Count returns in function
self.function_return_count = 0
# Current block scope
self.blockscopes = set()

def set_in_for_loop(self, name_of_list):
self.in_for_loop.add(name_of_list)

def remove_in_for_loop(self, name_of_list):
self.in_for_loop.remove(name_of_list)

def start_blockscope(self, blockscope_id):
self.blockscopes.add(blockscope_id)

def end_blockscope(self, blockscope_id):
# Remove all variables that have specific blockscope_id attached.
self.vars = {
name: var_record for name, var_record in self.vars.items()
if blockscope_id not in var_record.blockscopes
}
# Remove block scopes
self.blockscopes.remove(blockscope_id)

def increment_return_counter(self):
self.function_return_count += 1

Expand All @@ -293,7 +307,7 @@ def new_variable(self, name, typ):
raise VariableDeclarationException("Variable name invalid or reserved: " + name)
if name in self.vars or name in self.globals:
raise VariableDeclarationException("Duplicate variable name: %s" % name)
self.vars[name] = VariableRecord(name, self.next_mem, typ, True)
self.vars[name] = VariableRecord(name, self.next_mem, typ, True, self.blockscopes.copy())
pos = self.next_mem
self.next_mem += 32 * get_size_of_type(typ)
return pos
Expand Down
16 changes: 15 additions & 1 deletion viper/parser/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,21 @@ def parse_if(self):
parse_body,
)
if self.stmt.orelse:
block_scope_id = id(self.stmt.orelse)
self.context.start_blockscope(block_scope_id)
add_on = [parse_body(self.stmt.orelse, self.context)]
self.context.end_blockscope(block_scope_id)
else:
add_on = []
return LLLnode.from_list(['if', Expr.parse_value_expr(self.stmt.test, self.context), parse_body(self.stmt.body, self.context)] + add_on, typ=None, pos=getpos(self.stmt))

block_scope_id = id(self.stmt)
self.context.start_blockscope(block_scope_id)
o = LLLnode.from_list(
['if', Expr.parse_value_expr(self.stmt.test, self.context), parse_body(self.stmt.body, self.context)] + add_on,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does self.stmt.test do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DavidKnott I believe it's basically the comparison expression or statement part: if (<test>)

typ=None, pos=getpos(self.stmt)
)
self.context.end_blockscope(block_scope_id)
return o

def call(self):
from .parser import (
Expand Down Expand Up @@ -180,6 +191,8 @@ def parse_for(self):
len(self.stmt.iter.args) not in (1, 2):
raise StructureException("For statements must be of the form `for i in range(rounds): ..` or `for i in range(start, start + rounds): ..`", self.stmt.iter) # noqa

block_scope_id = id(self.stmt.orelse)
self.context.start_blockscope(block_scope_id)
# Type 1 for, eg. for i in range(10): ...
if len(self.stmt.iter.args) == 1:
if not isinstance(self.stmt.iter.args[0], ast.Num):
Expand Down Expand Up @@ -207,6 +220,7 @@ def parse_for(self):
o = LLLnode.from_list(['repeat', pos, start, rounds, parse_body(self.stmt.body, self.context)], typ=None, pos=getpos(self.stmt))
del self.context.vars[varname]
del self.context.forvars[varname]
self.context.end_blockscope(block_scope_id)
return o

def _is_list_iter(self):
Expand Down