Skip to content

Commit

Permalink
Expand globally scoped code support significantly
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Phelps committed Jun 26, 2021
1 parent eb3da84 commit 19cc035
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 46 deletions.
66 changes: 66 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,72 @@ func main() {
}(a, b))
}
```
### global_code
#### Python
```python
A = [1, 2, 3]

for i, x in enumerate(A):
A[i] += x

B = A[0]
C = A[0]
D: int = 3

while C < A[2]:
C += 1

if C == A[2]:
print('True')


def main():
print("Main started")
print(A)
print(B)
print(C)
print(D)


if __name__ == '__main__':
main()
```
#### Go
```go
package main

import "fmt"

var (
A = []int{1, 2, 3}
B int
C int
D int
)

func init() {
for i, x := range A {
A[i] += x
}
B = A[0]
C = A[0]
D = 3
for C < A[2] {
C += 1
}
if C == A[2] {
fmt.Println("True")
}
}

func main() {
fmt.Println("Main started")
fmt.Println(A)
fmt.Println(B)
fmt.Println(C)
fmt.Println(D)
}
```
### boolnumcompare
#### Python
```python
Expand Down
33 changes: 33 additions & 0 deletions examples/global_code.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package main

import "fmt"

var (
A = []int{1, 2, 3}
B int
C int
D int
)

func init() {
for i, x := range A {
A[i] += x
}
B = A[0]
C = A[0]
D = 3
for C < A[2] {
C += 1
}
if C == A[2] {
fmt.Println("True")
}
}

func main() {
fmt.Println("Main started")
fmt.Println(A)
fmt.Println(B)
fmt.Println(C)
fmt.Println(D)
}
26 changes: 26 additions & 0 deletions examples/global_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
A = [1, 2, 3]

for i, x in enumerate(A):
A[i] += x

B = A[0]
C = A[0]
D: int = 3

while C < A[2]:
C += 1

if C == A[2]:
print('True')


def main():
print("Main started")
print(A)
print(B)
print(C)
print(D)


if __name__ == '__main__':
main()
75 changes: 51 additions & 24 deletions pytago/go_ast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,16 @@ def build_stmt_list(nodes, **kwargs) -> list['Stmt']:


def build_decl_list(nodes) -> list['Decl']:
return _build_x_list(_DECL_TYPES, "Decl", nodes)
decls = []
global_mode = False
for node in nodes:
try:
decls += _build_x_list(_DECL_TYPES, "Decl", [node], global_mode=global_mode)
except ValueError:
decls.append(FuncDecl.from_global_code(node))
# From this point on, all variable assignmnts must be handled inside an init.
global_mode = True
return decls

import sys
import types
Expand Down Expand Up @@ -415,10 +424,15 @@ class GoAST(ast.AST):

_prefix = "ast."

def __init__(self, **kwargs):
def __init__(self, parents=None, **kwargs):
super().__init__(**kwargs)
GoAST.STORY.append(self)
self.STORY_INDEX = len(GoAST.STORY)
self.parents = parents or []
for field_name in self._fields:
field = getattr(self, field_name, None)
if isinstance(field, GoAST):
field.parents.append(self)
# GoAST.STORY.append(self)
# self.STORY_INDEX = len(GoAST.STORY)
# self.TRACE = exception_with_traceback() # Debugging

def remove_falsy_fields(self):
Expand All @@ -441,9 +455,9 @@ def __repr__(self):

class Expr(GoAST):
def __init__(self, *args, _type_help=None, _py_context=None, **kwargs):
super().__init__(**kwargs)
self._type_help = _type_help
self._py_context = _py_context or {}
super().__init__(**kwargs)

def __or__(self, Y: 'Expr') -> 'BinaryExpr':
"""
Expand Down Expand Up @@ -2459,24 +2473,15 @@ def from_AsyncFunctionDef(cls, node: ast.AsyncFunctionDef, **kwargs):
recv = None
return cls(body, doc, name, recv, _type, **kwargs)

# All of these simply throw the code in a function titled _ to be swept up later
@classmethod
def from_BadDecl(cls, node: ast.AST, **kwargs):
def from_global_code(cls, node: ast.AST, **kwargs):
body = from_this(BlockStmt, [node])
doc = None
name = from_this(Ident, "_")
name = from_this(Ident, "init")
recv = None
_type = FuncType(0, FieldList(0, []), FieldList(0, []))
return cls(body, doc, name, recv, _type, **kwargs)

@classmethod
def from_If(cls, node: ast.If, **kwargs):
return cls.from_BadDecl(node, **kwargs)

@classmethod
def from_Expr(cls, node: ast.If, **kwargs):
return cls.from_BadDecl(node, **kwargs)


class FuncLit(Expr):
"""A FuncLit node represents a function literal."""
Expand Down Expand Up @@ -2629,16 +2634,38 @@ def _matcher(_node: 'ast.AST'):
return decls

@classmethod
def from_Assign(cls, node: ast.Assign, **kwargs):
def from_Assign(cls, node: ast.Assign, global_mode=False, **kwargs):
values = build_expr_list([node.value])
specs = [ValueSpec(Values=values, Names=build_expr_list(node.targets))]
return cls(Tok=token.VAR, Specs=specs, **kwargs)

@classmethod
def from_AnnAssign(cls, node: ast.AnnAssign, **kwargs):
spec = ValueSpec(Values=values, Names=build_expr_list(node.targets))
if global_mode:
# If we've already initialized some code via init methods,
# to get close to the behavior of python we need to initialize
# all subsequent vars via init as well
_type = None
_desperate_type = None
values = values.copy()
while spec.Values:
value = spec.Values.pop(0)
_type = _type or value._type()
if not _type:
_desperate_type = _desperate_type or value._type(interface_ok=True)
spec.Type = _type or InterfaceType(_py_context={"elts": values})
init = FuncDecl.from_global_code(node)
init.Body.List[0].Tok = token.ASSIGN
return [init, cls(Tok=token.VAR, Specs=[spec], **kwargs)]
return cls(Tok=token.VAR, Specs=[spec], **kwargs)

@classmethod
def from_AnnAssign(cls, node: ast.AnnAssign, global_mode=False, **kwargs):
values = build_expr_list([node.value]) if node.value else None
specs = [ValueSpec(Values=values, Type=_type_annotation_to_go_type(node.annotation), Names=[from_this(Ident, node.target)])]
return cls(Tok=token.VAR, Specs=specs, **kwargs)
spec = ValueSpec(Values=values, Type=_type_annotation_to_go_type(node.annotation),
Names=[from_this(Ident, node.target)])
if values and global_mode:
init = FuncDecl.from_global_code(ast.Assign(targets=[node.target], value=node.value))
init.Body.List[0].Tok = token.ASSIGN
spec.Values = []
return [init, cls(Tok=token.VAR, Specs=[spec], **kwargs)]
return cls(Tok=token.VAR, Specs=[spec], **kwargs)

@classmethod
def from_ImportSpec(cls, node: ImportSpec, **kwargs):
Expand Down
69 changes: 48 additions & 21 deletions pytago/go_ast/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ast_snippets, MapType, ValueSpec, Expr, BadStmt, SendStmt, len_
# Shortcuts
from pytago.go_ast.core import _find_nodes, GoAST, ChanType, StructType, InterfaceType, BadExpr, OP_COMPLIMENTS, \
GoStmt, TypeSwitchStmt, StarExpr
GoStmt, TypeSwitchStmt, StarExpr, GenDecl

v = Ident.from_str

Expand Down Expand Up @@ -124,23 +124,24 @@ def visit_CallExpr(self, node: CallExpr):
return node


class RemoveOrphanedFunctions(BaseTransformer):
class RemoveIfNameEqualsMain(BaseTransformer):
STAGE = 2
REPEATABLE = False
"""
Orphaned code is placed in functions titled "_" -- later we may want to try to put such code in the main
method or elsewhere, but for now we'll remove it.
"""

def visit_File(self, node: File):
self.generic_visit(node)
to_delete = []
for decl in node.Decls:
match decl:
case FuncDecl(Name=Ident(Name="_")):
to_delete.append(decl)
for decl in to_delete:
node.Decls.remove(decl)
# TODO: in the future we want to support adding stuff from under
# "if __name__ == "__main__" to the main function
# Really, it should become the main function unless there already is one
# and it's calling it...

def visit_FuncDecl(self, node: FuncDecl):
match node.Name:
case Ident(Name="init"):
match node.Body:
case BlockStmt(List=[IfStmt(Cond=cond)]):
match cond:
case BinaryExpr(X=Ident(Name="__name__"), Op=token.EQL, Y=BasicLit(Value='"__main__"')) | \
BinaryExpr(X=Ident(Name="__main__"), Op=token.EQL, Y=BasicLit(Value='"__name__"')):
return
return node


Expand Down Expand Up @@ -203,6 +204,14 @@ def __init__(self, scope=None):
self.current_nonlocals = []
self.missing_type_info = []

def should_apply_new_scope(self, value):
match value:
case FuncDecl(Name="init"):
return False
case BlockStmt(parents=[FuncDecl(Name="init")]):
return False
return isinstance(value, Stmt) and not isinstance(value, (AssignStmt, ValueSpec))

def generic_visit(self, node):
self.stack.append(node)
prev_globals = self.current_globals
Expand All @@ -224,7 +233,7 @@ def generic_visit(self, node):
case ast.Nonlocal(names=names):
self.current_nonlocals += names
if isinstance(value, AST):
new_scope = isinstance(value, Stmt) and not isinstance(value, (AssignStmt, ValueSpec))
new_scope = self.should_apply_new_scope(value)
if new_scope:
self.scope = Scope({}, self.scope)
value = self.visit(value)
Expand All @@ -238,7 +247,7 @@ def generic_visit(self, node):
new_values.append(value)
old_value[:] = new_values
elif isinstance(old_value, AST):
new_scope = isinstance(old_value, Stmt) and not isinstance(old_value, (AssignStmt, ValueSpec))
new_scope = self.should_apply_new_scope(old_value)
if new_scope:
self.scope = Scope({}, self.scope)
new_node = self.visit(old_value)
Expand Down Expand Up @@ -1599,6 +1608,25 @@ def exit_callback(*args, **kwargs):
def generic_missing_type_callback(self, node: Expr, val: Expr, type_: Expr):
return

class MergeAdjacentInits(BaseTransformer):
def visit_File(self, node: File):
decls = node.Decls
i = 0
while i < (len(decls) - 1):
cur = decls[i]
nxt = decls[i+1]
match cur, nxt:
case FuncDecl(Name=Ident(Name="init")), FuncDecl(Name=Ident(Name="init")):
cur.Body.List += nxt.Body.List
del decls[i+1]
# Also try to group variable declarations/imports above init methods
case FuncDecl(Name=Ident(Name="init")), GenDecl():
decls[i], decls[i+1] = decls[i+1], decls[i]
i += 1
case _:
i += 1
return node

class LoopThroughSetValuesNotKeys(NodeTransformerWithScope):
def visit_RangeStmt(self, node: RangeStmt):
self.generic_visit(node)
Expand All @@ -1622,9 +1650,7 @@ def visit_RangeStmt(self, node: RangeStmt):
PythonToGoTypes,
UnpackRange,
RangeRangeToFor,

# Scope transformers
SpecialComparators,
SpecialComparators, # First scope transformer
YieldTransformer, # May need to be above other scope transformers because jank,
YieldRangeTransformer,
ReplacePowWithMathPow,
Expand Down Expand Up @@ -1652,7 +1678,8 @@ def visit_RangeStmt(self, node: RangeStmt):
NodeTransformerWithInterfaceTypes,
RemoveGoCallReturns,
RemoveBadStmt, # Should be last as these are used for scoping
MergeAdjacentInits,

#### STAGE 2 ####
RemoveOrphanedFunctions
RemoveIfNameEqualsMain
]
3 changes: 3 additions & 0 deletions pytago/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def test_list_methods(self):
def test_set_methods(self):
self.assert_examples_match("set_methods")

def test_global_code(self):
self.assert_examples_match("global_code")

def test_boolnumcompare(self):
self.assert_examples_match("boolnumcompare")

Expand Down
Loading

0 comments on commit 19cc035

Please sign in to comment.