Skip to content

Commit

Permalink
refactor(common): ibis.common.patterns.match() should return with the…
Browse files Browse the repository at this point in the history
… matched value rather than the context
  • Loading branch information
kszucs authored and cpcloud committed Aug 7, 2023
1 parent 4697e7d commit cbb9b2f
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 59 deletions.
29 changes: 17 additions & 12 deletions ibis/common/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def __hash__(self) -> int:
return self.__precomputed_hash__

def __setattr__(self, name, value) -> None:
raise AttributeError("Can't set attributes on immutable instance")
raise AttributeError("Can't set attributes on an immutable instance")

def __repr__(self):
fields = {k: getattr(self, k) for k in self.__slots__}
Expand Down Expand Up @@ -1062,7 +1062,8 @@ def match(self, value, context):
except StopIteration:
break

if match(following, item, context) is NoMatch:
res = following.match(item, context)
if res is NoMatch:
matches.append(item)
else:
it.rewind()
Expand Down Expand Up @@ -1202,7 +1203,9 @@ def pattern(obj: AnyType) -> Pattern:
return EqualTo(obj)


def match(pat: Pattern, value: AnyType, context: Optional[dict[str, AnyType]] = None):
def match(
pat: Pattern, value: AnyType, context: Optional[dict[str, AnyType]] = None
) -> Any:
"""Match a value against a pattern.
Parameters
Expand All @@ -1214,24 +1217,26 @@ def match(pat: Pattern, value: AnyType, context: Optional[dict[str, AnyType]] =
context
Arbitrary mapping of values to be used while matching.
Returns
-------
The matched value if the pattern matches, otherwise :obj:`NoMatch`.
Examples
--------
>>> assert match(Any(), 1) == {}
>>> assert match(1, 1) == {}
>>> assert match(Any(), 1) == 1
>>> assert match(1, 1) == 1
>>> assert match(1, 2) is NoMatch
>>> assert match(1, 1, context={"x": 1}) == {"x": 1}
>>> assert match(1, 1, context={"x": 1}) == 1
>>> assert match(1, 2, context={"x": 1}) is NoMatch
>>> assert match([1, int], [1, 2]) == {}
>>> assert match([1, int, "a" @ InstanceOf(str)], [1, 2, "three"]) == {"a": "three"}
>>> assert match([1, int], [1, 2]) == [1, 2]
>>> assert match([1, int, "a" @ InstanceOf(str)], [1, 2, "three"]) == [1, 2, "three"]
"""
if context is None:
context = {}

pat = pattern(pat)
if pat.match(value, context=context) is NoMatch:
return NoMatch

return context
result = pat.match(value, context)
return NoMatch if result is NoMatch else result


class Topmost(Matcher):
Expand Down
131 changes: 84 additions & 47 deletions ibis/common/tests/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ def test_generic_instance_of_with_covariant_typevar():
p = Pattern.from_typehint(My[int, AnyType])
assert p.match(My(1, 2, "3"), context={}) == My(1, 2, "3")

assert match(My[int, AnyType], My(1, 2, "3"), context={}) == {}
assert match(My[int, int], My(1, 2, "3"), context={}) == {}
assert match(My[int, float], My(1, 2, "3"), context={}) is NoMatch
assert match(My[int, float], My(1, 2.0, "3"), context={}) == {}
assert match(My[int, AnyType], v := My(1, 2, "3")) == v
assert match(My[int, int], v := My(1, 2, "3")) == v
assert match(My[int, float], My(1, 2, "3")) is NoMatch
assert match(My[int, float], v := My(1, 2.0, "3")) == v


def test_generic_instance_of_disallow_nested_coercion():
Expand Down Expand Up @@ -409,8 +409,14 @@ def __init__(self, a, b):
self.a = a
self.b = b

def __eq__(self, other):
return type(self) == type(other) and self.a == other.a and self.b == other.b

p = Object(Foo, 1, b=2)
assert match(p, Foo(1, 2)) == {}
o = Foo(1, 2)
r = match(p, o)
assert r is o
assert r == Foo(1, 2)


def test_callable_with():
Expand Down Expand Up @@ -470,49 +476,57 @@ def test_pattern_list():


def test_matching():
assert match("foo", "foo") == {}
assert match("foo", "foo") == "foo"
assert match("foo", "bar") is NoMatch

assert match(InstanceOf(int), 1) == {}
assert match(InstanceOf(int), 1) == 1
assert match(InstanceOf(int), "foo") is NoMatch

assert Capture(InstanceOf(float), "pi") == "pi" @ InstanceOf(float)
assert Capture(InstanceOf(float), "pi") == InstanceOf(float) >> "pi"

assert match(Capture(InstanceOf(float), "pi"), 3.14) == {"pi": 3.14}
assert match("pi" @ InstanceOf(float), 3.14) == {"pi": 3.14}
assert match(Capture(InstanceOf(float), "pi"), 3.14, ctx := {}) == 3.14
assert ctx == {"pi": 3.14}

assert match("pi" @ InstanceOf(float), 3.14, ctx := {}) == 3.14
assert ctx == {"pi": 3.14}

assert match(InstanceOf(int) | InstanceOf(float), 3) == {}
assert match(InstanceOf(object) & InstanceOf(float), 3.14) == {}
assert match(InstanceOf(int) | InstanceOf(float), 3) == 3
assert match(InstanceOf(object) & InstanceOf(float), 3.14) == 3.14


def test_matching_sequence_pattern():
assert match([], []) == {}
assert match([], []) == []
assert match([], [1]) is NoMatch

assert match([1, 2, 3, 4, ...], list(range(1, 9))) == {}
assert match([1, 2, 3, 4, ...], list(range(1, 9))) == list(range(1, 9))
assert match([1, 2, 3, 4, ...], list(range(1, 3))) is NoMatch
assert match([1, 2, 3, 4, ...], list(range(1, 5))) == {}
assert match([1, 2, 3, 4, ...], list(range(1, 6))) == {}
assert match([1, 2, 3, 4, ...], list(range(1, 5))) == list(range(1, 5))
assert match([1, 2, 3, 4, ...], list(range(1, 6))) == list(range(1, 6))

assert match([..., 3, 4], list(range(5))) == {}
assert match([..., 3, 4], list(range(5))) == list(range(5))
assert match([..., 3, 4], list(range(3))) is NoMatch

assert match([0, 1, ..., 4], list(range(5))) == {}
assert match([0, 1, ..., 4], list(range(5))) == list(range(5))
assert match([0, 1, ..., 4], list(range(4))) is NoMatch

assert match([...], list(range(5))) == {}
assert match([..., 2, 3, 4, ...], list(range(8))) == {}
assert match([...], list(range(5))) == list(range(5))
assert match([..., 2, 3, 4, ...], list(range(8))) == list(range(8))


def test_matching_sequence_with_captures():
assert match([1, 2, 3, 4, SequenceOf(...)], list(range(1, 9))) == {}
assert match([1, 2, 3, 4, "rest" @ SequenceOf(...)], list(range(1, 9))) == {
"rest": (5, 6, 7, 8)
}
assert match([1, 2, 3, 4, SequenceOf(...)], v := list(range(1, 9))) == v
assert (
match([1, 2, 3, 4, "rest" @ SequenceOf(...)], v := list(range(1, 9)), ctx := {})
== v
)
assert ctx == {"rest": (5, 6, 7, 8)}

assert match([0, 1, "var" @ SequenceOf(...), 4], list(range(5))) == {"var": (2, 3)}
assert match([0, 1, SequenceOf(...) >> "var", 4], list(range(5))) == {"var": (2, 3)}
v = list(range(5))
assert match([0, 1, "var" @ SequenceOf(...), 4], v, ctx := {}) == v
assert ctx == {"var": (2, 3)}
assert match([0, 1, "var" @ SequenceOf(...), 4], v, ctx := {}) == v
assert ctx == {"var": (2, 3)}

p = [
0,
Expand All @@ -521,23 +535,28 @@ def test_matching_sequence_with_captures():
"floats" @ SequenceOf(InstanceOf(float)),
6,
]
assert match(p, [0, 1, 2, 3, 4.0, 5.0, 6]) == {"ints": (2, 3), "floats": (4.0, 5.0)}
v = [0, 1, 2, 3, 4.0, 5.0, 6]
assert match(p, v, ctx := {}) == v
assert ctx == {"ints": (2, 3), "floats": (4.0, 5.0)}


def test_matching_sequence_remaining():
Seq = SequenceOf
IsInt = InstanceOf(int)

assert match([1, 2, 3, Seq(IsInt, at_least=1)], [1, 2, 3, 4]) == {}
assert match([1, 2, 3, Seq(IsInt, at_least=1)], [1, 2, 3]) is NoMatch
assert match([1, 2, 3, Seq(IsInt)], [1, 2, 3]) == {}
assert match([1, 2, 3, Seq(IsInt, at_most=1)], [1, 2, 3]) == {}
assert match([1, 2, 3, Seq(IsInt & Between(0, 10))], [1, 2, 3, 4, 5]) == {}
assert match([1, 2, 3, Seq(IsInt & Between(0, 4))], [1, 2, 3, 4, 5]) is NoMatch
assert match([1, 2, 3, Seq(IsInt, at_least=2)], [1, 2, 3, 4]) is NoMatch
assert match([1, 2, 3, Seq(IsInt, at_least=2) >> "res"], [1, 2, 3, 4, 5]) == {
"res": (4, 5)
}
three = [1, 2, 3]
four = [1, 2, 3, 4]
five = [1, 2, 3, 4, 5]

assert match([1, 2, 3, Seq(IsInt, at_least=1)], four) == four
assert match([1, 2, 3, Seq(IsInt, at_least=1)], three) is NoMatch
assert match([1, 2, 3, Seq(IsInt)], three) == three
assert match([1, 2, 3, Seq(IsInt, at_most=1)], three) == three
assert match([1, 2, 3, Seq(IsInt & Between(0, 10))], five) == five
assert match([1, 2, 3, Seq(IsInt & Between(0, 4))], five) is NoMatch
assert match([1, 2, 3, Seq(IsInt, at_least=2)], four) is NoMatch
assert match([1, 2, 3, "res" @ Seq(IsInt, at_least=2)], five, ctx := {}) == five
assert ctx == {"res": (4, 5)}


def test_matching_sequence_complicated():
Expand All @@ -553,11 +572,13 @@ def test_matching_sequence_complicated():
"a": [2, 3],
"b": (5, 6, 7),
}
assert match(pattern, range(1, 10)) == expected
assert match(pattern, range(1, 10), ctx := {}) == list(range(1, 10))
assert ctx == expected

pattern = [0, PatternSequence([1, 2]) >> "pairs", 3]
expected = {"pairs": [1, 2]}
assert match(pattern, [0, 1, 2, 1, 2, 3]) == expected
pattern = [0, "pairs" @ PatternSequence([-1, -2]), 3]
expected = {"pairs": [-1, -2]}
assert match(pattern, [0, -1, -2, 3], ctx := {}) == [0, -1, -2, 3]
assert ctx == expected

pattern = [
0,
Expand All @@ -566,13 +587,15 @@ def test_matching_sequence_complicated():
3,
]
expected = {"first": [1, 2], "second": [4, 5]}
assert match(pattern, [0, 1, 2, 4, 5, 3]) == expected
assert match(pattern, [0, 1, 2, 4, 5, 3], ctx := {}) == [0, 1, 2, 4, 5, 3]
assert ctx == expected

pattern = [1, 2, "remaining" @ SequenceOf(...)]
expected = {"remaining": (3, 4, 5, 6, 7, 8, 9)}
assert match(pattern, range(1, 10)) == expected
assert match(pattern, range(1, 10), ctx := {}) == list(range(1, 10))
assert ctx == expected

assert match([0, SequenceOf([1, 2]), 3], [0, [1, 2], [1, 2], 3]) == {}
assert match([0, SequenceOf([1, 2]), 3], v := [0, [1, 2], [1, 2], 3]) == v


def test_pattern_map():
Expand All @@ -584,26 +607,40 @@ def test_matching_mapping():
assert match({}, {}) == {}
assert match({}, {1: 2}) is NoMatch

assert match({1: 2}, {1: 2}) == {}
assert match({1: 2}, {1: 2}) == {1: 2}
assert match({1: 2}, {1: 3}) is NoMatch

assert match({}, 3) is NoMatch
assert match({"a": "capture" @ InstanceOf(int)}, {"a": 1}) == {"capture": 1}
ctx = {}
assert match({"a": "capture" @ InstanceOf(int)}, {"a": 1}, ctx) == {"a": 1}
assert ctx == {"capture": 1}

p = {
"a": "capture" @ InstanceOf(int),
"b": InstanceOf(float),
...: InstanceOf(str),
}
assert match(p, {"a": 1, "b": 2.0, "c": "foo"}) == {"capture": 1}
ctx = {}
assert match(p, {"a": 1, "b": 2.0, "c": "foo"}, ctx) == {
"a": 1,
"b": 2.0,
"c": "foo",
}
assert ctx == {"capture": 1}
assert match(p, {"a": 1, "b": 2.0, "c": 3}) is NoMatch

p = {
"a": "capture" @ InstanceOf(int),
"b": InstanceOf(float),
"rest" @ SequenceOf(...): InstanceOf(str),
}
assert match(p, {"a": 1, "b": 2.0, "c": "foo"}) == {"capture": 1, "rest": ("c",)}
ctx = {}
assert match(p, {"a": 1, "b": 2.0, "c": "foo"}, ctx) == {
"a": 1,
"b": 2.0,
"c": "foo",
}
assert ctx == {"capture": 1, "rest": ("c",)}


@pytest.mark.parametrize(
Expand Down

0 comments on commit cbb9b2f

Please sign in to comment.