Skip to content

Commit

Permalink
Add async stmt lambdas
Browse files Browse the repository at this point in the history
Resolves   #644.
  • Loading branch information
evhub committed Nov 2, 2022
1 parent f0a2a2c commit 8080679
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 34 deletions.
2 changes: 1 addition & 1 deletion DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -1455,7 +1455,7 @@ The statement lambda syntax is an extension of the [normal lambda syntax](#lambd

The syntax for a statement lambda is
```
def (arguments) -> statement; statement; ...
[async] [match] def (arguments) -> statement; statement; ...
```
where `arguments` can be standard function arguments or [pattern-matching function definition](#pattern-matching-functions) arguments and `statement` can be an assignment statement or a keyword statement. If the last `statement` (not followed by a semicolon) is an `expression`, it will automatically be returned.

Expand Down
44 changes: 29 additions & 15 deletions coconut/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1743,7 +1743,7 @@ def transform_returns(self, original, loc, raw_lines, tre_return_grammar=None, i
func_code = "".join(lines)
return func_code, tco, tre

def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method):
def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method, is_stmt_lambda):
"""Determines if TCO or TRE can be done and if so does it,
handles dotted function names, and universalizes async functions."""
# process tokens
Expand Down Expand Up @@ -1864,8 +1864,8 @@ def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method):
attempt_tre = (
func_name is not None
and not is_gen
# tre does not work with methods or decorators (though tco does)
and not in_method
and not is_stmt_lambda
and not decorators
)
if attempt_tre:
Expand Down Expand Up @@ -2003,7 +2003,7 @@ def deferred_code_proc(self, inputstring, add_code_at_start=False, ignore_names=
# look for functions
if line.startswith(funcwrapper):
func_id = int(line[len(funcwrapper):])
original, loc, decorators, funcdef, is_async, in_method = self.get_ref("func", func_id)
original, loc, decorators, funcdef, is_async, in_method, is_stmt_lambda = self.get_ref("func", func_id)

# process inner code
decorators = self.deferred_code_proc(decorators, add_code_at_start=True, ignore_names=ignore_names, **kwargs)
Expand All @@ -2022,7 +2022,7 @@ def deferred_code_proc(self, inputstring, add_code_at_start=False, ignore_names=

out.append(bef_ind)
out.extend(pre_def_lines)
out.append(self.proc_funcdef(original, loc, decorators, "".join(post_def_lines), is_async, in_method))
out.append(self.proc_funcdef(original, loc, decorators, "".join(post_def_lines), is_async, in_method, is_stmt_lambda))
out.append(aft_ind)

# look for add_code_before regexes
Expand Down Expand Up @@ -3071,43 +3071,57 @@ def set_letter_literal_handle(self, tokens):
def stmt_lambdef_handle(self, original, loc, tokens):
"""Process multi-line lambdef statements."""
if len(tokens) == 2:
params, stmts = tokens
params, stmts_toks = tokens
is_async = False
elif len(tokens) == 3:
params, stmts, last = tokens
if "tests" in tokens:
async_kwd, params, stmts_toks = tokens
internal_assert(async_kwd == "async", "invalid stmt lambdef async kwd", async_kwd)
is_async = True
else:
raise CoconutInternalException("invalid statement lambda tokens", tokens)

if len(stmts_toks) == 1:
stmts, = stmts_toks
elif len(stmts_toks) == 2:
stmts, last = stmts_toks
if "tests" in stmts_toks:
stmts = stmts.asList() + ["return " + last]
else:
stmts = stmts.asList() + [last]
else:
raise CoconutInternalException("invalid statement lambda tokens", tokens)
raise CoconutInternalException("invalid statement lambda body tokens", stmts_toks)

name = self.get_temp_var("lambda")
body = openindent + "\n".join(stmts) + closeindent

if isinstance(params, str):
self.add_code_before[name] = "def " + name + params + ":\n" + body
decorators = ""
funcdef = "def " + name + params + ":\n" + body
else:
match_tokens = [name] + list(params)
before_colon, after_docstring = self.name_match_funcdef_handle(original, loc, match_tokens)
self.add_code_before[name] = (
"@_coconut_mark_as_match\n"
+ before_colon
decorators = "@_coconut_mark_as_match\n"
funcdef = (
before_colon
+ ":\n"
+ after_docstring
+ body
)

self.add_code_before[name] = self.decoratable_funcdef_stmt_handle(original, loc, [decorators, funcdef], is_async, is_stmt_lambda=True)

return name

def decoratable_funcdef_stmt_handle(self, original, loc, tokens, is_async=False):
def decoratable_funcdef_stmt_handle(self, original, loc, tokens, is_async=False, is_stmt_lambda=False):
"""Wraps the given function for later processing"""
if len(tokens) == 1:
decorators, funcdef = "", tokens[0]
funcdef, = tokens
decorators = ""
elif len(tokens) == 2:
decorators, funcdef = tokens
else:
raise CoconutInternalException("invalid function definition tokens", tokens)
return funcwrapper + self.add_ref("func", (original, loc, decorators, funcdef, is_async, self.in_method)) + "\n"
return funcwrapper + self.add_ref("func", (original, loc, decorators, funcdef, is_async, self.in_method, is_stmt_lambda)) + "\n"

def await_expr_handle(self, original, loc, tokens):
"""Check for Python 3.5 await expression."""
Expand Down
38 changes: 24 additions & 14 deletions coconut/compiler/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,12 +1441,13 @@ class Grammar(object):
| stmt_lambdef_match_params,
default="(_=None)",
)
stmt_lambdef_body = (
stmt_lambdef_body = Group(
Group(OneOrMore(simple_stmt_item + semicolon.suppress())) + Optional(closing_stmt)
| Group(ZeroOrMore(simple_stmt_item + semicolon.suppress())) + closing_stmt
| Group(ZeroOrMore(simple_stmt_item + semicolon.suppress())) + closing_stmt,
)
general_stmt_lambdef = (
keyword("def").suppress()
Optional(async_kwd)
+ keyword("def").suppress()
+ stmt_lambdef_params
+ arrow.suppress()
+ stmt_lambdef_body
Expand All @@ -1458,7 +1459,16 @@ class Grammar(object):
+ arrow.suppress()
+ stmt_lambdef_body
)
stmt_lambdef_ref = general_stmt_lambdef | match_stmt_lambdef
async_match_stmt_lambdef = (
any_len_perm(
match_kwd.suppress(),
required=(async_kwd,),
) + keyword("def").suppress()
+ stmt_lambdef_match_params
+ arrow.suppress()
+ stmt_lambdef_body
)
stmt_lambdef_ref = general_stmt_lambdef | match_stmt_lambdef | async_match_stmt_lambdef

lambdef <<= addspace(lambdef_base + test) | stmt_lambdef
lambdef_no_cond = trace(addspace(lambdef_base + test_no_cond))
Expand Down Expand Up @@ -1970,17 +1980,17 @@ class Grammar(object):
match_kwd.suppress(),
# we don't suppress addpattern so its presence can be detected later
addpattern_kwd,
# makes async required
(1, async_kwd.suppress()),
required=(async_kwd.suppress(),),
) + (def_match_funcdef | math_match_funcdef),
),
)
async_yield_funcdef = attach(
trace(
any_len_perm(
# makes both required
(1, async_kwd.suppress()),
(2, keyword("yield").suppress()),
required=(
async_kwd.suppress(),
keyword("yield").suppress(),
),
) + (funcdef | math_funcdef),
),
yield_funcdef_handle,
Expand All @@ -1992,9 +2002,10 @@ class Grammar(object):
match_kwd.suppress(),
# we don't suppress addpattern so its presence can be detected later
addpattern_kwd,
# makes both required
(1, async_kwd.suppress()),
(2, keyword("yield").suppress()),
required=(
async_kwd.suppress(),
keyword("yield").suppress(),
),
) + (def_match_funcdef | math_match_funcdef),
),
),
Expand All @@ -2014,8 +2025,7 @@ class Grammar(object):
match_kwd.suppress(),
# we don't suppress addpattern so its presence can be detected later
addpattern_kwd,
# makes yield required
(1, keyword("yield").suppress()),
required=(keyword("yield").suppress(),),
) + (def_match_funcdef | math_match_funcdef),
),
)
Expand Down
2 changes: 1 addition & 1 deletion coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _coconut_tco(func):{COMMENT._coconut_tco_func_attr_is_used_in_main_coco}
wkref_func = None if wkref is None else wkref()
if wkref_func is call_func:
call_func = call_func._coconut_tco_func
result = call_func(*args, **kwargs) # pass --no-tco to clean up your traceback
result = call_func(*args, **kwargs) # use coconut --no-tco to clean up your traceback
if not isinstance(result, _coconut_tail_call):
return result
call_func, args, kwargs = result.func, result.args, result.kwargs
Expand Down
13 changes: 12 additions & 1 deletion coconut/compiler/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def keyword(name, explicit_prefix=None, require_whitespace=False):
boundary = regex_item(r"\b")


def any_len_perm(*groups_and_elems):
def any_len_perm_with_one_of_each_group(*groups_and_elems):
"""Matches any len permutation of elems that contains at least one of each group."""
elems = []
groups = defaultdict(list)
Expand Down Expand Up @@ -850,6 +850,17 @@ def any_len_perm(*groups_and_elems):
return out


def any_len_perm(*optional, **kwargs):
"""Any length permutation of optional and required."""
required = kwargs.pop("required", ())
internal_assert(not kwargs, "invalid any_len_perm kwargs", kwargs)

groups_and_elems = []
groups_and_elems.extend(optional)
groups_and_elems.extend(enumerate(required))
return any_len_perm_with_one_of_each_group(*groups_and_elems)


# -----------------------------------------------------------------------------------------------------------------------
# UTILITIES:
# -----------------------------------------------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "2.1.0"
VERSION_NAME = "The Spanish Inquisition"
# False for release, int >= 1 for develop
DEVELOP = 18
DEVELOP = 19
ALPHA = False # for pre releases rather than post releases

# -----------------------------------------------------------------------------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions coconut/tests/src/cocotest/target_sys/target_sys_test.coco
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def asyncio_test() -> bool:
async def main():
assert await async_map_test()
assert `(+)$(1) .. await aplus 1` 1 == 3
assert await (async def (x, y) -> x + y)(1, 2) == 3
assert await (async def (int(x), int(y)) -> x + y)(1, 2) == 3
assert await (async match def (int(x), int(y)) -> x + y)(1, 2) == 3
assert await (match async def (int(x), int(y)) -> x + y)(1, 2) == 3

loop = asyncio.new_event_loop()
loop.run_until_complete(main())
Expand Down
2 changes: 1 addition & 1 deletion coconut/tests/src/extras.coco
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_setup_none() -> bool:

# things that don't parse correctly without the computation graph
if not PYPY:
exec(parse("assert (1,2,3,4) == ([1, 2], [3, 4]) |*> def (x, y) -> *x, *y"))
exec(parse("assert (1,2,3,4) == ([1, 2], [3, 4]) |*> def (x, y) -> *x, *y"), {})

assert_raises(-> parse("(a := b)"), CoconutTargetError)
assert_raises(-> parse("async def f() = 1"), CoconutTargetError)
Expand Down

0 comments on commit 8080679

Please sign in to comment.