-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Chained function calls separated into multiple assignments #171
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
import ast | ||
|
||
|
||
class AsyncTransformer(ast.NodeTransformer): | ||
class AsyncTransformer(): | ||
"""Converts all async nodes into their synchronous counterparts.""" | ||
|
||
def visit_Await(self, node): | ||
|
@@ -16,3 +16,53 @@ def visit_AsyncFor(self, node): | |
|
||
def visit_AsyncWith(self, node): | ||
return self.visit(ast.With(**node.__dict__)) | ||
|
||
|
||
class ChainedFunctionTransformer(): | ||
def visit_chain(self, node, depth=1): | ||
if ( | ||
isinstance(node.value, ast.Call) and | ||
isinstance(node.value.func, ast.Attribute) and | ||
isinstance(node.value.func.value, ast.Call) | ||
): | ||
# Node is assignment or return with value like `b.c().d()` | ||
call_node = node.value | ||
# If we want to handle nested functions in future, depth needs fixing | ||
temp_var_id = '__chain_tmp_{}'.format(depth) | ||
# AST tree is from right to left, so d() is the outer Call and b.c() is the inner Call | ||
unvisited_inner_call = ast.Assign( | ||
targets=[ast.Name(id=temp_var_id, ctx=ast.Store())], | ||
value=call_node.func.value, | ||
) | ||
ast.copy_location(unvisited_inner_call, node) | ||
inner_calls = self.visit_chain(unvisited_inner_call, depth + 1) | ||
for inner_call_node in inner_calls: | ||
ast.copy_location(inner_call_node, node) | ||
outer_call = self.generic_visit(type(node)( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indentation nit: This is awesomely smart but more indentation may make it clearer outer_call = self.generic_visit(
type(node)(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id=temp_var_id, ctx=ast.Load()),
attr=call_node.func.attr,
ctx=ast.Load(),
),
args=call_node.args,
keywords=call_node.keywords,
),
**{field: value for field, value in ast.iter_fields(node) if field != 'value'} # e.g. targets
)
) Nit: w/r/t |
||
value=ast.Call( | ||
func=ast.Attribute( | ||
value=ast.Name(id=temp_var_id, ctx=ast.Load()), | ||
attr=call_node.func.attr, | ||
ctx=ast.Load(), | ||
), | ||
args=call_node.args, | ||
keywords=call_node.keywords, | ||
), | ||
**{field: value for field, value in ast.iter_fields(node) if field != 'value'} # e.g. targets | ||
)) | ||
ast.copy_location(outer_call, node) | ||
ast.copy_location(outer_call.value, node) | ||
ast.copy_location(outer_call.value.func, node) | ||
return [*inner_calls, outer_call] | ||
else: | ||
return [self.generic_visit(node)] | ||
|
||
def visit_Assign(self, node): | ||
return self.visit_chain(node) | ||
|
||
def visit_Return(self, node): | ||
return self.visit_chain(node) | ||
|
||
|
||
class PytTransformer(AsyncTransformer, ChainedFunctionTransformer, ast.NodeTransformer): | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,14 @@ | ||
import ast | ||
import unittest | ||
|
||
from pyt.core.transformer import AsyncTransformer | ||
from pyt.core.transformer import PytTransformer | ||
|
||
|
||
class TransformerTest(unittest.TestCase): | ||
"""Tests for the AsyncTransformer.""" | ||
|
||
def test_async_removed_by_transformer(self): | ||
self.maxDiff = 99999 | ||
async_tree = ast.parse("\n".join([ | ||
"async def a():", | ||
" async for b in c():", | ||
|
@@ -30,7 +31,24 @@ def test_async_removed_by_transformer(self): | |
])) | ||
self.assertIsInstance(sync_tree.body[0], ast.FunctionDef) | ||
|
||
transformed = AsyncTransformer().visit(async_tree) | ||
transformed = PytTransformer().visit(async_tree) | ||
self.assertIsInstance(transformed.body[0], ast.FunctionDef) | ||
|
||
self.assertEqual(ast.dump(transformed), ast.dump(sync_tree)) | ||
|
||
def test_chained_function(self): | ||
chained_tree = ast.parse("\n".join([ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great use of |
||
"def a():", | ||
" b = c.d(e).f(g).h(i).j(k)", | ||
])) | ||
|
||
separated_tree = ast.parse("\n".join([ | ||
"def a():", | ||
" __chain_tmp_3 = c.d(e)", | ||
" __chain_tmp_2 = __chain_tmp_3.f(g)", | ||
" __chain_tmp_1 = __chain_tmp_2.h(i)", | ||
" b = __chain_tmp_1.j(k)", | ||
])) | ||
|
||
transformed = PytTransformer().visit(chained_tree) | ||
self.assertEqual(ast.dump(transformed), ast.dump(separated_tree)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love how you did
node.value
node.value.func
node.value.func.value
in that order, it's super clean.