Skip to content
This repository has been archived by the owner on Sep 17, 2024. It is now read-only.

Commit

Permalink
feat: better typing on ModuleRouter.__call__ using @overload
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Bluhm <dbluhm@pm.me>
  • Loading branch information
dbluhm committed Oct 12, 2021
1 parent 2125c46 commit 9d1cc3a
Showing 1 changed file with 105 additions and 49 deletions.
154 changes: 105 additions & 49 deletions aries_staticagent/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Mapping,
TypeVar,
Union,
cast,
overload,
)

from .message import MsgType, ProtocolIdentifier
Expand Down Expand Up @@ -40,34 +40,77 @@ def __iter__(self) -> Iterable:
def __len__(self):
return len(self._routes)

def route(
def _route(
self,
func_or_name: Union[RouteFunc, str] = None,
func: RouteFunc,
*,
doc_uri: str = None,
protocol: str = None,
version: str = None,
name: str = None,
msg_type: Union[str, MsgType] = None
) -> Union[Callable[..., RouteFunc], RouteFunc]:
"""Decorator for defining routes within a module.
) -> RouteFunc:
"""Collect route."""
if msg_type:
if isinstance(msg_type, str):
msg_type = MsgType(msg_type)
type_to_route = msg_type
else:
type_to_route = MsgType.unparse(
doc_uri=doc_uri or self.protocol.doc_uri or "",
protocol=protocol or self.protocol.protocol or "",
version=version or self.protocol.version or "",
name=name or func.__name__ or "",
)

This decorator can be used in a few different ways, as demonstrated by
the following examples:
self._routes[type_to_route] = func
return func

@overload
def route(
self,
func_or_name: RouteFunc,
) -> RouteFunc:
"""Decorator for defining routes within a module.
>>> router = ModuleRouter("doc/protocol/1.0")
>>> @router
... def test():
... pass
>>> assert "doc/protocol/1.0/test" in router
>>> assert router["doc/protocol/1.0/test"] == test
>>>
"""
...

@overload
def route(
self,
func_or_name: str,
) -> Callable[..., RouteFunc]:
"""Decorator for defining routes within a module.
>>> router = ModuleRouter("doc/protocol/1.0")
>>> @router("alt-name")
... def test1():
... pass
>>> assert "doc/protocol/1.0/alt-name" in router
>>> assert router["doc/protocol/1.0/alt-name"] == test1
>>>
"""
...

@overload
def route(
self,
*,
doc_uri: str = None,
protocol: str = None,
version: str = None,
name: str = None,
msg_type: Union[str, MsgType] = None
) -> Callable[..., RouteFunc]:
"""Decorator for defining routes within a module.
>>> router = ModuleRouter("doc/protocol/1.0")
>>> @router(msg_type="another-doc/some-protocol/2.0/name")
... def test2():
... pass
Expand Down Expand Up @@ -97,59 +140,72 @@ def route(
>>> assert "doc/protocol/1.0/another-alt-name" in router
>>> assert router["doc/protocol/1.0/another-alt-name"] == test6
"""
...

@overload
def route(
self,
func_or_name: RouteFunc,
*,
doc_uri: str = None,
protocol: str = None,
version: str = None,
name: str = None,
msg_type: Union[str, MsgType] = None
) -> RouteFunc:
"""Decorator for defining routes within a module."""
...

def route(
self,
func_or_name: Union[RouteFunc, str] = None,
*,
doc_uri: str = None,
protocol: str = None,
version: str = None,
name: str = None,
msg_type: Union[str, MsgType] = None
) -> Union[Callable[..., RouteFunc], RouteFunc]:
"""Decorator for defining routes within a module."""

# Empty @route() case
if not func_or_name:
# Empty @route case
return lambda f: cast(
RouteFunc,
self.route(
f,
doc_uri=doc_uri,
protocol=protocol,
version=version,
name=name,
msg_type=msg_type,
),
return lambda f: self.route(
f,
doc_uri=doc_uri,
protocol=protocol,
version=version,
name=name,
msg_type=msg_type,
)

# @route("msg_name") case
if isinstance(func_or_name, str):
# @route("msg_name") case
name = func_or_name
return lambda f: cast(
RouteFunc,
self.route(
f,
doc_uri=doc_uri,
protocol=protocol,
version=version,
name=name,
msg_type=msg_type,
),
return lambda f: self.route(
f,
doc_uri=doc_uri,
protocol=protocol,
version=version,
name=name,
msg_type=msg_type,
)

# After the previous checks, the first positional argument must now be
# the method to decorate.
if not isinstance(func_or_name, Callable):
if not callable(func_or_name):
raise TypeError("func is not a callable")

# Func and name (if present) in expected parameters
func: Callable = func_or_name
if msg_type:
if isinstance(msg_type, str):
msg_type = MsgType(msg_type)
type_to_route = msg_type
else:
type_to_route = MsgType.unparse(
doc_uri=doc_uri or self.protocol.doc_uri or "",
protocol=protocol or self.protocol.protocol or "",
version=version or self.protocol.version or "",
name=name or func.__name__ or "",
)

self._routes[type_to_route] = func
return func
return self._route(
func_or_name,
doc_uri=doc_uri,
protocol=protocol,
version=version,
name=name,
msg_type=msg_type,
)

def __call__(self, *args, **kwargs):
"""Route definition decorator."""
return self.route(*args, **kwargs)

def contextualize(self, context: object) -> Dict[MsgType, Callable]:
Expand Down

0 comments on commit 9d1cc3a

Please sign in to comment.