From 5f577339057c5a31b4571f953d846f828a25b632 Mon Sep 17 00:00:00 2001 From: Aleksander Binion Date: Tue, 1 Oct 2024 19:15:41 -0400 Subject: [PATCH] add test + fix bug --- minject/inject.py | 4 +--- minject/registry.py | 35 +++++++++++++++-------------------- tests/test_async.py | 10 ++++++++++ 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/minject/inject.py b/minject/inject.py index cb3948f..f2ab6f3 100644 --- a/minject/inject.py +++ b/minject/inject.py @@ -150,9 +150,7 @@ def resolve(self, registry_impl: Resolver) -> T_co: async def aresolve(self, registry_impl: Resolver) -> T_co: if _is_key_async(self._key): - result = await registry_impl._aresolve(self._key) - await registry_impl._push_async_context(result) - return result + return await registry_impl._aresolve(self._key) return registry_impl.resolve(self._key) @property diff --git a/minject/registry.py b/minject/registry.py index 925d9af..baae2ee 100644 --- a/minject/registry.py +++ b/minject/registry.py @@ -8,7 +8,7 @@ from typing_extensions import Concatenate, ParamSpec -from minject.inject import _is_key_async, _RegistryReference +from minject.inject import _is_key_async, _RegistryReference, reference from .config import RegistryConfigWrapper, RegistryInitConfig from .metadata import RegistryMetadata, _get_meta, _get_meta_from_key @@ -103,7 +103,7 @@ def resolve(self, key: "RegistryKey[T]") -> T: return self[key] async def _aresolve(self, key: "RegistryKey[T]") -> T: - result = await self._aget(key) + result = await self.aget(key) if result is None: raise KeyError(key, "could not be resolved") return result @@ -241,6 +241,12 @@ async def _aregister_by_metadata(self, meta: RegistryMetadata[T]) -> RegistryWra # add to our list of all objects (this MUST happen after init so # any references come earlier in sequence and are destroyed first) self._objects.append(wrapper) + + # after creating an object, enter the objects context + # if it is marked with the @async_context decorator. + if meta.is_async_context(): + await self._push_async_context(obj) + success = True finally: if not success: @@ -333,7 +339,13 @@ def get( return _unwrap(self._get_by_metadata(meta, default)) - async def _aget(self, key: "RegistryKey[T]") -> Optional[T]: + async def aget(self, key: "RegistryKey[T]") -> Optional[T]: + """ + Resolve objects marked with the @async_context decorator. + """ + if not _is_key_async(key): + raise RegistryAPIError("cannot use aget outside of async context") + if not self._async_can_proceed: raise RegistryAPIError("cannot use aget outside of async context") @@ -375,23 +387,6 @@ def _get_if_already_in_registry( # nothing has been registered for this metadata yet return None - async def aget(self, key: "RegistryKey[T]") -> Optional[T]: - """ - Resolve objects marked with the @async_context decorator. - """ - if not _is_key_async(key): - raise RegistryAPIError("cannot use aget outside of async context") - - maybe_class = self._get_if_already_in_registry(key, _get_meta_from_key(key)) - if maybe_class is not None: - return maybe_class - - # this is done differently than the sync version because the async version - # requires that the top level context of key be entered, and contexts are entered - # in a RegistryReference's aresolve method - reference = _RegistryReference(key) - return await self._aresolve_resolvable(reference) - async def __aenter__(self) -> "Registry": """ Mark a registry instance as ready for resolving async objects. diff --git a/tests/test_async.py b/tests/test_async.py index c875dde..d7a350d 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -177,6 +177,16 @@ async def test_multiple_instantiation_top_level(registry: Registry) -> None: assert my_counter.exited_context_counter == 1 +async def test_multiple_instantiation_mixed(registry: Registry) -> None: + my_counter: MyAsyncAPIContextCounter + async with registry as r: + my_counter = await r.aget(MyAsyncAPIContextCounter) + assert my_counter.entered_context_counter == 1 + await r.aget(MyAsyncApi) + assert my_counter.entered_context_counter == 1 + assert my_counter.exited_context_counter == 1 + + async def test_async_context_outside_context_manager(registry: Registry) -> None: with pytest.raises(RegistryAPIError): # attempting to instantiate a class