Skip to content

Commit

Permalink
Bring back asyncio
Browse files Browse the repository at this point in the history
  • Loading branch information
cjw296 committed Oct 14, 2020
1 parent 712e80f commit 8488a47
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 202 deletions.
90 changes: 50 additions & 40 deletions mush/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,101 @@
import asyncio
from functools import partial
from typing import Callable
from typing import Callable, Dict, Any

from . import (
Context as SyncContext, Runner as SyncRunner, ResourceError, ContextError
Context as SyncContext, Runner as SyncRunner, ResourceError, ContextError, extract_returns
)
from .declarations import RequirementsDeclaration, ReturnsDeclaration
from .extraction import default_requirement_type
from .markers import get_mush, AsyncType
from .typing import RequirementModifier
from .requirements import Annotation
from .resources import ResourceValue
from .typing import DefaultRequirement


class AsyncFromSyncContext:

def __init__(self, context, loop):
self.context: Context = context
self.loop = loop
self.remove = context.remove
self.add = context.add
self.get = context.get

def call(self, obj: Callable, requires: RequirementsDeclaration = None):
coro = self.context.call(obj, requires)
future = asyncio.run_coroutine_threadsafe(coro, self.loop)
return future.result()

def extract(self, obj: Callable, requires: RequirementsDeclaration = None, returns: ReturnsDeclaration = None):
def extract(
self,
obj: Callable,
requires: RequirementsDeclaration = None,
returns: ReturnsDeclaration = None
):
coro = self.context.extract(obj, requires, returns)
future = asyncio.run_coroutine_threadsafe(coro, self.loop)
return future.result()


def async_behaviour(callable_):
to_check = callable_
if isinstance(callable_, partial):
to_check = callable_.func
if asyncio.iscoroutinefunction(to_check):
return AsyncType.async_
elif asyncio.iscoroutinefunction(to_check.__call__):
return AsyncType.async_
else:
async_type = get_mush(callable_, 'async', default=None)
if async_type is None:
if isinstance(callable_, type):
return AsyncType.nonblocking
else:
return AsyncType.blocking
else:
return async_type


class Context(SyncContext):

def __init__(self, requirement_modifier: RequirementModifier = default_requirement_type):
super().__init__(requirement_modifier)
def __init__(self, default_requirement: DefaultRequirement = Annotation):
super().__init__(default_requirement)
self._sync_context = AsyncFromSyncContext(self, asyncio.get_event_loop())
self._async_cache = {}

async def _ensure_async(self, func, *args, **kw):
async_type = self._async_cache.get(func)
if async_type is None:
to_check = func
if isinstance(func, partial):
to_check = func.func
if asyncio.iscoroutinefunction(to_check):
async_type = AsyncType.async_
elif asyncio.iscoroutinefunction(to_check.__call__):
async_type = AsyncType.async_
else:
async_type = get_mush(func, 'async', default=None)
if async_type is None:
if isinstance(func, type):
async_type = AsyncType.nonblocking
else:
async_type = AsyncType.blocking
self._async_cache[func] = async_type
behaviour = self._async_cache.get(func)
if behaviour is None:
behaviour = async_behaviour(func)
self._async_cache[func] = behaviour

if async_type is AsyncType.nonblocking:
if behaviour is AsyncType.nonblocking:
return func(*args, **kw)
elif async_type is AsyncType.blocking:
elif behaviour is AsyncType.blocking:
if kw:
func = partial(func, **kw)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, func, *args)
else:
return await func(*args, **kw)

def _context_for(self, obj):
return self if asyncio.iscoroutinefunction(obj) else self._sync_context
def _specials(self) -> Dict[type, Any]:
return {Context: self, SyncContext: self._sync_context}

async def call(self, obj: Callable, requires: RequirementsDeclaration = None):
args = []
kw = {}
resolving = self._resolve(obj, requires, args, kw, self._context_for(obj))
for requirement in resolving:
r = requirement.resolve
o = await self._ensure_async(r, self._context_for(r))
resolving.send(o)
return await self._ensure_async(obj, *args, **kw)
resolving = self._resolve(obj, requires)
for call in resolving:
result = await self._ensure_async(call.obj, *call.args, **call.kw)
if call.send:
resolving.send(result)
return result

async def extract(self,
obj: Callable,
requires: RequirementsDeclaration = None,
returns: ReturnsDeclaration = None):
result = await self.call(obj, requires)
self._process(obj, result, returns)
returns = extract_returns(obj, returns)
if returns:
self.add_by_keys(ResourceValue(result), returns)
return result


Expand Down Expand Up @@ -128,7 +138,7 @@ async def __call__(self, context: Context = None):

if getattr(manager, '__aenter__', None):
async with manager as managed:
if managed is not None:
if managed is not None and managed is not result:
context.add(managed)
# If the context manager swallows an exception,
# None should be returned, not the context manager:
Expand Down
25 changes: 22 additions & 3 deletions mush/context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import namedtuple
from typing import Optional, Callable, Union, Any, Dict, Iterable

from .callpoints import CallPoint
Expand All @@ -17,6 +18,9 @@ class ResourceError(Exception):
"""


Call = namedtuple('Call', ('obj', 'args', 'kw', 'send'))


class Context:
"Stores resources for a particular run."

Expand Down Expand Up @@ -95,9 +99,12 @@ def _find_resource(self, key):
exact = False
return None, exact

def _specials(self) -> Dict[type, Any]:
return {Context: self}

def _resolve(self, obj, requires=None, specials=None):
if specials is None:
specials: Dict[type, Any] = {Context: self}
specials = self._specials()

requires = extract_requires(obj, requires, self._default_requirement)

Expand Down Expand Up @@ -127,6 +134,13 @@ def _resolve(self, obj, requires=None, specials=None):
specials_[Requirement] = requirement
specials_[ResourceKey] = first_key
o = context._resolve(resource.provider, specials=specials_)
provider = resource.provider
resolving = context._resolve(provider, specials=specials_)
for call in resolving:
o = yield Call(call.obj, call.args, call.kw, send=True)
yield
if call.send:
resolving.send(o)
if resource.cache:
if exact and context is self:
resource.obj = o
Expand Down Expand Up @@ -159,10 +173,15 @@ def _resolve(self, obj, requires=None, specials=None):
else:
kw[parameter.target] = o

return obj(*args, **kw)
yield Call(obj, args, kw, send=False)

def call(self, obj: Callable, requires: Requires = None):
return self._resolve(obj, requires)
resolving = self._resolve(obj, requires)
for call in resolving:
result = call.obj(*call.args, **call.kw)
if call.send:
resolving.send(result)
return result

def nest(self):
nested = self.__class__(self._default_requirement)
Expand Down
3 changes: 3 additions & 0 deletions mush/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def complete(self, name: str, type_: Type_, default: Any):
return self

def process(self, obj):
"""
.. warning:: This must not block when used with an async context!
"""
for op in self.ops:
obj = op(obj)
if obj is missing:
Expand Down
Loading

0 comments on commit 8488a47

Please sign in to comment.