diff --git a/aries_staticagent/module.py b/aries_staticagent/module.py index 756f869..d1acacc 100644 --- a/aries_staticagent/module.py +++ b/aries_staticagent/module.py @@ -99,6 +99,7 @@ def route( """ if not func_or_name: + # Empty @route case return lambda f: cast( RouteFunc, self.route( @@ -111,6 +112,7 @@ def route( ), ) if isinstance(func_or_name, str): + # @route("msg_name") case name = func_or_name return lambda f: cast( RouteFunc, @@ -123,23 +125,27 @@ def route( msg_type=msg_type, ), ) + + # After the previous checks, the first positional argument must now be + # the method to decorate. if not isinstance(func_or_name, Callable): raise TypeError("func is not a callable") + # Func and name (if present) in expected parameters func: Callable = func_or_name if msg_type: if isinstance(msg_type, str): msg_type = MsgType(msg_type) - self._routes[MsgType(msg_type)] = func - return func - - type_to_route = MsgType.unparse( - doc_uri=doc_uri or self.protocol.doc_uri or "", - protocol=protocol or self.protocol.protocol or "", - version=version or self.protocol.version or "", - name=name or func.__name__ or "", - ) - self._routes[MsgType(type_to_route)] = func + type_to_route = msg_type + else: + type_to_route = MsgType.unparse( + doc_uri=doc_uri or self.protocol.doc_uri or "", + protocol=protocol or self.protocol.protocol or "", + version=version or self.protocol.version or "", + name=name or func.__name__ or "", + ) + + self._routes[type_to_route] = func return func def __call__(self, *args, **kwargs): @@ -147,6 +153,7 @@ def __call__(self, *args, **kwargs): return self.route(*args, **kwargs) def contextualize(self, context: object) -> Dict[MsgType, Callable]: + """Return routes with handlers wrapped as partials to include 'self'.""" return { msg_type: partial(handler, context) for msg_type, handler in self.items() } @@ -164,6 +171,7 @@ def __init__(self): @property def protocol_identifier(self) -> ProtocolIdentifier: + """Parsed protocol identifier.""" return self._protocol_identifier @property @@ -173,14 +181,17 @@ def router(self) -> ModuleRouter: @property def doc_uri(self) -> str: + """Protocol doc URI.""" return self.protocol_identifier.doc_uri @property def protocol_name(self) -> str: + """Protocol name.""" return self.protocol_identifier.protocol @property def version(self) -> str: + """Protocol version.""" return self.protocol_identifier.version def type(