Skip to content

Commit

Permalink
Merge pull request #87 from CBNeurotech/cboulay/gen_axarr_base
Browse files Browse the repository at this point in the history
generator - Add GenAxisArray base Unit
  • Loading branch information
griffinmilsap authored Feb 1, 2024
2 parents 74965eb + 04de791 commit 1cc4d15
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/ezmsg/util/generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray
import traceback
from typing import Any, AsyncGenerator, Generator, Callable, TypeVar
from typing_extensions import ParamSpec
Expand Down Expand Up @@ -86,3 +87,29 @@ async def on_message(self, message: Any) -> AsyncGenerator:
except Exception:
ez.logger.info(traceback.format_exc())


class GenAxisArray(ez.Unit):
STATE: GenState

INPUT_SIGNAL = ez.InputStream(AxisArray)
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)

def initialize(self) -> None:
self.construct_generator()

# Method to be implemented by subclasses to construct the specific generator
def construct_generator(self):
raise NotImplementedError

@ez.subscriber(INPUT_SIGNAL)
@ez.publisher(OUTPUT_SIGNAL)
async def on_message(self, message: AxisArray) -> AsyncGenerator:
try:
ret = self.STATE.gen.send(message)
if ret is not None:
yield self.OUTPUT_SIGNAL, ret
except (StopIteration, GeneratorExit):
ez.logger.debug(f"Generator closed in {self.address}")
except Exception:
ez.logger.info(traceback.format_exc())

0 comments on commit 1cc4d15

Please sign in to comment.