Skip to content
This repository has been archived by the owner on Sep 17, 2024. It is now read-only.

Commit

Permalink
feat: simplify module routing further
Browse files Browse the repository at this point in the history
- Removed PartialType
- ModuleRouter.__init__() accepts a ProtocolIdentifier
- Modules now must define protocol as a string

BREAKING CHANGE: Modules must define protocol as a string and
ModuleRouter must accept a protocol identifier

Signed-off-by: Daniel Bluhm <dbluhm@pm.me>
  • Loading branch information
dbluhm committed Oct 11, 2021
1 parent 61354c3 commit d641cda
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 214 deletions.
263 changes: 150 additions & 113 deletions aries_staticagent/module.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,37 @@
""" Module base class """

from abc import ABC, abstractclassmethod
from abc import ABC
from functools import partial
from typing import Callable, Dict, Iterable, Mapping, NamedTuple, Optional, Union
from typing import (
Callable,
ClassVar,
Dict,
Iterable,
Mapping,
TypeVar,
Union,
cast,
)

from .message import MsgType
from .message import MsgType, ProtocolIdentifier


class PartialType(NamedTuple):
"""Class containing the type information of a route before having the
context of the module as is the case when statically defining routes in
a module definition.
"""
RouteFunc = TypeVar("RouteFunc", bound=Callable)

name: Optional[str] = None
doc_uri: Optional[str] = None
protocol: Optional[str] = None
version: Optional[str] = None

def complete(
self,
doc_uri: str = None,
protocol: str = None,
version: str = None,
name: str = None,
) -> MsgType:
"""Return a complete type given the module context."""
doc_uri = self.doc_uri or doc_uri or ""
protocol = self.protocol or protocol or ""
version = self.version or version or ""
name = self.name or name or ""
return MsgType.unparse(doc_uri, protocol, version, name)


class ModuleRouter(Mapping[Union[PartialType, MsgType], Callable]):
class ModuleRouter(Mapping[MsgType, Callable]):
"""Collect module routes."""

def __init__(self):
self._routes: Dict[Union[MsgType, PartialType], Callable] = {}
def __init__(
self,
protocol: Union[str, ProtocolIdentifier],
):
if not isinstance(protocol, ProtocolIdentifier):
protocol = ProtocolIdentifier(protocol)
self.protocol = protocol
self._routes: Dict[Union[str, MsgType], Callable] = {}

def __getitem__(self, item: MsgType) -> Callable:
def __getitem__(self, item: Union[str, MsgType]) -> Callable:
return self._routes[item]

def __iter__(self) -> Iterable:
Expand All @@ -48,122 +40,167 @@ def __iter__(self) -> Iterable:
def __len__(self):
return len(self._routes)

def __call__(self, *args, **kwargs):
"""Route definition decorator.
if just @route is used, type_or_func is the decorated function
if @route(type) is used, type_or_func is the type string.
"""
if args:
type_or_func: Union[Callable, str, MsgType] = args[0]
if callable(type_or_func):
func = type_or_func
self._routes[PartialType(name=func.__name__)] = func
return func

if isinstance(type_or_func, MsgType):
msg_type = type_or_func

def _route_from_type(func):
self._routes[msg_type] = func
return func

return _route_from_type

if isinstance(type_or_func, str):
msg_type_str = type_or_func

def _route_from_str(func):
self._routes[MsgType(msg_type_str)] = func
return func

return _route_from_str

if kwargs:

def _route_from_kwargs(func):
self._routes[PartialType(**kwargs)] = func
return func

return _route_from_kwargs

raise ValueError(
"Expecting @route before a function or @route(msg_type) "
"before a function!"
)

def complete(
def route(
self,
func_or_name: Union[RouteFunc, str] = None,
*,
doc_uri: str = None,
protocol: str = None,
version: str = None,
name: str = None,
context: object = None,
) -> Dict[MsgType, Callable]:
routes = {}
for msg_type, handler in self._routes.items():
if isinstance(msg_type, PartialType):
route_type = msg_type.complete(doc_uri, protocol, version, name)
elif isinstance(msg_type, MsgType):
route_type = msg_type
else:
raise TypeError(
f"Route of invaild type {type(msg_type).__name__} registered"
)
routes[route_type] = partial(handler, context) if context else handler
return routes
msg_type: Union[str, MsgType] = None
) -> Union[Callable[..., RouteFunc], RouteFunc]:
"""Decorator for defining routes within a module.
This decorator can be used in a few different ways, as demonstrated by
the following examples:
>>> router = ModuleRouter("doc/protocol/1.0")
>>> @router
... def test():
... pass
>>> assert "doc/protocol/1.0/test" in router
>>> assert router["doc/protocol/1.0/test"] == test
>>>
>>> @router("alt-name")
... def test1():
... pass
>>> assert "doc/protocol/1.0/alt-name" in router
>>> assert router["doc/protocol/1.0/alt-name"] == test1
>>>
>>> @router(msg_type="another-doc/some-protocol/2.0/name")
... def test2():
... pass
>>> assert "another-doc/some-protocol/2.0/name" in router
>>> assert router["another-doc/some-protocol/2.0/name"] == test2
>>>
>>> @router(doc_uri="another-doc/")
... def test3():
... pass
>>> assert "another-doc/protocol/1.0/test3" in router
>>> assert router["another-doc/protocol/1.0/test3"] == test3
>>> @router(protocol="some-protocol")
... def test4():
... pass
>>> assert "doc/some-protocol/1.0/test4" in router
>>> assert router["doc/some-protocol/1.0/test4"] == test4
>>>
>>> @router(version="2.0")
... def test5():
... pass
>>> assert "doc/protocol/2.0/test5" in router
>>> assert router["doc/protocol/2.0/test5"] == test5
>>>
>>> @router(name="another-alt-name")
... def test6():
... pass
>>> assert "doc/protocol/1.0/another-alt-name" in router
>>> assert router["doc/protocol/1.0/another-alt-name"] == test6
"""

if not func_or_name:
return lambda f: cast(
RouteFunc,
self.route(
f,
doc_uri=doc_uri,
protocol=protocol,
version=version,
name=name,
msg_type=msg_type,
),
)
if isinstance(func_or_name, str):
name = func_or_name
return lambda f: cast(
RouteFunc,
self.route(
f,
doc_uri=doc_uri,
protocol=protocol,
version=version,
name=name,
msg_type=msg_type,
),
)
if not isinstance(func_or_name, Callable):
raise TypeError("func is not a callable")

func: Callable = func_or_name
if msg_type:
if isinstance(msg_type, str):
msg_type = MsgType(msg_type)
self._routes[MsgType(msg_type)] = func
return func

type_to_route = MsgType.unparse(
doc_uri=doc_uri or self.protocol.doc_uri or "",
protocol=protocol or self.protocol.protocol or "",
version=version or self.protocol.version or "",
name=name or func.__name__ or "",
)
self._routes[MsgType(type_to_route)] = func
return func

def __call__(self, *args, **kwargs):
"""Route definition decorator."""
return self.route(*args, **kwargs)

def contextualize(self, context: object) -> Dict[MsgType, Callable]:
return {
msg_type: partial(handler, context) for msg_type, handler in self.items()
}


class Module(ABC): # pylint: disable=too-few-public-methods
"""Base Module class"""
"""Base Module class."""

protocol: ClassVar[str]
route: ClassVar[ModuleRouter]

def __init__(self):
self._routes = None
self._protocol_identifier = ProtocolIdentifier(self.protocol)

@property
def protocol_identifier(self) -> ProtocolIdentifier:
return self._protocol_identifier

@property
@classmethod
@abstractclassmethod
def doc_uri(cls) -> str:
"""Return doc_uri of module."""
def router(self) -> ModuleRouter:
"""Alias to route."""
return self.route

@property
@classmethod
@abstractclassmethod
def protocol(cls) -> str:
"""Return protocol of module."""
def doc_uri(self) -> str:
return self.protocol_identifier.doc_uri

@property
@classmethod
@abstractclassmethod
def version(cls) -> str:
"""Return protocol of module."""
def protocol_name(self) -> str:
return self.protocol_identifier.protocol

@property
@classmethod
@abstractclassmethod
def route(cls) -> ModuleRouter:
"""Return router for module."""
def version(self) -> str:
return self.protocol_identifier.version

def type(
self, name: str, doc_uri: str = None, protocol: str = None, version: str = None
):
"""Build a type string for this module."""
# doc_url can be falsey, need explicit none check
doc_uri = doc_uri if doc_uri is not None else self.doc_uri
protocol = protocol or self.protocol
protocol = protocol or self.protocol_name
version = version or self.version
return MsgType.unparse(doc_uri, protocol, version, name)

def _finish_routes(self) -> Mapping[MsgType, Callable]:
return self.route.complete(
self.doc_uri, self.protocol, self.version, context=self
)
def _contextualize_routes(self) -> Mapping[MsgType, Callable]:
return self.router.contextualize(context=self)

@property
def routes(self) -> Mapping[MsgType, Callable]:
"""Get the routes statically defined for this module and
save in instance.
"""
if self._routes is None:
self._routes = self._finish_routes()
self._routes = self._contextualize_routes()
return self._routes
8 changes: 3 additions & 5 deletions examples/webserver_with_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@ class BasicMessageCounter(Module):
Responds with the number of messages received.
"""

doc_uri = "https://didcomm.org/"
protocol = "basicmessage"
version = "1.0"
route = ModuleRouter()
protocol = "https://didcomm.org/basicmessage/1.0"
route = ModuleRouter(protocol)

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -45,7 +43,7 @@ def main():
async def handle(request):
"""aiohttp handle POST."""
await conn.handle(await request.read())
raise web.HTTPAccepted()
return web.Response(status=201)

app = web.Application()
app.add_routes([web.post("/", handle)])
Expand Down
Loading

0 comments on commit d641cda

Please sign in to comment.