Skip to content

Commit

Permalink
Merge pull request #92 from fzyukio/master
Browse files Browse the repository at this point in the history
Allow provider to be a context manager (sync/async)
  • Loading branch information
ivankorobkov authored Nov 23, 2023
2 parents 5be9189 + f5272ca commit a1c5b7b
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 3 deletions.
46 changes: 43 additions & 3 deletions src/inject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def my_config(binder):
inject.configure(my_config)
"""
import contextlib

from inject._version import __version__

import inspect
Expand Down Expand Up @@ -156,7 +158,10 @@ def bind_to_constructor(self, cls: Binding, constructor: Constructor) -> 'Binder
return self

def bind_to_provider(self, cls: Binding, provider: Provider) -> 'Binder':
"""Bind a class to a callable instance provider executed for each injection."""
"""
Bind a class to a callable instance provider executed for each injection.
A provider can be a normal function or a context manager. Both sync and async are supported.
"""
self._check_class(cls)
if provider is None:
raise InjectorException('Provider cannot be None, key=%s' % cls)
Expand Down Expand Up @@ -323,6 +328,35 @@ class _ParametersInjection(Generic[T]):
def __init__(self, **kwargs: Any) -> None:
self._params = kwargs

@staticmethod
def _aggregate_sync_stack(
sync_stack: contextlib.ExitStack,
provided_params: frozenset[str],
kwargs: dict[str, Any]
) -> None:
"""Extracts context managers, aggregate them in an ExitStack and swap out the param value with results of
running __enter__(). The result is equivalent to using `with` multiple times """
executed_kwargs = {
param: sync_stack.enter_context(inst)
for param, inst in kwargs.items()
if param not in provided_params and isinstance(inst, contextlib._GeneratorContextManager)
}
kwargs.update(executed_kwargs)

@staticmethod
async def _aggregate_async_stack(
async_stack: contextlib.AsyncExitStack,
provided_params: frozenset[str],
kwargs: dict[str, Any]
) -> None:
"""Similar to _aggregate_sync_stack, but for async context managers"""
executed_kwargs = {
param: await async_stack.enter_async_context(inst)
for param, inst in kwargs.items()
if param not in provided_params and isinstance(inst, contextlib._AsyncGeneratorContextManager)
}
kwargs.update(executed_kwargs)

def __call__(self, func: Callable[..., Union[Awaitable[T], T]]) -> Callable[..., Union[Awaitable[T], T]]:
if sys.version_info.major == 2:
arg_names = inspect.getargspec(func).args
Expand All @@ -340,7 +374,11 @@ async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T:
kwargs[param] = instance(cls)
async_func = cast(Callable[..., Awaitable[T]], func)
try:
return await async_func(*args, **kwargs)
with contextlib.ExitStack() as sync_stack:
async with contextlib.AsyncExitStack() as async_stack:
self._aggregate_sync_stack(sync_stack, provided_params, kwargs)
await self._aggregate_async_stack(async_stack, provided_params, kwargs)
return await async_func(*args, **kwargs)
except TypeError as previous_error:
raise ConstructorTypeError(func, previous_error)

Expand All @@ -355,7 +393,9 @@ def injection_wrapper(*args: Any, **kwargs: Any) -> T:
kwargs[param] = instance(cls)
sync_func = cast(Callable[..., T], func)
try:
return sync_func(*args, **kwargs)
with contextlib.ExitStack() as sync_stack:
self._aggregate_sync_stack(sync_stack, provided_params, kwargs)
return sync_func(*args, **kwargs)
except TypeError as previous_error:
raise ConstructorTypeError(func, previous_error)
return injection_wrapper
Expand Down
107 changes: 107 additions & 0 deletions test/test_context_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import contextlib

import inject
from test import BaseTestInject


class Destroyable:
def __init__(self):
self.started = True

def destroy(self):
self.started = False


class MockFile(Destroyable):
...


class MockConnection(Destroyable):
...


class MockFoo(Destroyable):
...


@contextlib.contextmanager
def get_file_sync():
obj = MockFile()
yield obj
obj.destroy()


@contextlib.contextmanager
def get_conn_sync():
obj = MockConnection()
yield obj
obj.destroy()


@contextlib.contextmanager
def get_foo_sync():
obj = MockFoo()
yield obj
obj.destroy()


@contextlib.asynccontextmanager
async def get_file_async():
obj = MockFile()
yield obj
obj.destroy()


@contextlib.asynccontextmanager
async def get_conn_async():
obj = MockConnection()
yield obj
obj.destroy()


class TestContextManagerFunctional(BaseTestInject):

def test_provider_as_context_manager_sync(self):
def config(binder):
binder.bind_to_provider(MockFile, get_file_sync)
binder.bind(int, 100)
binder.bind_to_provider(str, lambda: "Hello")
binder.bind_to_provider(MockConnection, get_conn_sync)

inject.configure(config)

@inject.autoparams()
def mock_func(conn: MockConnection, name: str, f: MockFile, number: int):
assert f.started
assert conn.started
assert name == "Hello"
assert number == 100
return f, conn

f_, conn_ = mock_func()
assert not f_.started
assert not conn_.started

def test_provider_as_context_manager_async(self):
def config(binder):
binder.bind_to_provider(MockFile, get_file_async)
binder.bind(int, 100)
binder.bind_to_provider(str, lambda: "Hello")
binder.bind_to_provider(MockConnection, get_conn_async)
binder.bind_to_provider(MockFoo, get_foo_sync)

inject.configure(config)

@inject.autoparams()
async def mock_func(conn: MockConnection, name: str, f: MockFile, number: int, foo: MockFoo):
assert f.started
assert conn.started
assert foo.started
assert name == "Hello"
assert number == 100
return f, conn, foo

f_, conn_, foo_ = self.run_async(mock_func())
assert not f_.started
assert not conn_.started
assert not foo_.started

0 comments on commit a1c5b7b

Please sign in to comment.