Skip to content

Commit

Permalink
feat(common): add Dispatched base class for convenient visitor patt…
Browse files Browse the repository at this point in the history
…ern implementation
  • Loading branch information
kszucs committed Feb 12, 2024
1 parent 28fb6ec commit f80c5b3
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 1 deletion.
68 changes: 68 additions & 0 deletions ibis/common/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import functools
import inspect
import re
from collections import defaultdict

Expand Down Expand Up @@ -91,3 +92,70 @@ def call(arg, *args, **kwargs):
call.register = register

return call


class _MultiDict(dict):
"""A dictionary that allows multiple values for a single key."""

def __setitem__(self, key, value):
if key in self:
self[key].append(value)
else:
super().__setitem__(key, [value])


class DispatchedMeta(type):
"""Metaclass that allows multiple implementations of a method to be defined."""

def __new__(cls, name, bases, dct):
namespace = {}
for key, value in dct.items():
if len(value) == 1:
# there is just a single attribute so pick that
namespace[key] = value[0]
elif all(inspect.isfunction(v) for v in value):
# multiple functions are defined with the same name, so create
# a dispatcher function
first, *rest = value
func = functools.singledispatchmethod(first)
for impl in rest:
func.register(impl)
namespace[key] = func
elif all(isinstance(v, classmethod) for v in value):
first, *rest = value
func = functools.singledispatchmethod(first.__func__)
for v in rest:
func.register(v.__func__)
namespace[key] = classmethod(func)
elif all(isinstance(v, staticmethod) for v in value):
first, *rest = value
func = functools.singledispatch(first.__func__)
for v in rest:
func.register(v.__func__)
namespace[key] = staticmethod(func)
else:
raise TypeError(f"Multiple attributes are defined with name {key}")

return type.__new__(cls, name, bases, namespace)

@classmethod
def __prepare__(cls, name, bases):
return _MultiDict()


class Dispatched(metaclass=DispatchedMeta):
"""Base class supporting multiple implementations of a method.
Methods with the same name can be defined multiple times. The first method
defined is the default implementation, and subsequent methods are registered
as implementations for specific types of the first argument.
The constructed methods are equivalent as if they were defined with
`functools.singledispatchmethod` but without the need to use the decorator
syntax. The recommended application of this class is to implement visitor
patterns.
Besides ordinary methods, classmethods and staticmethods are also supported.
The implementation can be extended to overload multiple arguments by using
`multimethod` instead of `singledispatchmethod` as the dispatcher.
"""
77 changes: 76 additions & 1 deletion ibis/common/tests/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import collections
import decimal

from ibis.common.dispatch import lazy_singledispatch
from ibis.common.dispatch import Dispatched, lazy_singledispatch

# ruff: noqa: F811


def test_lazy_singledispatch():
Expand Down Expand Up @@ -118,3 +120,76 @@ def _(a):
assert foo({}) == "mapping"
assert foo(mydict()) == "mydict" # concrete takes precedence
assert foo(sum) == "callable"


class Visitor(Dispatched):
def a(self):
return "a"

def b(self, x: int):
return "b_int"

def b(self, x: str):
return "b_str"

@classmethod
def c(cls, x: int, **kwargs):
return "c_int"

@classmethod
def c(cls, x: str, a=0, b=1):
return "c_str"

def d(self, x: int):
return "d_int"

def d(self, x: str):
return "d_str"

@staticmethod
def e(x: int):
return "e_int"

@staticmethod
def e(x: str):
return "e_str"


class Subvisitor(Visitor):
def b(self, x):
return super().b(x)

def b(self, x: float):
return "b_float"

@classmethod
def c(cls, x):
return super().c(x)

@classmethod
def c(cls, s: float):
return "c_float"


def test_dispatched():
v = Visitor()
assert v.a == v.a
assert v.b(1) == "b_int"
assert v.b("1") == "b_str"
assert v.d(1) == "d_int"
assert v.d("1") == "d_str"

w = Subvisitor()
assert w.b(1) == "b_int"
assert w.b(1.1) == "b_float"

assert Visitor.c(1, a=0, b=0) == "c_int"
assert Visitor.c("1") == "c_str"

assert Visitor.e("1") == "e_str"
assert Visitor.e(1) == "e_int"

assert Subvisitor.c(1) == "c_int"
assert Subvisitor.c(1.1) == "c_float"

assert Subvisitor.e(1) == "e_int"

0 comments on commit f80c5b3

Please sign in to comment.