diff --git a/tmp.py b/tmp.py new file mode 100644 index 0000000..710ba2b --- /dev/null +++ b/tmp.py @@ -0,0 +1,22 @@ +import asyncio +import time +import tributary.streaming as ts +import requests + +def create(interval): + async def foo_(): + for _ in range(5): + yield interval + await asyncio.sleep(interval) + return foo_ + + +fast = ts.Foo(create(1)) +med = ts.Foo(create(2)) +slow = ts.Foo(create(3)) + +def reducer(fast, med, slow): + return {"fast": fast, "med": med, "slow": slow} + +node = ts.Reduce(fast, med, slow, reducer=reducer).print() +ts.run(node, period=1) diff --git a/tributary/streaming/__init__.py b/tributary/streaming/__init__.py index f96fe9c..de5b931 100644 --- a/tributary/streaming/__init__.py +++ b/tributary/streaming/__init__.py @@ -9,7 +9,8 @@ from .utils import * -def run(node, blocking=True, **kwargs): +def run(node, blocking=True, period=None, **kwargs): graph = node.constructGraph() kwargs["blocking"] = blocking + kwargs["period"] = period return graph.run(**kwargs) diff --git a/tributary/streaming/graph.py b/tributary/streaming/graph.py index 3e3ac98..ed1c4db 100644 --- a/tributary/streaming/graph.py +++ b/tributary/streaming/graph.py @@ -49,23 +49,52 @@ def rebuild(self): def stop(self): self._stop = True - async def _run(self): + async def _run(self, period=None): + """this is the main graph runner. it is pretty straightforward, we go through + all the layers of the graph and execute the layer as a batch of coroutines. + + If we generate a stop event (e.g. graph is done), we stop. + + If a period is set, a layer in the graph will run for at max `period` seconds + before pushing a None. + + Args: + period (Optional[int]): max period to wait + """ value, last, self._stop = None, None, False # run onstarts await asyncio.gather(*(asyncio.create_task(s()) for s in self._onstarts)) while True: - for level in self._nodes: - if self._stop: - break + if period is not None: + sets = {} + for i, level in enumerate(self._nodes): + sets[i] = set() + for n in level: + sets[i].add(asyncio.create_task(n())) + + await asyncio.gather(*(asyncio.create_task(n()) for n in level)) + # TODO + # wrap each individual node in a task + # add tasks to set + # execute all and remove from set on callback + # how loop checking if tasks are done up until `period` + # force push None for remaining (`_output(None)`) + # next level + # on next loop around only re-wrap and re-call those that aren't still in the set + raise NotImplementedError() + else: + for level in self._nodes: + if self._stop: + break - await asyncio.gather(*(asyncio.create_task(n()) for n in level)) + await asyncio.gather(*(asyncio.create_task(n()) for n in level)) - self.rebuild() + self.rebuild() - if self._stop: - break + if self._stop: + break value, last = self._starting_node.value(), value @@ -78,7 +107,8 @@ async def _run(self): # return last val return last - def run(self, blocking=True, newloop=False, start=True): + def run(self, blocking=True, newloop=False, period=None): + if sys.platform == "win32": # Set to proactor event loop on window # (default in python 3.8+) @@ -94,7 +124,7 @@ def run(self, blocking=True, newloop=False, start=True): asyncio.set_event_loop(loop) - task = loop.create_task(self._run()) + task = loop.create_task(self._run(period=period)) if blocking: # block until done @@ -103,13 +133,10 @@ def run(self, blocking=True, newloop=False, start=True): except KeyboardInterrupt: return - if start: - t = Thread(target=loop.run_until_complete, args=(task,)) - t.daemon = True - t.start() - return loop - - return loop, task + t = Thread(target=loop.run_until_complete, args=(task,)) + t.daemon = True + t.start() + return loop def graph(self): return self._starting_node.graph() diff --git a/tributary/streaming/node.py b/tributary/streaming/node.py index 03a6f53..40c1def 100644 --- a/tributary/streaming/node.py +++ b/tributary/streaming/node.py @@ -217,6 +217,9 @@ async def __call__(self): if isinstance(val, StreamEnd): return await self._finish() + if isinstance(val, (StreamNone,)): + ready = False + # set as active self._active[i] = val else: @@ -234,6 +237,7 @@ async def __call__(self): # Private interface # *********************** def __hash__(self): + """nodes are unique""" return self._id def __rshift__(self, other): @@ -273,10 +277,12 @@ async def _execute(self): # else call it elif isinstance(self._foo, types.FunctionType): try: - # could be a generator + # could be a kicked generator, so wrap in try try: + # execute wrapped function _last = self._foo(*self._active, **self._foo_kwargs) except ZeroDivisionError: + # catch divide by zero and force inf _last = float("inf") except ValueError: @@ -285,6 +291,7 @@ async def _execute(self): continue else: + # can only wrap function types raise TributaryException("Cannot use type:{}".format(type(self._foo))) # calculation was valid @@ -294,7 +301,7 @@ async def _execute(self): self._execution_count += 1 if isinstance(_last, types.AsyncGeneratorType): - + # Swap to async generator unroller async def _foo(g=_last): return await _agen_to_foo(g) @@ -308,6 +315,7 @@ async def _foo(g=_last): _last = self._foo() elif asyncio.iscoroutine(_last): + # await coroutine _last = await _last if self._repeat: @@ -319,8 +327,10 @@ async def _foo(g=_last): else: self._last = _last + # push result downstream await self._output(self._last) + # allow new inputs for i in range(len(self._active)): self._active[i] = StreamNone()