Skip to content

Commit

Permalink
add test + fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
biniona committed Oct 1, 2024
1 parent f00b433 commit 5f57733
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
4 changes: 1 addition & 3 deletions minject/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,7 @@ def resolve(self, registry_impl: Resolver) -> T_co:

async def aresolve(self, registry_impl: Resolver) -> T_co:
if _is_key_async(self._key):
result = await registry_impl._aresolve(self._key)
await registry_impl._push_async_context(result)
return result
return await registry_impl._aresolve(self._key)
return registry_impl.resolve(self._key)

@property
Expand Down
35 changes: 15 additions & 20 deletions minject/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from typing_extensions import Concatenate, ParamSpec

from minject.inject import _is_key_async, _RegistryReference
from minject.inject import _is_key_async, _RegistryReference, reference

from .config import RegistryConfigWrapper, RegistryInitConfig
from .metadata import RegistryMetadata, _get_meta, _get_meta_from_key
Expand Down Expand Up @@ -103,7 +103,7 @@ def resolve(self, key: "RegistryKey[T]") -> T:
return self[key]

async def _aresolve(self, key: "RegistryKey[T]") -> T:
result = await self._aget(key)
result = await self.aget(key)
if result is None:
raise KeyError(key, "could not be resolved")
return result
Expand Down Expand Up @@ -241,6 +241,12 @@ async def _aregister_by_metadata(self, meta: RegistryMetadata[T]) -> RegistryWra
# add to our list of all objects (this MUST happen after init so
# any references come earlier in sequence and are destroyed first)
self._objects.append(wrapper)

# after creating an object, enter the objects context
# if it is marked with the @async_context decorator.
if meta.is_async_context():
await self._push_async_context(obj)

success = True
finally:
if not success:
Expand Down Expand Up @@ -333,7 +339,13 @@ def get(

return _unwrap(self._get_by_metadata(meta, default))

async def _aget(self, key: "RegistryKey[T]") -> Optional[T]:
async def aget(self, key: "RegistryKey[T]") -> Optional[T]:
"""
Resolve objects marked with the @async_context decorator.
"""
if not _is_key_async(key):
raise RegistryAPIError("cannot use aget outside of async context")

if not self._async_can_proceed:
raise RegistryAPIError("cannot use aget outside of async context")

Expand Down Expand Up @@ -375,23 +387,6 @@ def _get_if_already_in_registry(
# nothing has been registered for this metadata yet
return None

async def aget(self, key: "RegistryKey[T]") -> Optional[T]:
"""
Resolve objects marked with the @async_context decorator.
"""
if not _is_key_async(key):
raise RegistryAPIError("cannot use aget outside of async context")

maybe_class = self._get_if_already_in_registry(key, _get_meta_from_key(key))
if maybe_class is not None:
return maybe_class

# this is done differently than the sync version because the async version
# requires that the top level context of key be entered, and contexts are entered
# in a RegistryReference's aresolve method
reference = _RegistryReference(key)
return await self._aresolve_resolvable(reference)

async def __aenter__(self) -> "Registry":
"""
Mark a registry instance as ready for resolving async objects.
Expand Down
10 changes: 10 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,16 @@ async def test_multiple_instantiation_top_level(registry: Registry) -> None:
assert my_counter.exited_context_counter == 1


async def test_multiple_instantiation_mixed(registry: Registry) -> None:
my_counter: MyAsyncAPIContextCounter
async with registry as r:
my_counter = await r.aget(MyAsyncAPIContextCounter)
assert my_counter.entered_context_counter == 1
await r.aget(MyAsyncApi)
assert my_counter.entered_context_counter == 1
assert my_counter.exited_context_counter == 1


async def test_async_context_outside_context_manager(registry: Registry) -> None:
with pytest.raises(RegistryAPIError):
# attempting to instantiate a class
Expand Down

0 comments on commit 5f57733

Please sign in to comment.