From f80c5b3731919096ddaad2b7667523acd6d007d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Thu, 28 Dec 2023 22:47:34 +0100 Subject: [PATCH] feat(common): add `Dispatched` base class for convenient visitor pattern implementation --- ibis/common/dispatch.py | 68 ++++++++++++++++++++++++++ ibis/common/tests/test_dispatch.py | 77 +++++++++++++++++++++++++++++- 2 files changed, 144 insertions(+), 1 deletion(-) diff --git a/ibis/common/dispatch.py b/ibis/common/dispatch.py index e9cf71c20109..28999b452e37 100644 --- a/ibis/common/dispatch.py +++ b/ibis/common/dispatch.py @@ -2,6 +2,7 @@ import abc import functools +import inspect import re from collections import defaultdict @@ -91,3 +92,70 @@ def call(arg, *args, **kwargs): call.register = register return call + + +class _MultiDict(dict): + """A dictionary that allows multiple values for a single key.""" + + def __setitem__(self, key, value): + if key in self: + self[key].append(value) + else: + super().__setitem__(key, [value]) + + +class DispatchedMeta(type): + """Metaclass that allows multiple implementations of a method to be defined.""" + + def __new__(cls, name, bases, dct): + namespace = {} + for key, value in dct.items(): + if len(value) == 1: + # there is just a single attribute so pick that + namespace[key] = value[0] + elif all(inspect.isfunction(v) for v in value): + # multiple functions are defined with the same name, so create + # a dispatcher function + first, *rest = value + func = functools.singledispatchmethod(first) + for impl in rest: + func.register(impl) + namespace[key] = func + elif all(isinstance(v, classmethod) for v in value): + first, *rest = value + func = functools.singledispatchmethod(first.__func__) + for v in rest: + func.register(v.__func__) + namespace[key] = classmethod(func) + elif all(isinstance(v, staticmethod) for v in value): + first, *rest = value + func = functools.singledispatch(first.__func__) + for v in rest: + func.register(v.__func__) + namespace[key] = staticmethod(func) + else: + raise TypeError(f"Multiple attributes are defined with name {key}") + + return type.__new__(cls, name, bases, namespace) + + @classmethod + def __prepare__(cls, name, bases): + return _MultiDict() + + +class Dispatched(metaclass=DispatchedMeta): + """Base class supporting multiple implementations of a method. + + Methods with the same name can be defined multiple times. The first method + defined is the default implementation, and subsequent methods are registered + as implementations for specific types of the first argument. + + The constructed methods are equivalent as if they were defined with + `functools.singledispatchmethod` but without the need to use the decorator + syntax. The recommended application of this class is to implement visitor + patterns. + + Besides ordinary methods, classmethods and staticmethods are also supported. + The implementation can be extended to overload multiple arguments by using + `multimethod` instead of `singledispatchmethod` as the dispatcher. + """ diff --git a/ibis/common/tests/test_dispatch.py b/ibis/common/tests/test_dispatch.py index 4b3a34ff5b7c..5f34c533851d 100644 --- a/ibis/common/tests/test_dispatch.py +++ b/ibis/common/tests/test_dispatch.py @@ -3,7 +3,9 @@ import collections import decimal -from ibis.common.dispatch import lazy_singledispatch +from ibis.common.dispatch import Dispatched, lazy_singledispatch + +# ruff: noqa: F811 def test_lazy_singledispatch(): @@ -118,3 +120,76 @@ def _(a): assert foo({}) == "mapping" assert foo(mydict()) == "mydict" # concrete takes precedence assert foo(sum) == "callable" + + +class Visitor(Dispatched): + def a(self): + return "a" + + def b(self, x: int): + return "b_int" + + def b(self, x: str): + return "b_str" + + @classmethod + def c(cls, x: int, **kwargs): + return "c_int" + + @classmethod + def c(cls, x: str, a=0, b=1): + return "c_str" + + def d(self, x: int): + return "d_int" + + def d(self, x: str): + return "d_str" + + @staticmethod + def e(x: int): + return "e_int" + + @staticmethod + def e(x: str): + return "e_str" + + +class Subvisitor(Visitor): + def b(self, x): + return super().b(x) + + def b(self, x: float): + return "b_float" + + @classmethod + def c(cls, x): + return super().c(x) + + @classmethod + def c(cls, s: float): + return "c_float" + + +def test_dispatched(): + v = Visitor() + assert v.a == v.a + assert v.b(1) == "b_int" + assert v.b("1") == "b_str" + assert v.d(1) == "d_int" + assert v.d("1") == "d_str" + + w = Subvisitor() + assert w.b(1) == "b_int" + assert w.b(1.1) == "b_float" + + assert Visitor.c(1, a=0, b=0) == "c_int" + assert Visitor.c("1") == "c_str" + + assert Visitor.e("1") == "e_str" + assert Visitor.e(1) == "e_int" + + assert Subvisitor.c(1) == "c_int" + assert Subvisitor.c(1.1) == "c_float" + + assert Subvisitor.e(1) == "e_int"