Skip to content

Commit

Permalink
Improve pattern-matching
Browse files Browse the repository at this point in the history
Resolves   #847, #848.
  • Loading branch information
evhub committed Jul 26, 2024
1 parent 7ea0163 commit 3d7577e
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 22 deletions.
6 changes: 4 additions & 2 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,7 @@ base_pattern ::= (
| NAME "(" patterns ")" # classes or data types
| "data" NAME "(" patterns ")" # data types
| "class" NAME "(" patterns ")" # classes
| "(" name "=" pattern ... ")" # anonymous named tuples
| "{" pattern_pairs # dictionaries
["," "**" (NAME | "{}")] "}" # (keys must be constants or equality checks)
| ["s" | "f" | "m"] "{"
Expand Down Expand Up @@ -1269,7 +1270,8 @@ base_pattern ::= (
- Classes or Data Types (`<name>(<args>)`): will match as a data type if given [a Coconut `data` type](#data) (or a tuple of Coconut data types) and a class otherwise.
- Data Types (`data <name>(<args>)`): will check that whatever is in that position is of data type `<name>` and will match the attributes to `<args>`. Generally, `data <name>(<args>)` will match any data type that could have been constructed with `makedata(<name>, <args>)`. Includes support for positional arguments, named arguments, default arguments, and starred arguments. Also supports strict attributes by prepending a dot to the attribute name that raises `AttributError` if the attribute is not present rather than failing the match (e.g. `data MyData(.my_attr=<some_pattern>)`).
- Classes (`class <name>(<args>)`): does [PEP-634-style class matching](https://www.python.org/dev/peps/pep-0634/#class-patterns). Also supports strict attribute matching as above.
- Mapping Destructuring:
- Anonymous Named Tuples (`(<name>=<pattern>, ...)`): checks that the object is a `tuple` of the given length with the given attributes. For matching [anonymous `namedtuple`s](#anonymous-namedtuples).
- Dict Destructuring:
- Dicts (`{<key>: <value>, ...}`): will match any mapping (`collections.abc.Mapping`) with the given keys and values that match the value patterns. Keys must be constants or equality checks.
- Dicts With Rest (`{<pairs>, **<rest>}`): will match a mapping (`collections.abc.Mapping`) containing all the `<pairs>`, and will put a `dict` of everything else into `<rest>`. If `<rest>` is `{}`, will enforce that the mapping is exactly the same length as `<pairs>`.
- Set Destructuring:
Expand Down Expand Up @@ -2233,7 +2235,7 @@ as a shorthand for
f(long_variable_name=long_variable_name)
```

Such syntax is also supported in [partial application](#partial-application) and [anonymous `namedtuple`s](#anonymous-namedtuples).
Such syntax is also supported in [partial application](#partial-application), [anonymous `namedtuple`s](#anonymous-namedtuples), and [`class`/`data`/anonymous `namedtuple` patterns](#match).

_Deprecated: Coconut also supports `f(...=long_variable_name)` as an alternative shorthand syntax._

Expand Down
15 changes: 12 additions & 3 deletions coconut/compiler/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -1995,8 +1995,17 @@ class Grammar(object):

del_stmt = addspace(keyword("del") - simple_assignlist)

matchlist_data_item = Group(Optional(star | Optional(dot) + unsafe_name + equals) + match)
matchlist_data = Group(Optional(tokenlist(matchlist_data_item, comma)))
interior_name_match = labeled_group(setname, "var")
matchlist_anon_named_tuple_item = (
Group(Optional(dot) + unsafe_name) + equals + match
| Group(Optional(dot) + interior_name_match) + equals
)
matchlist_data_item = (
matchlist_anon_named_tuple_item
| Optional(star) + match
)
matchlist_data = Group(Optional(tokenlist(Group(matchlist_data_item), comma)))
matchlist_anon_named_tuple = Optional(tokenlist(Group(matchlist_anon_named_tuple_item), comma))

match_check_equals = Forward()
match_check_equals_ref = equals
Expand Down Expand Up @@ -2031,7 +2040,6 @@ class Grammar(object):
match_tuple = Group(lparen + matchlist_tuple + rparen.suppress())
match_lazy = Group(lbanana + matchlist_list + rbanana.suppress())

interior_name_match = labeled_group(setname, "var")
match_string = interleaved_tokenlist(
# f_string_atom must come first
f_string_atom("f_string") | fixed_len_string_tokens("string"),
Expand Down Expand Up @@ -2085,6 +2093,7 @@ class Grammar(object):
| (keyword("data").suppress() + dotted_refname + lparen.suppress() + matchlist_data + rparen.suppress())("data")
| (keyword("class").suppress() + dotted_refname + lparen.suppress() + matchlist_data + rparen.suppress())("class")
| (dotted_refname + lparen.suppress() + matchlist_data + rparen.suppress())("data_or_class")
| (lparen.suppress() + matchlist_anon_named_tuple + rparen.suppress())("anon_named_tuple")
| Optional(keyword("as").suppress()) + setname("var"),
)

Expand Down
50 changes: 34 additions & 16 deletions coconut/compiler/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class Matcher(object):
"data": lambda self: self.match_data,
"class": lambda self: self.match_class,
"data_or_class": lambda self: self.match_data_or_class,
"anon_named_tuple": lambda self: self.match_anon_named_tuple,
"paren": lambda self: self.match_paren,
"as": lambda self: self.match_as,
"and": lambda self: self.match_and,
Expand Down Expand Up @@ -1056,10 +1057,8 @@ def match_set(self, tokens, item):
for const in match:
self.add_check(const + " in " + item)

def split_data_or_class_match(self, tokens):
"""Split data/class match tokens into cls_name, pos_matches, name_matches, star_match."""
cls_name, matches = tokens

def split_data_or_class_matches(self, matches):
"""Split data/class match tokens into pos_matches, name_matches, star_match."""
pos_matches = []
name_matches = {}
star_match = None
Expand All @@ -1073,8 +1072,7 @@ def split_data_or_class_match(self, tokens):
raise CoconutDeferredSyntaxError("positional arg after keyword arg in data/class match", self.loc)
pos_matches.append(match)
# starred arg
elif len(match_arg) == 2:
internal_assert(match_arg[0] == "*", "invalid starred data/class match arg tokens", match_arg)
elif len(match_arg) == 2 and match_arg[0] == "*":
_, match = match_arg
if star_match is not None:
raise CoconutDeferredSyntaxError("duplicate starred arg in data/class match", self.loc)
Expand All @@ -1083,23 +1081,30 @@ def split_data_or_class_match(self, tokens):
star_match = match
# keyword arg
else:
internal_assert(match_arg[1] == "=", "invalid keyword data/class match arg tokens", match_arg)
if len(match_arg) == 3:
internal_assert(match_arg[1] == "=", "invalid keyword data/class match arg tokens", match_arg)
name, _, match = match_arg
strict = False
elif len(match_arg) == 4:
internal_assert(match_arg[0] == "." and match_arg[2] == "=", "invalid strict keyword data/class match arg tokens", match_arg)
_, name, _, match = match_arg
strict = True
name_grp, _, match = match_arg
elif len(match_arg) == 2:
match_grp, _ = match_arg
match = match_grp[-1]
name, = match
name_grp = match_grp[:-1] + [name]
else:
raise CoconutInternalException("invalid data/class match arg", match_arg)
if len(name_grp) == 1:
name, = name_grp
strict = False
else:
internal_assert(name_grp[0] == ".", "invalid keyword data/class match arg tokens", name_grp)
_, name = name_grp
strict = True
if star_match is not None:
raise CoconutDeferredSyntaxError("both keyword arg and starred arg in data/class match", self.loc)
if name in name_matches:
raise CoconutDeferredSyntaxError("duplicate keyword arg {name!r} in data/class match".format(name=name), self.loc)
name_matches[name] = (match, strict)

return cls_name, pos_matches, name_matches, star_match
return pos_matches, name_matches, star_match

def match_class_attr(self, match, attr, item):
"""Match an attribute for a class match where attr is an expression that evaluates to the attribute name."""
Expand All @@ -1119,7 +1124,8 @@ def match_class_names(self, name_matches, item):

def match_class(self, tokens, item):
"""Matches a class PEP-622-style."""
cls_name, pos_matches, name_matches, star_match = self.split_data_or_class_match(tokens)
cls_name, matches = tokens
pos_matches, name_matches, star_match = self.split_data_or_class_matches(matches)

self.add_check("_coconut.isinstance(" + item + ", " + cls_name + ")")

Expand Down Expand Up @@ -1191,7 +1197,8 @@ def match_class(self, tokens, item):

def match_data(self, tokens, item):
"""Matches a data type."""
cls_name, pos_matches, name_matches, star_match = self.split_data_or_class_match(tokens)
cls_name, matches = tokens
pos_matches, name_matches, star_match = self.split_data_or_class_matches(matches)

self.add_check("_coconut.isinstance(" + item + ", " + cls_name + ")")

Expand Down Expand Up @@ -1240,6 +1247,17 @@ def match_data(self, tokens, item):
with self.down_a_level():
self.add_check(temp_var)

def match_anon_named_tuple(self, tokens, item):
"""Matches an anonymous named tuple pattern."""
pos_matches, name_matches, star_match = self.split_data_or_class_matches(tokens)
internal_assert(not pos_matches and not star_match, "got invalid pos/star matches in anon named tuple pattern", (pos_matches, star_match))
self.add_check("_coconut.isinstance(" + item + ", tuple)")
self.add_check("_coconut.len({item}) == {expected_len}".format(
item=item,
expected_len=len(name_matches),
))
self.match_class_names(name_matches, item)

def match_data_or_class(self, tokens, item):
"""Matches an ambiguous data or class match."""
cls_name, matches = tokens
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "3.1.1"
VERSION_NAME = None
# False for release, int >= 1 for develop
DEVELOP = 2
DEVELOP = 3
ALPHA = False # for pre releases rather than post releases

assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1"
Expand Down
21 changes: 21 additions & 0 deletions coconut/tests/src/cocotest/agnostic/primary_2.coco
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,27 @@ def primary_test_2() -> bool:
""" == '\n"2"\n'
assert f"\{1}" == "\\1"
assert f''' '{1}' ''' == " '1' "
tuple(x=) = (x=4)
assert x == 4
tuple(x=, y=) = (x=5, y=5)
assert x == 5 == y
data tuple(x=) = (x=6)
assert x == 6
class tuple(x=) = (x=7)
assert x == 7
data tuple(x, y=) = (x=8, y=8)
assert x == 8 == y
(x=, y=) = (x=9, y=9)
assert x == 9 == y
(x=x) = (x=10)
assert x == 10
(x=, y=y) = (x=11, y=11)
assert x == 11 == y
tuple(x=) = (x=12, y=12)
assert x == 12
match (x=) in (x=13, y=13):
assert False
assert x == 12

with process_map.multiple_sequential_calls(): # type: ignore
assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore
Expand Down

0 comments on commit 3d7577e

Please sign in to comment.