Skip to content

Commit

Permalink
fix(patterns): support optional keyword arguments in CallableWith
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and jcrist committed Sep 9, 2023
1 parent b0bcdde commit a78aa60
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
16 changes: 7 additions & 9 deletions ibis/common/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ class CoercionError(Exception):
...


class MatchError(Exception):
...


class Coercible(ABC):
"""Protocol for defining coercible types.
Expand Down Expand Up @@ -1512,26 +1508,28 @@ def __init__(self, args, return_=_any):
super().__init__(args=tuple(args), return_=return_)

def match(self, value, context):
from ibis.common.annotations import annotated
from ibis.common.annotations import EMPTY, annotated

if not callable(value):
return NoMatch

fn = annotated(self.args, self.return_, value)

has_varargs = False
positional = []
positional, required_positional = [], []
for p in fn.__signature__.parameters.values():
if p.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD):
positional.append(p)
elif p.kind is Parameter.KEYWORD_ONLY and p.default is Parameter.empty:
raise MatchError(
if p.default is EMPTY:
required_positional.append(p)
elif p.kind is Parameter.KEYWORD_ONLY and p.default is EMPTY:
raise TypeError(
"Callable has mandatory keyword-only arguments which cannot be specified"
)
elif p.kind is Parameter.VAR_POSITIONAL:
has_varargs = True

if len(positional) > len(self.args):
if len(required_positional) > len(self.args):
# Callable has more positional arguments than expected")
return NoMatch
elif len(positional) < len(self.args) and not has_varargs:
Expand Down
21 changes: 18 additions & 3 deletions ibis/common/tests/test_patterns.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
import re
import sys
from collections.abc import Callable as CallableABC
Expand Down Expand Up @@ -51,7 +52,6 @@
Length,
ListOf,
MappingOf,
MatchError,
Never,
Node,
NoMatch,
Expand Down Expand Up @@ -537,12 +537,12 @@ def func_with_required_keyword_only_kwargs(*, c):
assert p.match(10, context={}) is NoMatch

msg = "Callable has mandatory keyword-only arguments which cannot be specified"
with pytest.raises(MatchError, match=msg):
with pytest.raises(TypeError, match=msg):
p.match(func_with_required_keyword_only_kwargs, context={})

# Callable has more positional arguments than expected
p = CallableWith([InstanceOf(int)] * 2)
assert p.match(func_with_kwargs, context={}) is NoMatch
assert p.match(func_with_kwargs, context={}).__wrapped__ is func_with_kwargs

# Callable has less positional arguments than expected
p = CallableWith([InstanceOf(int)] * 4)
Expand All @@ -564,6 +564,21 @@ def func_with_required_keyword_only_kwargs(*, c):
assert wrapped(1) == 2


def test_callable_with_default_arguments():
def f(a: int, b: str, c: str):
return a + int(b) + int(c)

def g(a: int, b: str, c: str = "0"):
return a + int(b) + int(c)

h = functools.partial(f, c="0")

p = Pattern.from_typehint(Callable[[int, str], int])
assert p.match(f, {}) is NoMatch
assert p.match(g, {}).__wrapped__ == g
assert p.match(h, {}).__wrapped__ == h


def test_pattern_list():
p = PatternSequence([1, 2, InstanceOf(int), ...])
assert p.match([1, 2, 3, 4, 5], context={}) == [1, 2, 3, 4, 5]
Expand Down

0 comments on commit a78aa60

Please sign in to comment.