Source code for multimethod

import abc
import collections
import functools
import inspect
import itertools
import types
import typing
from typing import Callable, Iterable, Iterator, Mapping, Union

__version__ = '1.4'


def groupby(func: Callable, values: Iterable) -> dict:
    """Return mapping of key function to values."""
    groups = collections.defaultdict(list)  # type: dict
    for value in values:
        groups[func(value)].append(value)
    return groups


def get_types(func: Callable) -> tuple:
    """Return evaluated type hints in order."""
    if not hasattr(func, '__annotations__'):
        return ()
    annotations = dict(typing.get_type_hints(func))
    annotations.pop('return', None)
    params = inspect.signature(func).parameters
    return tuple(annotations.pop(name, object) for name in params if annotations)


[docs]class DispatchError(TypeError): pass
class subtype(type): """A normalized generic type which checks subscripts.""" def __new__(cls, tp, *args): if tp is typing.Any: return object if isinstance(tp, typing.TypeVar): if not tp.__constraints__: return object tp = typing.Union[tp.__constraints__] origin = getattr(tp, '__extra__', getattr(tp, '__origin__', tp)) args = tuple(map(cls, getattr(tp, '__args__', None) or args)) if set(args) <= {object} and not (origin is tuple and args): return origin bases = (origin,) if type(origin) is type else () namespace = {'__origin__': origin, '__args__': args} return type.__new__(cls, str(tp), bases, namespace) def __init__(self, tp, *args): if isinstance(self.__origin__, abc.ABCMeta): self.__origin__.register(self) def __getstate__(self): return self.__origin__, self.__args__ def __eq__(self, other): return isinstance(other, subtype) and self.__getstate__() == other.__getstate__() def __hash__(self): return hash(self.__getstate__()) def __subclasscheck__(self, subclass): origin = getattr(subclass, '__extra__', getattr(subclass, '__origin__', subclass)) args = getattr(subclass, '__args__', ()) if origin is typing.Union: return all(issubclass(cls, self) for cls in args) if self.__origin__ is typing.Union: return issubclass(subclass, self.__args__) return ( # check args first to avoid a recursion error in ABCMeta len(args) == len(self.__args__) and issubclass(origin, self.__origin__) and all(map(issubclass, args, self.__args__)) ) @classmethod def subcheck(cls, tp): """Return whether type requires checking subscripts using `get_type`.""" return isinstance(tp, cls) and ( tp.__origin__ is not Union or any(map(cls.subcheck, tp.__args__)) ) class signature(tuple): """A tuple of types that supports partial ordering.""" parents = None # type: set def __new__(cls, types: Iterable): return tuple.__new__(cls, map(subtype, types)) def __le__(self, other) -> bool: return len(self) <= len(other) and all(map(issubclass, other, self)) def __lt__(self, other) -> bool: return self != other and self <= other def __sub__(self, other) -> tuple: """Return relative distances, assuming self >= other.""" mros = (subclass.mro() for subclass in self) return tuple(mro.index(cls if cls in mro else object) for mro, cls in zip(mros, other))
[docs]class multimethod(dict): """A callable directed acyclic graph of methods.""" pending = None # type: set
[docs] def __new__(cls, func): namespace = inspect.currentframe().f_back.f_locals self = functools.update_wrapper(dict.__new__(cls), func) self.pending = set() self.get_type = type # default type checker return namespace.get(func.__name__, self)
[docs] def __init__(self, func: Callable): try: self[get_types(func)] = func except NameError: self.pending.add(func)
[docs] def register(self, *args): """Decorator for registering a function. Optionally call with types to return a decorator for unannotated functions. """ if len(args) == 1 and hasattr(args[0], '__annotations__'): return overload.register(self, *args) return lambda func: self.__setitem__(args, func) or func
def __get__(self, instance, owner): return self if instance is None else types.MethodType(self, instance)
[docs] def parents(self, types: tuple) -> set: """Find immediate parents of potential key.""" parents = {key for key in self if isinstance(key, signature) and key < types} return parents - {ancestor for parent in parents for ancestor in parent.parents}
[docs] def clean(self): """Empty the cache.""" for key in list(self): if not isinstance(key, signature): super().__delitem__(key)
[docs] def __setitem__(self, types: tuple, func: Callable): self.clean() types = signature(types) parents = types.parents = self.parents(types) for key in self: if types < key and (not parents or parents & key.parents): key.parents -= parents key.parents.add(types) if any(map(subtype.subcheck, types)): self.get_type = get_type # switch to slower generic type checker super().__setitem__(types, func) self.__doc__ = self.docstring
[docs] def __delitem__(self, types: tuple): self.clean() super().__delitem__(types) for key in self: if types in key.parents: key.parents = self.parents(key) self.__doc__ = self.docstring
[docs] def __missing__(self, types: tuple) -> Callable: """Find and cache the next applicable method of given types.""" self.evaluate() if types in self: return self[types] groups = groupby(signature(types).__sub__, self.parents(types)) keys = groups[min(groups)] if groups else [] funcs = {self[key] for key in keys} if len(funcs) == 1: return self.setdefault(types, *funcs) msg = f"{self.__name__}: {len(keys)} methods found" # type: ignore raise DispatchError(msg, types, keys)
[docs] def __call__(self, *args, **kwargs): """Resolve and dispatch to best method.""" return self[tuple(map(self.get_type, args))](*args, **kwargs)
[docs] def evaluate(self): """Evaluate any pending forward references. This can be called explicitly when using forward references, otherwise cache misses will evaluate. """ while self.pending: func = self.pending.pop() self[get_types(func)] = func
@property def docstring(self): """a descriptive docstring of all registered functions""" docs = [] for func in set(self.values()): try: sig = inspect.signature(func) except ValueError: sig = '' doc = func.__doc__ or '' docs.append(f'{func.__name__}{sig}\n {doc}') return '\n\n'.join(docs)
[docs]class multidispatch(multimethod): """Provisional wrapper for future compatibility with `functools.singledispatch`."""
get_type = multimethod(type) get_type.__doc__ = """Return a generic `subtype` which checks subscripts.""" for atomic in (Iterator, str, bytes): get_type[atomic,] = type @multimethod # type: ignore[no-redef] def get_type(arg: tuple): """Return generic type checking all values.""" return subtype(type(arg), *map(get_type, arg)) @multimethod # type: ignore[no-redef] def get_type(arg: Mapping): """Return generic type checking first item.""" return subtype(type(arg), *map(get_type, next(iter(arg.items()), ()))) @multimethod # type: ignore[no-redef] def get_type(arg: Iterable): """Return generic type checking first value.""" return subtype(type(arg), *map(get_type, itertools.islice(arg, 1)))
[docs]def isa(*types) -> Callable: """Partially bound `isinstance`.""" return lambda arg: isinstance(arg, types)
[docs]class overload(collections.OrderedDict): """Ordered functions which dispatch based on their annotated predicates.""" __get__ = multimethod.__get__
[docs] def __new__(cls, func): namespace = inspect.currentframe().f_back.f_locals self = functools.update_wrapper(super().__new__(cls), func) return namespace.get(func.__name__, self)
[docs] def __init__(self, func: Callable): self[inspect.signature(func)] = func
[docs] def __call__(self, *args, **kwargs): """Dispatch to first matching function.""" for sig, func in reversed(self.items()): arguments = sig.bind(*args, **kwargs).arguments if all( predicate(arguments[name]) for name, predicate in overload.__get_predicates(sig.parameters) ): return func(*args, **kwargs) raise DispatchError("No matching functions found")
@staticmethod def __get_predicates(parameters): for name, parameter in parameters.items(): annotation = parameter.annotation if annotation is not inspect.Parameter.empty: yield name, annotation
[docs] def register(self, func: Callable) -> Callable: """Decorator for registering a function.""" self.__init__(func) # type: ignore return self if self.__name__ == func.__name__ else func # type: ignore
class multimeta(type): """Convert all callables in namespace to multimethods.""" class __prepare__(dict): def __init__(*args): pass def __setitem__(self, key, value): if callable(value): value = getattr(self.get(key), 'register', multimethod)(value) super().__setitem__(key, value)