Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable more robust multiple dispatch with plum #415

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fastcore/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.4.6"
__version__ = "1.5.0"
6 changes: 2 additions & 4 deletions fastcore/_nbdev.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,8 @@
"do_request": "03b_net.ipynb",
"start_server": "03b_net.ipynb",
"start_client": "03b_net.ipynb",
"lenient_issubclass": "04_dispatch.ipynb",
"sorted_topologically": "04_dispatch.ipynb",
"TypeDispatch": "04_dispatch.ipynb",
"DispatchReg": "04_dispatch.ipynb",
"FastFunction": "04_dispatch.ipynb",
"FastDispatcher": "04_dispatch.ipynb",
"typedispatch": "04_dispatch.ipynb",
"retain_meta": "04_dispatch.ipynb",
"default_set_meta": "04_dispatch.ipynb",
Expand Down
202 changes: 72 additions & 130 deletions fastcore/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,154 +4,96 @@
from __future__ import annotations


__all__ = ['lenient_issubclass', 'sorted_topologically', 'TypeDispatch', 'DispatchReg', 'typedispatch', 'cast',
'retain_meta', 'default_set_meta', 'retain_type', 'retain_types', 'explode_types']
__all__ = ['FastFunction', 'FastDispatcher', 'typedispatch', 'cast', 'retain_meta', 'default_set_meta', 'retain_type',
'retain_types', 'explode_types']

# Cell
#nbdev_comment from __future__ import annotations
from .imports import *
from .foundation import *
from .utils import *
from .meta import delegates

from collections import defaultdict
from plum import Function, Dispatcher

# Cell
def lenient_issubclass(cls, types):
"If possible return whether `cls` is a subclass of `types`, otherwise return False."
if cls is object and types is not object: return False # treat `object` as highest level
try: return isinstance(cls, types) or issubclass(cls, types)
except: return False
def _eval_annotations(f):
"Evaluate future annotations before passing to plum to support backported union operator `|`"
f = copy_func(f)
for k, v in type_hints(f).items(): f.__annotations__[k] = Union[v] if isinstance(v, tuple) else v
return f

# Cell
def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False):
"Return a new list containing all items from the iterable sorted topologically"
l,res = L(list(iterable)),[]
for _ in range(len(l)):
t = l.reduce(lambda x,y: y if cmp(y,x) else x)
res.append(t), l.remove(t)
return res[::-1] if reverse else res
def _pt_repr(o):
"Concise repr of plum types"
n = type(o).__name__
if n == 'Tuple': return f"{n.lower()}[{','.join(_pt_repr(t) for t in o._el_types)}]"
if n == 'List': return f'{n.lower()}[{_pt_repr(o._el_type)}]'
if n == 'Dict': return f'{n.lower()}[{_pt_repr(o._key_type)},{_pt_repr(o._value_type)}]'
if n in ('Sequence','Iterable'): return f'{n}[{_pt_repr(o._el_type)}]'
if n == 'VarArgs': return f'{n}[{_pt_repr(o.type)}]'
if n == 'Union': return '|'.join(sorted(t.__name__ for t in (o.get_types())))
assert len(o.get_types()) == 1
return o.get_types()[0].__name__

# Cell
def _chk_defaults(f, ann):
pass
# Implementation removed until we can figure out how to do this without `inspect` module
# try: # Some callables don't have signatures, so ignore those errors
# params = list(inspect.signature(f).parameters.values())[:min(len(ann),2)]
# if any(p.default!=inspect.Parameter.empty for p in params):
# warn(f"{f.__name__} has default params. These will be ignored.")
# except ValueError: pass

# Cell
def _p2_anno(f):
"Get the 1st 2 annotations of `f`, defaulting to `object`"
hints = type_hints(f)
ann = [o for n,o in hints.items() if n!='return']
if callable(f): _chk_defaults(f, ann)
while len(ann)<2: ann.append(object)
return ann[:2]
class FastFunction(Function):
def __repr__(self):
return '\n'.join(f"{f.__name__}: ({','.join(_pt_repr(t) for t in s.types)}) -> {_pt_repr(r)}"
for s, (f, r) in self.methods.items())

# Cell
class _TypeDict:
def __init__(self): self.d,self.cache = {},{}

def _reset(self):
self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)}
self.cache = {}

def add(self, t, f):
"Add type `t` and function `f`"
if not isinstance(t, tuple): t = tuple(L(union2tuple(t)))
for t_ in t: self.d[t_] = f
self._reset()

def all_matches(self, k):
"Find first matching type that is a super-class of `k`"
if k not in self.cache:
types = [f for f in self.d if lenient_issubclass(k,f)]
self.cache[k] = [self.d[o] for o in types]
return self.cache[k]

def __getitem__(self, k):
"Find first matching type that is a super-class of `k`"
res = self.all_matches(k)
return res[0] if len(res) else None

def __repr__(self): return self.d.__repr__()
def first(self): return first(self.d.values())
@delegates(Function.dispatch)
def dispatch(self, f=None, **kwargs): return super().dispatch(_eval_annotations(f), **kwargs)

# Cell
class TypeDispatch:
"Dictionary-like object; `__getitem__` matches keys of types using `issubclass`"
def __init__(self, funcs=(), bases=()):
self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None))
for o in L(funcs): self.add(o)
self.inst = None
self.owner = None

def add(self, f):
"Add type `t` and function `f`"
if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__)
else: a0,a1 = _p2_anno(f)
t = self.funcs.d.get(a0)
if t is None:
t = _TypeDict()
self.funcs.add(a0, t)
t.add(a1, f)

def first(self):
"Get first function in ordered dict of type:func."
return self.funcs.first().first()

def returns(self, x):
"Get the return type of annotation of `x`."
return anno_ret(self[type(x)])

def _attname(self,k): return getattr(k,'__name__',str(k))
def __repr__(self):
r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, "__name__", type(v).__name__)}'
for k in self.funcs.d for l,v in self.funcs[k].d.items()]
r = r + [o.__repr__() for o in self.bases]
return '\n'.join(r)

def __call__(self, *args, **kwargs):
ts = L(args).map(type)[:2]
f = self[tuple(ts)]
if not f: return args[0]
if isinstance(f, staticmethod): f = f.__func__
elif self.inst is not None: f = MethodType(f, self.inst)
elif self.owner is not None: f = MethodType(f, self.owner)
return f(*args, **kwargs)

def __get__(self, inst, owner):
self.inst = inst
self.owner = owner
return self

def __getitem__(self, k):
"Find first matching type that is a super-class of `k`"
k = L(k)
while len(k)<2: k.append(object)
r = self.funcs.all_matches(k[0])
for t in r:
o = t[k[1]]
if o is not None: return o
for base in self.bases:
res = base[k]
if res is not None: return res
return None
def __getitem__(self, ts):
"Return the most-specific matching method with fewest parameters"
ts = L(ts)
nargs = min(len(o) for o in self.methods.keys())
while len(ts) < nargs: ts.append(object)
return self.invoke(*ts)

# Cell
class DispatchReg:
"A global registry for `TypeDispatch` objects keyed by function name"
def __init__(self): self.d = defaultdict(TypeDispatch)
def __call__(self, f):
if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}'
else: nm = f'{f.__qualname__}'
if isinstance(f, classmethod): f=f.__func__
self.d[nm].add(f)
return self.d[nm]

typedispatch = DispatchReg()
class FastDispatcher(Dispatcher):
def _get_function(self, method, owner):
"Adapted from `Dispatcher._get_function` to use `FastFunction`"
name = method.__name__
if owner:
if owner not in self._classes: self._classes[owner] = {}
namespace = self._classes[owner]
else: namespace = self._functions
if name not in namespace: namespace[name] = FastFunction(method, owner=owner)
return namespace[name]

@delegates(Dispatcher.__call__, but='method')
def __call__(self, f, **kwargs): return super().__call__(_eval_annotations(f), **kwargs)

def _to(self, cls, nm, f, **kwargs):
nf = copy_func(f)
nf.__qualname__ = f'{cls.__name__}.{nm}' # plum uses __qualname__ to infer f's owner
pf = self(nf, **kwargs)
# plum uses __set_name__ to resolve a plum.Function's owner
# since we assign after class creation, __set_name__ must be called directly
# source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__
pf.__set_name__(cls, nm)
pf = pf.resolve()
setattr(cls, nm, pf)
return pf

def to(self, cls):
"Decorator: dispatch `f` to `cls.f`"
def _inner(f, **kwargs):
nm = f.__name__
# check __dict__ to avoid inherited methods but use getattr so pf.__get__ is called, which plum relies on
if nm in cls.__dict__:
pf = getattr(cls, nm)
if not hasattr(pf, 'dispatch'): pf = self._to(cls, nm, pf, **kwargs)
pf.dispatch(f)
else: pf = self._to(cls, nm, f, **kwargs)
return pf
return _inner

typedispatch = FastDispatcher()

# Cell
#nbdev_comment _all_=['cast']
Expand Down
10 changes: 10 additions & 0 deletions fastcore/imports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys,os,re,typing,itertools,operator,functools,math,warnings,functools,io,enum

from copy import copy
from operator import itemgetter,attrgetter
from warnings import warn
from typing import Iterable,Generator,Sequence,Iterator,List,Set,Dict,Union,Optional,Tuple
Expand All @@ -14,6 +15,15 @@
MethodDescriptorType = type(str.join)
from types import BuiltinFunctionType,BuiltinMethodType,MethodType,FunctionType,SimpleNamespace

#Patch autoreload (if its loaded) to work with plum
try: from IPython import get_ipython
except ImportError: pass
else:
ip = get_ipython()
if ip is not None and 'IPython.extensions.storemagic' in ip.extension_manager.loaded:
from plum.autoreload import activate
activate()

NoneType = type(None)
string_classes = (str,bytes)

Expand Down
38 changes: 21 additions & 17 deletions fastcore/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,28 @@
from .utils import *
from .dispatch import *
import inspect
from plum import add_conversion_method

# Cell
_tfm_methods = 'encodes','decodes','setups'

def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)

class _TfmDict(dict):
def __setitem__(self, k, v):
if not _is_tfm_method(k, v): return super().__setitem__(k,v)
if k not in self: super().__setitem__(k,TypeDispatch())
self[k].add(v)
def __setitem__(self, k, v): super().__setitem__(k, typedispatch(v) if _is_tfm_method(k, v) else v)

# Cell
class _TfmMeta(type):
def __new__(cls, name, bases, dict):
res = super().__new__(cls, name, bases, dict)
for nm in _tfm_methods:
base_td = [getattr(b,nm,None) for b in bases]
if nm in res.__dict__: getattr(res,nm).bases = base_td
else: setattr(res, nm, TypeDispatch(bases=base_td))
# _TfmMeta.__call__ shadows the signature of inheriting classes, set it back
res = super().__new__(cls, name, bases, dict)
res.__signature__ = inspect.signature(res.__init__)
return res

def __call__(cls, *args, **kwargs):
f = first(args)
n = getattr(f, '__name__', None)
if _is_tfm_method(n, f):
getattr(cls,n).add(f)
return f
if _is_tfm_method(n, f): return typedispatch.to(cls)(f)
obj = super().__call__(*args, **kwargs)
# _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable
seeM marked this conversation as resolved.
Show resolved Hide resolved
# instances of cls, fix it
Expand Down Expand Up @@ -67,13 +59,14 @@ def __init__(self, enc=None, dec=None, split_idx=None, order=None):
self.init_enc = enc or dec
if not self.init_enc: return

self.encodes,self.decodes,self.setups = TypeDispatch(),TypeDispatch(),TypeDispatch()
def identity(x): return x
for n in _tfm_methods: setattr(self,n,FastFunction(identity).dispatch(identity))
if enc:
self.encodes.add(enc)
self.encodes.dispatch(enc)
self.order = getattr(enc,'order',self.order)
if len(type_hints(enc)) > 0: self.input_types = union2tuple(first(type_hints(enc).values()))
self._name = _get_name(enc)
if dec: self.decodes.add(dec)
if dec: self.decodes.dispatch(dec)

@property
def name(self): return getattr(self, '_name', _get_name(self))
Expand All @@ -92,13 +85,24 @@ def _call(self, fn, x, split_idx=None, **kwargs):
def _do_call(self, f, x, **kwargs):
if not _is_tuple(x):
if f is None: return x
ret = f.returns(x) if hasattr(f,'returns') else None
return retain_type(f(x, **kwargs), x, ret)
ts = [type(self),type(x)] if hasattr(f,'instance') else [type(x)]
_, ret = f.resolve_method(*ts)
ret = ret._type
# plum reads empty return annotation as object, retain_type expects it as None
if ret is object: ret = None
return retain_type(f(x,**kwargs), x, ret)
res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
return retain_type(res, x)
def encodes(self, x): return x
def decodes(self, x): return x
def setups(self, dl): return dl

add_docs(Transform, decode="Delegate to <code>decodes</code> to undo transform", setup="Delegate to <code>setups</code> to set up transform")

# Cell
#Implement the Transform convention that a None return annotation disables conversion
add_conversion_method(object, NoneType, lambda x: x)

# Cell
class InplaceTransform(Transform):
"A `Transform` that modifies in-place and just returns whatever it's passed"
Expand Down
Loading