Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Commit

Permalink
Rough POC that we can do async without touching function graph
Browse files Browse the repository at this point in the history
@Stefan this is what I meant

Basically we pass the coroutine the entire way through.
When a function depends on other data, we create a coroutine that
awaits everything. Finally, we do a gather at the end.

I *think* this is optimal but will need to dig in a bit/look into
things.

Otherwise, nifty POC with *very* rough edges.
  • Loading branch information
elijahbenizzy committed Aug 6, 2022
1 parent 9ecf72c commit 081db80
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 0 deletions.
62 changes: 62 additions & 0 deletions examples/async/poc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import asyncio
import aiohttp
import fastapi

from hamilton import driver, ad_hoc_utils


async def request_raw(request: fastapi.Request) -> dict:
return await request.json()


def foo(request_raw: fastapi.Request) -> str:
return request_raw.get('foo', 'far')


def bar(request_raw: fastapi.Request) -> str:
return request_raw.get('bar', 'baz')


async def computation1(foo: str, some_data: dict) -> bool:
await asyncio.sleep(1)
return False


async def some_data() -> dict:
async with aiohttp.ClientSession() as session:
async with session.get('http://httpbin.org/get') as resp:
return await resp.json()


async def computation2(bar: str) -> bool:
await asyncio.sleep(1)
return True


async def pipeline(computation1: bool, computation2: bool) -> dict:
await asyncio.sleep(1)
return {'computation1': computation1, 'computation2': computation2}


# Some logic similar to computation1


app = fastapi.FastAPI()


@app.post('/execute')
async def call(
request: fastapi.Request
) -> dict:
"""Handler for pipeline call"""
dr = driver.AsyncDriver(
{},
ad_hoc_utils.create_temporary_module(
pipeline,
computation1,
foo,
bar,
some_data,
computation2))
input_data = {'request_raw': request}
return await dr.raw_execute(['pipeline'], inputs=input_data)
62 changes: 62 additions & 0 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import inspect
import logging
from datetime import datetime
from typing import Dict, Collection, List, Any
Expand Down Expand Up @@ -267,6 +269,66 @@ def what_is_upstream_of(self, *node_names: str) -> List[Variable]:
return [Variable(node.name, node.type, node.tags) for node in upstream_nodes]


class AsyncGraphAdapter(base.HamiltonGraphAdapter):

def __init__(self):
self.coroutine_cache = {} # Ughh python's coroutine API is terrible

@staticmethod
def check_input_type(node_type: typing.Type, input_value: typing.Any) -> bool:
# TODO -- check if its a coroutine
return True

@staticmethod
def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type) -> bool:
return True # We should just delegate to the superclass -- that should have a default implementation...

async def process_value(self, val) -> Any:
if not inspect.iscoroutine(val):
return val
val_id = id(val)
if val_id in self.coroutine_cache:
return self.coroutine_cache[val_id]
output = await val
self.coroutine_cache[val_id] = output
return output

def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any:
callabl = node.callable

async def new_fn(fn=callabl, **fn_kwargs):
fn_kwargs = {key: await self.process_value(value) for key, value in fn_kwargs.items()}
if inspect.iscoroutinefunction(fn):
return await(fn(**fn_kwargs))
return fn(**fn_kwargs)

coroutine = asyncio.coroutine(new_fn)
return coroutine(**kwargs)

@staticmethod
def build_result(**outputs: typing.Dict[str, typing.Any]) -> typing.Any:
return outputs # don't really care yet


class AsyncDriver(Driver):
def __init__(self, config, *modules):
super(AsyncDriver, self).__init__(config, *modules, adapter=AsyncGraphAdapter())

async def raw_execute(self,
final_vars: List[str],
overrides: Dict[str, Any] = None,
display_graph: bool = False, # don't care
inputs: Dict[str, Any] = None) -> Dict[str, Any]:
nodes, user_nodes = self.graph.get_upstream_nodes(final_vars, inputs)
memoized_computation = dict() # memoized storage
self.graph.execute(nodes, memoized_computation, overrides, inputs)
coroutines_to_await = []
for key in final_vars:
coroutines_to_await.append(asyncio.create_task(self.adapter.process_value(memoized_computation[key])))
final_var_results = await asyncio.gather(*coroutines_to_await)
return {key: value for key, value in zip(final_vars, final_var_results)}


if __name__ == '__main__':
"""some example test code"""
import sys
Expand Down

0 comments on commit 081db80

Please sign in to comment.