From 1d314f76531a016567973e1de9f9ea7c8b1e9cbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Thu, 23 Nov 2023 18:40:26 +0100 Subject: [PATCH] feat(common): allow matching on dictionaries in possibly nested patterns Currently only exact matches are supported, we can extend this to allow more advanced dictionary patterns in the future. --- ibis/common/patterns.py | 2 +- ibis/common/tests/test_patterns.py | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index 7cccc3ff9204..5fb1728c23b3 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -1586,7 +1586,7 @@ def pattern(obj: AnyType) -> Pattern: elif isinstance(obj, (Deferred, Resolver)): return Capture(obj) elif isinstance(obj, Mapping): - raise TypeError("Cannot create a pattern from a mapping") + return EqualTo(FrozenDict(obj)) elif isinstance(obj, Sequence): if isinstance(obj, (str, bytes)): return EqualTo(obj) diff --git a/ibis/common/tests/test_patterns.py b/ibis/common/tests/test_patterns.py index 67e507ca6894..aa770da9e79c 100644 --- a/ibis/common/tests/test_patterns.py +++ b/ibis/common/tests/test_patterns.py @@ -631,6 +631,26 @@ def __eq__(self, other): assert match(p, Foo(1, 2, 1)) == Foo(1, 2, 1) +def test_object_pattern_matching_dictionary_field(): + a = Bar(1, FrozenDict()) + b = Bar(1, {}) + c = Bar(1, None) + d = Bar(1, {"foo": 1}) + + pattern = Object(Bar, 1, d={}) + assert match(pattern, a) is a + assert match(pattern, b) is b + assert match(pattern, c) is NoMatch + + pattern = Object(Bar, 1, d=None) + assert match(pattern, a) is NoMatch + assert match(pattern, c) is c + + pattern = Object(Bar, 1, d={"foo": 1}) + assert match(pattern, a) is NoMatch + assert match(pattern, d) is d + + def test_callable_with(): def func(a, b): return str(a) + b @@ -1231,6 +1251,9 @@ def f(x): # matching deferred to user defined functions assert pattern(f) == Custom(f) + # matching mapping values + assert pattern({"a": 1, "b": 2}) == EqualTo(FrozenDict({"a": 1, "b": 2})) + class Term(GraphNode): def __eq__(self, other):