-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_stmts.py
68 lines (48 loc) · 1.66 KB
/
eval_stmts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Modified version of https://gist.github.com/nitros12/2c3c265813121492655bc95aa54da6b9
import ast
def insert_returns(body):
# insert return stmt if the last expression is a expression statement
if isinstance(body[-1], ast.Expr):
body[-1] = ast.Return(body[-1].value)
ast.fix_missing_locations(body[-1])
# for if statements, we insert returns into the body and the orelse
if isinstance(body[-1], ast.If):
insert_returns(body[-1].body)
insert_returns(body[-1].orelse)
# for with blocks, again we insert returns into the body
if isinstance(body[-1], ast.With):
insert_returns(body[-1].body)
# for with blocks, again we insert returns into the body
if isinstance(body[-1], ast.AsyncWith):
insert_returns(body[-1].body)
async def eval_stmts(stmts, env=None):
"""
Evaluates input.
If the last statement is an expression, that is the return value.
>>> from asyncio import run
>>> run(eval_stmts("1+1"))
2
>>> ctx = {}
>>> run(eval_stmts("ctx['foo'] = 1", {"ctx": ctx}))
>>> ctx['foo']
1
>>> run(eval_stmts('''
... async def f():
... return 42
...
... await f()'''))
42
"""
parsed_stmts = ast.parse(stmts)
fn_name = "_eval_expr"
fn = f"async def {fn_name}(): pass"
parsed_fn = ast.parse(fn)
for node in parsed_stmts.body:
ast.increment_lineno(node)
insert_returns(parsed_stmts.body)
parsed_fn.body[0].body = parsed_stmts.body
exec(compile(parsed_fn, filename="<ast>", mode="exec"), env)
return await eval(f"{fn_name}()", env)
if __name__ == "__main__":
import doctest
doctest.testmod()