From 19cc035a48a50bcabe9753b7502576f0de3cae17 Mon Sep 17 00:00:00 2001 From: Michael Phelps Date: Fri, 25 Jun 2021 21:22:15 -0400 Subject: [PATCH] Expand globally scoped code support significantly --- README.md | 66 ++++++++++++++++++++++++++++++ examples/global_code.go | 33 +++++++++++++++ examples/global_code.py | 26 ++++++++++++ pytago/go_ast/core.py | 75 ++++++++++++++++++++++++----------- pytago/go_ast/transformers.py | 69 ++++++++++++++++++++++---------- pytago/tests/test_core.py | 3 ++ setup.py | 2 +- 7 files changed, 228 insertions(+), 46 deletions(-) create mode 100644 examples/global_code.go create mode 100644 examples/global_code.py diff --git a/README.md b/README.md index 245e2c0..36a753d 100644 --- a/README.md +++ b/README.md @@ -768,6 +768,72 @@ func main() { }(a, b)) } ``` +### global_code +#### Python +```python +A = [1, 2, 3] + +for i, x in enumerate(A): + A[i] += x + +B = A[0] +C = A[0] +D: int = 3 + +while C < A[2]: + C += 1 + +if C == A[2]: + print('True') + + +def main(): + print("Main started") + print(A) + print(B) + print(C) + print(D) + + +if __name__ == '__main__': + main() +``` +#### Go +```go +package main + +import "fmt" + +var ( + A = []int{1, 2, 3} + B int + C int + D int +) + +func init() { + for i, x := range A { + A[i] += x + } + B = A[0] + C = A[0] + D = 3 + for C < A[2] { + C += 1 + } + if C == A[2] { + fmt.Println("True") + } +} + +func main() { + fmt.Println("Main started") + fmt.Println(A) + fmt.Println(B) + fmt.Println(C) + fmt.Println(D) +} +``` ### boolnumcompare #### Python ```python diff --git a/examples/global_code.go b/examples/global_code.go new file mode 100644 index 0000000..bbdaf6b --- /dev/null +++ b/examples/global_code.go @@ -0,0 +1,33 @@ +package main + +import "fmt" + +var ( + A = []int{1, 2, 3} + B int + C int + D int +) + +func init() { + for i, x := range A { + A[i] += x + } + B = A[0] + C = A[0] + D = 3 + for C < A[2] { + C += 1 + } + if C == A[2] { + fmt.Println("True") + } +} + +func main() { + fmt.Println("Main started") + fmt.Println(A) + fmt.Println(B) + fmt.Println(C) + fmt.Println(D) +} diff --git a/examples/global_code.py b/examples/global_code.py new file mode 100644 index 0000000..dbcbbe6 --- /dev/null +++ b/examples/global_code.py @@ -0,0 +1,26 @@ +A = [1, 2, 3] + +for i, x in enumerate(A): + A[i] += x + +B = A[0] +C = A[0] +D: int = 3 + +while C < A[2]: + C += 1 + +if C == A[2]: + print('True') + + +def main(): + print("Main started") + print(A) + print(B) + print(C) + print(D) + + +if __name__ == '__main__': + main() diff --git a/pytago/go_ast/core.py b/pytago/go_ast/core.py index b823a0f..6b35834 100644 --- a/pytago/go_ast/core.py +++ b/pytago/go_ast/core.py @@ -384,7 +384,16 @@ def build_stmt_list(nodes, **kwargs) -> list['Stmt']: def build_decl_list(nodes) -> list['Decl']: - return _build_x_list(_DECL_TYPES, "Decl", nodes) + decls = [] + global_mode = False + for node in nodes: + try: + decls += _build_x_list(_DECL_TYPES, "Decl", [node], global_mode=global_mode) + except ValueError: + decls.append(FuncDecl.from_global_code(node)) + # From this point on, all variable assignmnts must be handled inside an init. + global_mode = True + return decls import sys import types @@ -415,10 +424,15 @@ class GoAST(ast.AST): _prefix = "ast." - def __init__(self, **kwargs): + def __init__(self, parents=None, **kwargs): super().__init__(**kwargs) - GoAST.STORY.append(self) - self.STORY_INDEX = len(GoAST.STORY) + self.parents = parents or [] + for field_name in self._fields: + field = getattr(self, field_name, None) + if isinstance(field, GoAST): + field.parents.append(self) + # GoAST.STORY.append(self) + # self.STORY_INDEX = len(GoAST.STORY) # self.TRACE = exception_with_traceback() # Debugging def remove_falsy_fields(self): @@ -441,9 +455,9 @@ def __repr__(self): class Expr(GoAST): def __init__(self, *args, _type_help=None, _py_context=None, **kwargs): - super().__init__(**kwargs) self._type_help = _type_help self._py_context = _py_context or {} + super().__init__(**kwargs) def __or__(self, Y: 'Expr') -> 'BinaryExpr': """ @@ -2459,24 +2473,15 @@ def from_AsyncFunctionDef(cls, node: ast.AsyncFunctionDef, **kwargs): recv = None return cls(body, doc, name, recv, _type, **kwargs) - # All of these simply throw the code in a function titled _ to be swept up later @classmethod - def from_BadDecl(cls, node: ast.AST, **kwargs): + def from_global_code(cls, node: ast.AST, **kwargs): body = from_this(BlockStmt, [node]) doc = None - name = from_this(Ident, "_") + name = from_this(Ident, "init") recv = None _type = FuncType(0, FieldList(0, []), FieldList(0, [])) return cls(body, doc, name, recv, _type, **kwargs) - @classmethod - def from_If(cls, node: ast.If, **kwargs): - return cls.from_BadDecl(node, **kwargs) - - @classmethod - def from_Expr(cls, node: ast.If, **kwargs): - return cls.from_BadDecl(node, **kwargs) - class FuncLit(Expr): """A FuncLit node represents a function literal.""" @@ -2629,16 +2634,38 @@ def _matcher(_node: 'ast.AST'): return decls @classmethod - def from_Assign(cls, node: ast.Assign, **kwargs): + def from_Assign(cls, node: ast.Assign, global_mode=False, **kwargs): values = build_expr_list([node.value]) - specs = [ValueSpec(Values=values, Names=build_expr_list(node.targets))] - return cls(Tok=token.VAR, Specs=specs, **kwargs) - - @classmethod - def from_AnnAssign(cls, node: ast.AnnAssign, **kwargs): + spec = ValueSpec(Values=values, Names=build_expr_list(node.targets)) + if global_mode: + # If we've already initialized some code via init methods, + # to get close to the behavior of python we need to initialize + # all subsequent vars via init as well + _type = None + _desperate_type = None + values = values.copy() + while spec.Values: + value = spec.Values.pop(0) + _type = _type or value._type() + if not _type: + _desperate_type = _desperate_type or value._type(interface_ok=True) + spec.Type = _type or InterfaceType(_py_context={"elts": values}) + init = FuncDecl.from_global_code(node) + init.Body.List[0].Tok = token.ASSIGN + return [init, cls(Tok=token.VAR, Specs=[spec], **kwargs)] + return cls(Tok=token.VAR, Specs=[spec], **kwargs) + + @classmethod + def from_AnnAssign(cls, node: ast.AnnAssign, global_mode=False, **kwargs): values = build_expr_list([node.value]) if node.value else None - specs = [ValueSpec(Values=values, Type=_type_annotation_to_go_type(node.annotation), Names=[from_this(Ident, node.target)])] - return cls(Tok=token.VAR, Specs=specs, **kwargs) + spec = ValueSpec(Values=values, Type=_type_annotation_to_go_type(node.annotation), + Names=[from_this(Ident, node.target)]) + if values and global_mode: + init = FuncDecl.from_global_code(ast.Assign(targets=[node.target], value=node.value)) + init.Body.List[0].Tok = token.ASSIGN + spec.Values = [] + return [init, cls(Tok=token.VAR, Specs=[spec], **kwargs)] + return cls(Tok=token.VAR, Specs=[spec], **kwargs) @classmethod def from_ImportSpec(cls, node: ImportSpec, **kwargs): diff --git a/pytago/go_ast/transformers.py b/pytago/go_ast/transformers.py index 634f745..96e71f2 100644 --- a/pytago/go_ast/transformers.py +++ b/pytago/go_ast/transformers.py @@ -9,7 +9,7 @@ ast_snippets, MapType, ValueSpec, Expr, BadStmt, SendStmt, len_ # Shortcuts from pytago.go_ast.core import _find_nodes, GoAST, ChanType, StructType, InterfaceType, BadExpr, OP_COMPLIMENTS, \ - GoStmt, TypeSwitchStmt, StarExpr + GoStmt, TypeSwitchStmt, StarExpr, GenDecl v = Ident.from_str @@ -124,23 +124,24 @@ def visit_CallExpr(self, node: CallExpr): return node -class RemoveOrphanedFunctions(BaseTransformer): +class RemoveIfNameEqualsMain(BaseTransformer): STAGE = 2 REPEATABLE = False - """ - Orphaned code is placed in functions titled "_" -- later we may want to try to put such code in the main - method or elsewhere, but for now we'll remove it. - """ - def visit_File(self, node: File): - self.generic_visit(node) - to_delete = [] - for decl in node.Decls: - match decl: - case FuncDecl(Name=Ident(Name="_")): - to_delete.append(decl) - for decl in to_delete: - node.Decls.remove(decl) + # TODO: in the future we want to support adding stuff from under + # "if __name__ == "__main__" to the main function + # Really, it should become the main function unless there already is one + # and it's calling it... + + def visit_FuncDecl(self, node: FuncDecl): + match node.Name: + case Ident(Name="init"): + match node.Body: + case BlockStmt(List=[IfStmt(Cond=cond)]): + match cond: + case BinaryExpr(X=Ident(Name="__name__"), Op=token.EQL, Y=BasicLit(Value='"__main__"')) | \ + BinaryExpr(X=Ident(Name="__main__"), Op=token.EQL, Y=BasicLit(Value='"__name__"')): + return return node @@ -203,6 +204,14 @@ def __init__(self, scope=None): self.current_nonlocals = [] self.missing_type_info = [] + def should_apply_new_scope(self, value): + match value: + case FuncDecl(Name="init"): + return False + case BlockStmt(parents=[FuncDecl(Name="init")]): + return False + return isinstance(value, Stmt) and not isinstance(value, (AssignStmt, ValueSpec)) + def generic_visit(self, node): self.stack.append(node) prev_globals = self.current_globals @@ -224,7 +233,7 @@ def generic_visit(self, node): case ast.Nonlocal(names=names): self.current_nonlocals += names if isinstance(value, AST): - new_scope = isinstance(value, Stmt) and not isinstance(value, (AssignStmt, ValueSpec)) + new_scope = self.should_apply_new_scope(value) if new_scope: self.scope = Scope({}, self.scope) value = self.visit(value) @@ -238,7 +247,7 @@ def generic_visit(self, node): new_values.append(value) old_value[:] = new_values elif isinstance(old_value, AST): - new_scope = isinstance(old_value, Stmt) and not isinstance(old_value, (AssignStmt, ValueSpec)) + new_scope = self.should_apply_new_scope(old_value) if new_scope: self.scope = Scope({}, self.scope) new_node = self.visit(old_value) @@ -1599,6 +1608,25 @@ def exit_callback(*args, **kwargs): def generic_missing_type_callback(self, node: Expr, val: Expr, type_: Expr): return +class MergeAdjacentInits(BaseTransformer): + def visit_File(self, node: File): + decls = node.Decls + i = 0 + while i < (len(decls) - 1): + cur = decls[i] + nxt = decls[i+1] + match cur, nxt: + case FuncDecl(Name=Ident(Name="init")), FuncDecl(Name=Ident(Name="init")): + cur.Body.List += nxt.Body.List + del decls[i+1] + # Also try to group variable declarations/imports above init methods + case FuncDecl(Name=Ident(Name="init")), GenDecl(): + decls[i], decls[i+1] = decls[i+1], decls[i] + i += 1 + case _: + i += 1 + return node + class LoopThroughSetValuesNotKeys(NodeTransformerWithScope): def visit_RangeStmt(self, node: RangeStmt): self.generic_visit(node) @@ -1622,9 +1650,7 @@ def visit_RangeStmt(self, node: RangeStmt): PythonToGoTypes, UnpackRange, RangeRangeToFor, - - # Scope transformers - SpecialComparators, + SpecialComparators, # First scope transformer YieldTransformer, # May need to be above other scope transformers because jank, YieldRangeTransformer, ReplacePowWithMathPow, @@ -1652,7 +1678,8 @@ def visit_RangeStmt(self, node: RangeStmt): NodeTransformerWithInterfaceTypes, RemoveGoCallReturns, RemoveBadStmt, # Should be last as these are used for scoping + MergeAdjacentInits, #### STAGE 2 #### - RemoveOrphanedFunctions + RemoveIfNameEqualsMain ] \ No newline at end of file diff --git a/pytago/tests/test_core.py b/pytago/tests/test_core.py index 8a3b7b3..2abf9c6 100644 --- a/pytago/tests/test_core.py +++ b/pytago/tests/test_core.py @@ -37,6 +37,9 @@ def test_list_methods(self): def test_set_methods(self): self.assert_examples_match("set_methods") + def test_global_code(self): + self.assert_examples_match("global_code") + def test_boolnumcompare(self): self.assert_examples_match("boolnumcompare") diff --git a/setup.py b/setup.py index 4b68d55..403bc25 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='pytago', - version='0.0.7', + version='0.0.8', packages=['pytago'], url='https://github.com/nottheswimmer/pytago', license='',