Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add implementation of the DiMiddleware to didiator #16

Merged
merged 7 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion didiator/dispatchers/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ async def send(self, command: Command[CRes], *args: Any, **kwargs: Any) -> CRes:
try:
return await self._handle(command, *args, **kwargs)
except HandlerNotFound:
raise CommandHandlerNotFound()
raise CommandHandlerNotFound(f"Command handler for {type(command).__name__} command is not registered", command)
2 changes: 1 addition & 1 deletion didiator/dispatchers/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ async def query(self, query: Query[QRes], *args: Any, **kwargs: Any) -> QRes:
try:
return await self._handle(query, *args, **kwargs)
except HandlerNotFound:
raise QueryHandlerNotFound()
raise QueryHandlerNotFound(f"Query handler for {type(query).__name__} query is not registered", query)
5 changes: 3 additions & 2 deletions didiator/dispatchers/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

from didiator.interface.entities.request import Request
from didiator.interface.exceptions import HandlerNotFound
from didiator.interface.handlers import HandlerType
from didiator.middlewares.base import Middleware
from didiator.interface.dispatchers.request import Dispatcher, MiddlewareType, HandlerType
from didiator.interface.dispatchers.request import Dispatcher, MiddlewareType

RRes = TypeVar("RRes")
R = TypeVar("R", bound=Request[Any])
Expand Down Expand Up @@ -33,7 +34,7 @@ async def _handle(self, request: Request[RRes], *args: Any, **kwargs: Any) -> RR
try:
handler = self._handlers[type(request)]
except KeyError:
raise HandlerNotFound()
raise HandlerNotFound(f"Request handler for {type(request).__name__} request is not registered", request)

# Handler has to be wrapped with at least one middleware to initialize the handler if it is necessary
middlewares = self._middlewares if self._middlewares else DEFAULT_MIDDLEWARES
Expand Down
22 changes: 12 additions & 10 deletions didiator/interface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# from .command import Command, CommandHandler
from didiator.interface.dispatchers.command import CommandDispatcher
from didiator.interface.dispatchers.request import Dispatcher
from .dispatchers.command import CommandDispatcher
from .dispatchers.request import Dispatcher
from .dispatchers.query import QueryDispatcher
from .mediator import CommandMediator, Mediator, QueryMediator
from didiator.interface.dispatchers.query import QueryDispatcher
# from .query import Query, QueryHandler
from .entities import Command, Query, Request
from .handlers import CommandHandler, Handler, QueryHandler


__all__ = (
"Mediator",
Expand All @@ -12,9 +13,10 @@
"Dispatcher",
"CommandDispatcher",
"QueryDispatcher",
# "Command",
# "CommandHandler",
# "Query",
# "QueryHandler",
"Request",
"Command",
"Query",
"Handler",
"CommandHandler",
"QueryHandler",
)

5 changes: 2 additions & 3 deletions didiator/interface/dispatchers/command.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Any, Protocol, Type, TypeVar

from didiator.interface.entities.command import Command
from didiator.interface.dispatchers.request import Dispatcher, HandlerType
from didiator.interface.dispatchers.request import Dispatcher
from didiator.interface.handlers import HandlerType

C = TypeVar("C", bound=Command[Any])
CRes = TypeVar("CRes")

# HandlerType = Callable[[C], Awaitable[CRes]] | Type[RequestHandler[C, CRes]]


class CommandDispatcher(Dispatcher, Protocol):
@property
Expand Down
5 changes: 2 additions & 3 deletions didiator/interface/dispatchers/query.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Any, Protocol, Type, TypeVar

from didiator.interface.dispatchers.request import Dispatcher, HandlerType
from didiator.interface.dispatchers.request import Dispatcher
from didiator.interface.entities.query import Query
from didiator.interface.handlers import HandlerType

Q = TypeVar("Q", bound=Query[Any])
QRes = TypeVar("QRes")

# HandlerType = Callable[[Q], Awaitable[QRes]] | Type[RequestHandler[Q, QRes]]


class QueryDispatcher(Dispatcher, Protocol):
@property
Expand Down
6 changes: 1 addition & 5 deletions didiator/interface/dispatchers/request.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from typing import Any, Awaitable, Callable, Protocol, Type, TypeVar

from didiator.interface.entities.request import Request
from didiator.interface.handlers.request import HandlerType
from didiator.interface.handlers import HandlerType

R = TypeVar("R", bound=Request[Any])
RRes = TypeVar("RRes")

# MiddlewareType = Callable[[HandlerType | "MiddlewareType", C], Awaitable[CR]]
MiddlewareType = Callable[[HandlerType[R, RRes], R], Awaitable[RRes]]
# MiddlewareType = Callable[[Callable[..., Awaitable[CR]], C], CR]

# x: MiddlewareType[Command[int], str]


class Dispatcher(Protocol):
Expand Down
18 changes: 14 additions & 4 deletions didiator/interface/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
from didiator.interface import Request

class HandlerNotFound(TypeError):
...

class MediatorError(Exception):
pass


class HandlerNotFound(MediatorError, TypeError):
request: Request

def __init__(self, text: str, request: Request):
super().__init__(text)
self.request = request


class CommandHandlerNotFound(HandlerNotFound):
...
pass


class QueryHandlerNotFound(HandlerNotFound):
...
pass
19 changes: 4 additions & 15 deletions didiator/middlewares/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any, TypeVar

from didiator.interface.dispatchers.request import HandlerType
from didiator.interface.entities.request import Request
from didiator.interface.handlers.request import Handler
from didiator.interface.handlers import HandlerType

RRes = TypeVar("RRes")
R = TypeVar("R", bound=Request)
R = TypeVar("R", bound=Request[Any])


class Middleware:
Expand All @@ -25,17 +24,7 @@ async def _call(
*args: Any,
**kwargs: Any,
) -> RRes:
if isinstance(handler, type) and issubclass(handler, Handler):
if isinstance(handler, type):
handler = handler()

return await handler(request, *args, **kwargs) # noqa: type
# return await cast(
# handler,
# Callable[[HandlerType[C, CR] | "Middleware"], Awaitable[CR]],
# )(command, *args, **kwargs)


# NextMiddlewareType = Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]]
# MiddlewareType = Union[
# BaseMiddleware, Callable[[NextMiddlewareType, TelegramObject, Dict[str, Any]], Awaitable[Any]]
# ]
return await handler(request, *args, **kwargs)
91 changes: 91 additions & 0 deletions didiator/middlewares/di.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, TypeVar

from di.api.executor import SupportsAsyncExecutor
from di.api.scopes import Scope
from di.api.solved import SolvedDependent
from di.container import Container, ContainerState
from di.dependent import Dependent

from didiator.interface.entities.request import Request
from didiator.interface.handlers import HandlerType
from didiator.middlewares import Middleware

RRes = TypeVar("RRes")
R = TypeVar("R", bound=Request)
DEFAULT_STATE_KEY = "di_state"


@dataclass(frozen=True)
class MediatorDiScope:
cls_handler: Scope
func_handler: Scope


class DiMiddleware(Middleware):
def __init__(
self, di_container: Container, di_executor: SupportsAsyncExecutor, di_scopes: Sequence[Scope] = (),
*, cls_scope: Scope = ..., func_scope: Scope = "mediator_request", state_key: str = DEFAULT_STATE_KEY,
) -> None:
self._di_container = di_container
self._di_executor = di_executor

mediator_scope = MediatorDiScope(func_scope if cls_scope is ... else cls_scope, func_scope)
self._di_scopes = self._get_di_scopes(tuple(di_scopes), mediator_scope)
self._mediator_scope = mediator_scope

self._state_key = state_key
self._solved_handlers: dict[HandlerType[Any, Any], SolvedDependent[HandlerType[Any, Any]]] = {}

def _get_di_scopes(self, di_scopes: tuple[Scope, ...], mediator_scope: MediatorDiScope) -> tuple[Scope, ...]:
if mediator_scope.cls_handler not in di_scopes:
di_scopes += (mediator_scope.cls_handler,)
if mediator_scope.func_handler not in di_scopes:
di_scopes += (mediator_scope.func_handler,)
return di_scopes

async def _call(
self,
handler: HandlerType[R, RRes],
request: R,
*args: Any,
**kwargs: Any,
) -> RRes:
di_state: ContainerState | None = kwargs.pop(self._state_key, None)

if isinstance(handler, type):
return await self._call_class_handler(handler, request, di_state, *args, **kwargs)
return await self._call_func_handler(handler, request, di_state)

async def _call_class_handler(
self, handler: HandlerType[R, RRes], request: R, di_state: ContainerState | None,
*args: Any, **kwargs: Any,
) -> RRes:
solved_handler = self._get_cached_solved_handler(handler, self._mediator_scope.cls_handler)
async with self._di_container.enter_scope(self._mediator_scope.func_handler, di_state) as scoped_di_state:
handler = await self._di_container.execute_async(
solved_handler, executor=self._di_executor, state=scoped_di_state, values={type(request): request},
)
return await handler(request, *args, **kwargs)

async def _call_func_handler(
self, handler: HandlerType[R, RRes], request: R, di_state: ContainerState | None,
) -> RRes:
solved_handler = self._get_cached_solved_handler(handler, self._mediator_scope.func_handler)
async with self._di_container.enter_scope(self._mediator_scope.func_handler, di_state) as scoped_di_state:
return await self._di_container.execute_async(
solved_handler, executor=self._di_executor, state=scoped_di_state, values={type(request): request},
)

def _get_cached_solved_handler(self, handler: HandlerType, scope: Scope) -> SolvedDependent[HandlerType]:
try:
solved_handler = self._solved_handlers[handler]
print("Get cached solved handler:", solved_handler)
except KeyError:
solved_handler = self._di_container.solve(
Dependent(handler, scope=scope), scopes=self._di_scopes,
)
self._solved_handlers[handler] = solved_handler
print("Solve handler:", solved_handler)
return solved_handler
Loading