diff --git a/examples/applications/llamaindex-cassandra-sink/python/llamaindex_cassandra.py b/examples/applications/llamaindex-cassandra-sink/python/llamaindex_cassandra.py index 912dd1881..19bd2ff6c 100644 --- a/examples/applications/llamaindex-cassandra-sink/python/llamaindex_cassandra.py +++ b/examples/applications/llamaindex-cassandra-sink/python/llamaindex_cassandra.py @@ -16,19 +16,18 @@ import base64 import io -from typing import List, Optional, Dict, Any +from typing import Dict, Any import openai from cassandra.auth import PlainTextAuthProvider from cassandra.cluster import Cluster -from langstream import Sink, CommitCallback, Record +from langstream import Sink, Record from llama_index import VectorStoreIndex, Document from llama_index.vector_stores import CassandraVectorStore class LlamaIndexCassandraSink(Sink): def __init__(self): - self.commit_cb: Optional[CommitCallback] = None self.config = None self.session = None self.index = None @@ -63,13 +62,8 @@ def start(self): self.index = VectorStoreIndex.from_vector_store(vector_store) - def write(self, records: List[Record]): - for record in records: - self.index.insert(Document(text=record.value())) - self.commit_cb.commit([record]) - - def set_commit_callback(self, commit_callback: CommitCallback): - self.commit_cb = commit_callback + def write(self, record: Record): + self.index.insert(Document(text=record.value())) def close(self): if self.session: diff --git a/examples/applications/python-processor-embeddings/python/embeddings.py b/examples/applications/python-processor-embeddings/python/embeddings.py index 37f16e084..a1ffdca55 100644 --- a/examples/applications/python-processor-embeddings/python/embeddings.py +++ b/examples/applications/python-processor-embeddings/python/embeddings.py @@ -25,11 +25,8 @@ def init(self, config): print("init", config) openai.api_key = config["openaiKey"] - def process(self, records): - processed_records = [] - for record in records: - embedding = get_embedding(record.value(), engine="text-embedding-ada-002") - result = {"input": str(record.value()), "embedding": embedding} - new_value = json.dumps(result) - processed_records.append((record, [(new_value,)])) - return processed_records + def process(self, record): + embedding = get_embedding(record.value(), engine="text-embedding-ada-002") + result = {"input": str(record.value()), "embedding": embedding} + new_value = json.dumps(result) + return [(new_value,)] diff --git a/examples/applications/python-processor-exclamation/python/example.py b/examples/applications/python-processor-exclamation/python/example.py index c2697200b..3fd401623 100644 --- a/examples/applications/python-processor-exclamation/python/example.py +++ b/examples/applications/python-processor-exclamation/python/example.py @@ -15,10 +15,10 @@ # limitations under the License. # -from langstream import SimpleRecord, SingleRecordProcessor +from langstream import SimpleRecord, Processor # Example Python processor that adds an exclamation mark to the end of the record value -class Exclamation(SingleRecordProcessor): - def process_record(self, record): +class Exclamation(Processor): + def process(self, record): return [SimpleRecord(record.value() + "!!", headers=record.headers())] diff --git a/langstream-e2e-tests/src/test/resources/apps/python-processor/python/example.py b/langstream-e2e-tests/src/test/resources/apps/python-processor/python/example.py index 8016fd259..dd947fcbf 100644 --- a/langstream-e2e-tests/src/test/resources/apps/python-processor/python/example.py +++ b/langstream-e2e-tests/src/test/resources/apps/python-processor/python/example.py @@ -14,15 +14,15 @@ # limitations under the License. # -from langstream import SimpleRecord, SingleRecordProcessor +from langstream import SimpleRecord, Processor -class Exclamation(SingleRecordProcessor): +class Exclamation(Processor): def init(self, config): print("init", config) self.secret_value = config["secret_value"] - def process_record(self, record): + def process(self, record): return [ SimpleRecord( record.value() + "!!" + self.secret_value, headers=record.headers() diff --git a/langstream-e2e-tests/src/test/resources/apps/python-sink/python/example.py b/langstream-e2e-tests/src/test/resources/apps/python-sink/python/example.py index 02e3baf86..d815bcda9 100644 --- a/langstream-e2e-tests/src/test/resources/apps/python-sink/python/example.py +++ b/langstream-e2e-tests/src/test/resources/apps/python-sink/python/example.py @@ -20,25 +20,19 @@ class TestSink(object): def __init__(self): - self.commit_callback = None self.producer = None def init(self, config): logging.info("Init config: " + str(config)) self.producer = Producer({"bootstrap.servers": config["bootstrapServers"]}) - def write(self, records): - logging.info("Write records: " + str(records)) + def write(self, record): + logging.info("Write record: " + str(record)) try: - for record in records: - self.producer.produce( - "ls-test-output", value=("write: " + record.value()).encode("utf-8") - ) + self.producer.produce( + "ls-test-output", value=("write: " + record.value()).encode("utf-8") + ) self.producer.flush() - self.commit_callback.commit(records) except Exception as e: logging.error("Error writing records: " + str(e)) raise e - - def set_commit_callback(self, commit_callback): - self.commit_callback = commit_callback diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/examples/example.py b/langstream-runtime/langstream-runtime-impl/src/main/python/examples/example.py index b809ef584..df6bfc9e6 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/examples/example.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/examples/example.py @@ -27,12 +27,9 @@ def read(self): print(f"read {records}") return records - def set_commit_callback(self, cb): - pass + def process(self, record): + print(f"process {record}") + return [record] - def process(self, records): - print(f"process {records}") - return [(record, [record]) for record in records] - - def write(self, records): - print(f"write {records}") + def write(self, record): + print(f"write {record}") diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/__init__.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/__init__.py index 4cd5835ed..c0a3e0bca 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/__init__.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/__init__.py @@ -22,9 +22,8 @@ Sink, Source, Processor, - CommitCallback, ) -from .util import SimpleRecord, SingleRecordProcessor, AvroValue +from .util import SimpleRecord, AvroValue __all__ = [ "Record", @@ -33,8 +32,6 @@ "Source", "Sink", "Processor", - "CommitCallback", "SimpleRecord", - "SingleRecordProcessor", "AvroValue", ] diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/api.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/api.py index 136cfd259..089462b37 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/api.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/api.py @@ -16,7 +16,8 @@ # from abc import ABC, abstractmethod -from typing import Any, List, Tuple, Dict, Union +from concurrent.futures import Future +from typing import Any, List, Tuple, Dict, Union, Optional __all__ = [ "Record", @@ -25,7 +26,6 @@ "Source", "Sink", "Processor", - "CommitCallback", ] @@ -91,9 +91,9 @@ class Source(Agent): def read(self) -> List[RecordType]: """The Source agent generates records and returns them as list of records. - :returns: the list of records. The records must either respect the Record - API contract (have methods value(), key() and so on) or be a dict or - tuples/list. + :returns: the list of records. + The records must either respect the Record API contract (have methods value(), + key() and so on) or be a dict or tuples/list. If the records are dict, the keys if present shall be "value", "key", "headers", "origin" and "timestamp". Eg: @@ -108,15 +108,15 @@ def read(self) -> List[RecordType]: """ pass - def commit(self, records: List[Record]): - """Called by the framework to indicate the records that have been successfully + def commit(self, record: Record): + """Called by the framework to indicate that a record has been successfully processed.""" pass def permanent_failure(self, record: Record, error: Exception): """Called by the framework to indicate that the agent has permanently failed to - process the record. - The Source agent may send the records to a dead letter queue or raise an error. + process a record. + The Source agent may send the record to a dead letter queue or raise an error. """ raise error @@ -129,47 +129,30 @@ class Processor(Agent): @abstractmethod def process( - self, records: List[Record] - ) -> List[Tuple[Record, Union[List[RecordType], Exception]]]: - """The agent processes records and returns a list containing the associations of - these records with the result of these record processing. - The result of each record processing is a list of new records or an exception. - The transactionality of the function is guaranteed by the runtime. - - :returns: the list of associations between an input record and the output - records processed from it. - Eg: [(input_record, [output_record1, output_record2])] - If an input record cannot be processed, the associated element shall be an - exception. - Eg: [(input_record, RuntimeError("Could not process"))] + self, record: Record + ) -> Union[List[RecordType], Future[List[RecordType]]]: + """The agent processes a record and returns a list of new records. + + :returns: the list of records or a concurrent.futures.Future that will complete + with the list of records. When the processing is successful, the output records must either respect the Record API contract (have methods value(), key() and so on) or be a dict or tuples/list. If the records are dict, the keys if present shall be "value", "key", "headers", "origin" and "timestamp". Eg: - * if you return [(input_record, [{"value": "foo"}])] a record - Record(value="foo") will be built. + * if you return {"value": "foo"} a record Record(value="foo") will be built. If the output records are tuples/list, the framework will automatically construct Record objects from them with the values in the following order : value, key, headers, origin, timestamp. Eg: - * if you return [(input_record, [("foo",)])] a record Record(value="foo") will - be built. - * if you return [(input_record, [("foo", "bar")])] a record - Record(value="foo", key="bar") will be built. + * if you return ("foo",) a record Record(value="foo") will be built. + * if you return ("foo", "bar") a record Record(value="foo", key="bar") will be + built. """ pass -class CommitCallback(ABC): - @abstractmethod - def commit(self, records: List[Record]): - """Called by a Sink to indicate the records that have been successfully - written.""" - pass - - class Sink(Agent): """The Sink agent interface @@ -177,13 +160,13 @@ class Sink(Agent): """ @abstractmethod - def write(self, records: List[Record]): + def write(self, record: Record) -> Optional[Future[None]]: """The Sink agent receives records from the framework and typically writes them - to an external service.""" - pass + to an external service. + For a synchronous result, return None/nothing if successful or otherwise raise + an Exception. + For an asynchronous result, return a concurrent.futures.Future. - @abstractmethod - def set_commit_callback(self, commit_callback: CommitCallback): - """Called by the framework to specify a CommitCallback that shall be used by the - Sink to indicate the records that have been written.""" + :returns: nothing if the write is successful or a concurrent.futures.Future + """ pass diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/util.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/util.py index 8ec3556fb..98f06444f 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/util.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/util.py @@ -14,13 +14,12 @@ # limitations under the License. # -from abc import abstractmethod from dataclasses import dataclass -from typing import Any, List, Tuple, Union +from typing import Any, List, Tuple -from .api import Record, Processor, RecordType +from .api import Record -__all__ = ["SimpleRecord", "SingleRecordProcessor", "AvroValue"] +__all__ = ["SimpleRecord", "AvroValue"] class SimpleRecord(Record): @@ -57,46 +56,15 @@ def timestamp(self) -> int: def __str__(self): return ( - f"Record(value={self._value}, key={self._key}, origin={self._origin}, " - f"timestamp={self._timestamp}, headers={self._headers})" + f"SimpleRecord(value={self._value}, key={self._key}, " + f"origin={self._origin},timestamp={self._timestamp}, " + f"headers={self._headers})" ) def __repr__(self): return self.__str__() -class SingleRecordProcessor(Processor): - """A Processor that processes records one-by-one""" - - @abstractmethod - def process_record(self, record: Record) -> List[RecordType]: - """Process one record and return a list of records or raise an exception. - - :returns: the list of processed records. The records must either respect the - Record API contract (have methods value(), key() and so on) or be tuples/list. - If the records are tuples/list, the framework will automatically construct - Record objects from them with the values in the following order : value, key, - headers, origin, timestamp. - Eg: - * if you return [("foo",)] a record Record(value="foo") will be built. - * if you return [("foo", "bar")] a record Record(value="foo", key="bar") will - be built. - """ - pass - - def process( - self, records: List[Record] - ) -> List[Tuple[Record, Union[List[RecordType], Exception]]]: - results = [] - for record in records: - try: - processed = self.process_record(record) - results.append((record, processed)) - except Exception as e: - results.append((record, e)) - return results - - @dataclass class AvroValue(object): schema: dict diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/api.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/api.py index 136cfd259..089462b37 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/api.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/api.py @@ -16,7 +16,8 @@ # from abc import ABC, abstractmethod -from typing import Any, List, Tuple, Dict, Union +from concurrent.futures import Future +from typing import Any, List, Tuple, Dict, Union, Optional __all__ = [ "Record", @@ -25,7 +26,6 @@ "Source", "Sink", "Processor", - "CommitCallback", ] @@ -91,9 +91,9 @@ class Source(Agent): def read(self) -> List[RecordType]: """The Source agent generates records and returns them as list of records. - :returns: the list of records. The records must either respect the Record - API contract (have methods value(), key() and so on) or be a dict or - tuples/list. + :returns: the list of records. + The records must either respect the Record API contract (have methods value(), + key() and so on) or be a dict or tuples/list. If the records are dict, the keys if present shall be "value", "key", "headers", "origin" and "timestamp". Eg: @@ -108,15 +108,15 @@ def read(self) -> List[RecordType]: """ pass - def commit(self, records: List[Record]): - """Called by the framework to indicate the records that have been successfully + def commit(self, record: Record): + """Called by the framework to indicate that a record has been successfully processed.""" pass def permanent_failure(self, record: Record, error: Exception): """Called by the framework to indicate that the agent has permanently failed to - process the record. - The Source agent may send the records to a dead letter queue or raise an error. + process a record. + The Source agent may send the record to a dead letter queue or raise an error. """ raise error @@ -129,47 +129,30 @@ class Processor(Agent): @abstractmethod def process( - self, records: List[Record] - ) -> List[Tuple[Record, Union[List[RecordType], Exception]]]: - """The agent processes records and returns a list containing the associations of - these records with the result of these record processing. - The result of each record processing is a list of new records or an exception. - The transactionality of the function is guaranteed by the runtime. - - :returns: the list of associations between an input record and the output - records processed from it. - Eg: [(input_record, [output_record1, output_record2])] - If an input record cannot be processed, the associated element shall be an - exception. - Eg: [(input_record, RuntimeError("Could not process"))] + self, record: Record + ) -> Union[List[RecordType], Future[List[RecordType]]]: + """The agent processes a record and returns a list of new records. + + :returns: the list of records or a concurrent.futures.Future that will complete + with the list of records. When the processing is successful, the output records must either respect the Record API contract (have methods value(), key() and so on) or be a dict or tuples/list. If the records are dict, the keys if present shall be "value", "key", "headers", "origin" and "timestamp". Eg: - * if you return [(input_record, [{"value": "foo"}])] a record - Record(value="foo") will be built. + * if you return {"value": "foo"} a record Record(value="foo") will be built. If the output records are tuples/list, the framework will automatically construct Record objects from them with the values in the following order : value, key, headers, origin, timestamp. Eg: - * if you return [(input_record, [("foo",)])] a record Record(value="foo") will - be built. - * if you return [(input_record, [("foo", "bar")])] a record - Record(value="foo", key="bar") will be built. + * if you return ("foo",) a record Record(value="foo") will be built. + * if you return ("foo", "bar") a record Record(value="foo", key="bar") will be + built. """ pass -class CommitCallback(ABC): - @abstractmethod - def commit(self, records: List[Record]): - """Called by a Sink to indicate the records that have been successfully - written.""" - pass - - class Sink(Agent): """The Sink agent interface @@ -177,13 +160,13 @@ class Sink(Agent): """ @abstractmethod - def write(self, records: List[Record]): + def write(self, record: Record) -> Optional[Future[None]]: """The Sink agent receives records from the framework and typically writes them - to an external service.""" - pass + to an external service. + For a synchronous result, return None/nothing if successful or otherwise raise + an Exception. + For an asynchronous result, return a concurrent.futures.Future. - @abstractmethod - def set_commit_callback(self, commit_callback: CommitCallback): - """Called by the framework to specify a CommitCallback that shall be used by the - Sink to indicate the records that have been written.""" + :returns: nothing if the write is successful or a concurrent.futures.Future + """ pass diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py index 8197981d3..cb30fa25d 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py @@ -20,6 +20,7 @@ import logging import queue import threading +from concurrent.futures import Future from io import BytesIO from typing import Iterable, Union, List, Tuple, Any, Optional, Dict @@ -48,7 +49,6 @@ Processor, Record, Agent, - CommitCallback, ) from .util import SimpleRecord, AvroValue @@ -125,12 +125,10 @@ def handle_read_requests( try: for request in requests: if len(request.committed_records) > 0: - records = [] for record_id in request.committed_records: record = read_records.pop(record_id, None) if record is not None: - records.append(record) - call_method_if_exists(self.agent, "commit", records) + call_method_if_exists(self.agent, "commit", record) if request.HasField("permanent_failure"): failure = request.permanent_failure record = read_records.pop(failure.record_id, None) @@ -144,56 +142,85 @@ def handle_read_requests( except Exception as e: read_result.append(e) + @staticmethod + def handle_requests(handler, requests): + results = queue.Queue(1000) + thread = threading.Thread(target=handler, args=(requests, results)) + thread.start() + + while True: + try: + result = results.get(True, 0.1) + if isinstance(result, bool): + break + yield result + except queue.Empty: + pass + thread.join() + def process(self, requests: Iterable[ProcessorRequest], _): + return self.handle_requests(self.handle_process_requests, requests) + + def process_record( + self, source_record, get_processed_fn, get_processed_args, process_results + ): + grpc_result = ProcessorResult(record_id=source_record.record_id) + try: + processed_records = get_processed_fn(*get_processed_args) + if isinstance(processed_records, Future): + processed_records.add_done_callback( + lambda f: self.process_record( + source_record, f.result, (), process_results + ) + ) + else: + for record in processed_records: + schemas, grpc_record = self.to_grpc_record(wrap_in_record(record)) + for schema in schemas: + process_results.put(ProcessorResponse(schema=schema)) + grpc_result.records.append(grpc_record) + process_results.put(ProcessorResponse(results=[grpc_result])) + except Exception as e: + grpc_result.error = str(e) + process_results.put(ProcessorResponse(results=[grpc_result])) + + def handle_process_requests( + self, requests: Iterable[ProcessorRequest], process_results + ): for request in requests: if request.HasField("schema"): schema = fastavro.parse_schema(json.loads(request.schema.value)) self.client_schemas[request.schema.schema_id] = schema if len(request.records) > 0: - records = [self.from_grpc_record(record) for record in request.records] - process_result = self.agent.process(records) - grpc_results = [] - for source_record, result in process_result: - grpc_result = ProcessorResult(record_id=source_record.record_id) - if isinstance(result, Exception): - grpc_result.error = str(result) - else: - for record in result: - schemas, grpc_record = self.to_grpc_record( - wrap_in_record(record) - ) - for schema in schemas: - yield ProcessorResponse(schema=schema) - grpc_result.records.append(grpc_record) - grpc_results.append(grpc_result) - yield ProcessorResponse(results=grpc_results) + for source_record in request.records: + self.process_record( + source_record, + lambda r: self.agent.process(self.from_grpc_record(r)), + (source_record,), + process_results, + ) + process_results.put(True) def write(self, requests: Iterable[SinkRequest], _): - write_results = queue.Queue(1000) - - class GrpcCommitCallback(CommitCallback): - def commit(self, records: List[RecordWithId]): - for record in records: - write_results.put(record.record_id) - - self.agent.set_commit_callback(GrpcCommitCallback()) - write_thread = threading.Thread( - target=self.handle_write_requests, args=(requests, write_results) - ) - write_thread.start() + return self.handle_requests(self.handle_write_requests, requests) - while True: - try: - result = write_results.get(True, 0.1) - if isinstance(result, Exception): - yield SinkResponse(error=str(result)) - elif isinstance(result, bool): - break - else: - yield SinkResponse(record_id=result) - except queue.Empty: - pass - write_thread.join() + def write_record( + self, source_record, get_written_fn, get_written_args, write_results + ): + try: + result = get_written_fn(*get_written_args) + if isinstance(result, Future): + result.add_done_callback( + lambda f: self.write_record( + source_record, f.result, (), write_results + ) + ) + else: + write_results.put(SinkResponse(record_id=source_record.record_id)) + except Exception as e: + write_results.put( + SinkResponse(record_id=source_record.record_id, error=str(e)) + ) def handle_write_requests(self, requests: Iterable[SinkRequest], write_results): for request in requests: @@ -201,11 +228,12 @@ def handle_write_requests(self, requests: Iterable[SinkRequest], write_results): schema = fastavro.parse_schema(json.loads(request.schema.value)) self.client_schemas[request.schema.schema_id] = schema if request.HasField("record"): - record = self.from_grpc_record(request.record) - try: - self.agent.write([record]) - except Exception as e: - write_results.put(e) + self.write_record( + request.record, + lambda r: self.agent.write(self.from_grpc_record(r)), + (request.record,), + write_results, + ) write_results.put(True) def from_grpc_record(self, record: GrpcRecord) -> SimpleRecord: diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/server_and_stub.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/server_and_stub.py new file mode 100644 index 000000000..1ce338ed3 --- /dev/null +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/server_and_stub.py @@ -0,0 +1,44 @@ +# +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional + +import grpc + +from langstream_grpc.grpc_service import AgentServer +from langstream_grpc.proto.agent_pb2_grpc import AgentServiceStub + + +class ServerAndStub(object): + def __init__(self, class_name): + self.class_name = class_name + self.server: Optional[AgentServer] = None + self.channel: Optional[grpc.Channel] = None + self.stub: Optional[AgentServiceStub] = None + + def __enter__(self): + config = f"""{{ + "className": "{self.class_name}" + }}""" + self.server = AgentServer("[::]:0", config) + self.server.start() + self.channel = grpc.insecure_channel("localhost:%d" % self.server.port) + self.stub = AgentServiceStub(channel=self.channel) + return self + + def __exit__(self, *args): + self.channel.close() + self.server.stop() diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py index 0fbb8b1e3..25ae02ff9 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py @@ -15,16 +15,15 @@ # import json +from concurrent.futures import ThreadPoolExecutor, Future from io import BytesIO from typing import List, Dict, Any import fastavro -import grpc import pytest from google.protobuf import empty_pb2 -from langstream_grpc.api import Record, RecordType -from langstream_grpc.grpc_service import AgentServer +from langstream_grpc.api import Record, RecordType, Processor from langstream_grpc.proto.agent_pb2 import ( ProcessorRequest, Record as GrpcRecord, @@ -34,23 +33,7 @@ Schema, InfoResponse, ) -from langstream_grpc.proto.agent_pb2_grpc import AgentServiceStub -from langstream_grpc.util import SingleRecordProcessor - - -@pytest.fixture(autouse=True) -def stub(): - config = """{ - "className": "langstream_grpc.tests.test_grpc_processor.MyProcessor" - }""" - server = AgentServer("[::]:0", config) - server.start() - channel = grpc.insecure_channel("localhost:%d" % server.port) - - yield AgentServiceStub(channel=channel) - - channel.close() - server.stop() +from langstream_grpc.tests.server_and_stub import ServerAndStub @pytest.mark.parametrize( @@ -78,128 +61,162 @@ def stub(): pytest.param("double_value", "double_value", 42.0, 43.0, 44.0), ], ) -def test_process(input_type, output_type, value, key, header, request): - stub = request.getfixturevalue("stub") - - record = GrpcRecord( - record_id=42, - key=Value(**{input_type: key}), - value=Value(**{input_type: value}), - headers=[ - Header( - name="test-header", - value=Value(**{input_type: header}), - ) - ], - origin="test-origin", - timestamp=43, - ) - response: ProcessorResponse - for response in stub.process(iter([ProcessorRequest(records=[record])])): - assert len(response.results) == 1 - assert response.results[0].record_id == record.record_id - assert response.results[0].HasField("error") is False - assert len(response.results[0].records) == 1 - result = response.results[0].records[0] - assert result.key == Value(**{output_type: key}) - assert result.value == Value(**{output_type: value}) - assert len(result.headers) == 1 - assert result.headers[0].name == result.headers[0].name - assert result.headers[0].value == Value(**{output_type: header}) - assert result.origin == record.origin - assert result.timestamp == record.timestamp - - -def test_avro(stub): - requests = [] - schema = { - "type": "record", - "name": "Test", - "namespace": "test", - "fields": [{"name": "field", "type": {"type": "string"}}], - } - canonical_schema = fastavro.schema.to_parsing_canonical_form(schema) - requests.append( - ProcessorRequest( - schema=Schema(schema_id=42, value=canonical_schema.encode("utf-8")) +def test_process(input_type, output_type, value, key, header): + with ServerAndStub( + "langstream_grpc.tests.test_grpc_processor.MyProcessor" + ) as server_and_stub: + record = GrpcRecord( + record_id=42, + key=Value(**{input_type: key}), + value=Value(**{input_type: value}), + headers=[ + Header( + name="test-header", + value=Value(**{input_type: header}), + ) + ], + origin="test-origin", + timestamp=43, ) - ) - - fp = BytesIO() - try: - fastavro.schemaless_writer(fp, schema, {"field": "test"}) + response: ProcessorResponse + for response in server_and_stub.stub.process( + iter([ProcessorRequest(records=[record])]) + ): + assert len(response.results) == 1 + assert response.results[0].record_id == record.record_id + assert response.results[0].HasField("error") is False + assert len(response.results[0].records) == 1 + result = response.results[0].records[0] + assert result.key == Value(**{output_type: key}) + assert result.value == Value(**{output_type: value}) + assert len(result.headers) == 1 + assert result.headers[0].name == result.headers[0].name + assert result.headers[0].value == Value(**{output_type: header}) + assert result.origin == record.origin + assert result.timestamp == record.timestamp + + +def test_avro(): + with ServerAndStub( + "langstream_grpc.tests.test_grpc_processor.MyProcessor" + ) as server_and_stub: + requests = [] + schema = { + "type": "record", + "name": "Test", + "namespace": "test", + "fields": [{"name": "field", "type": {"type": "string"}}], + } + canonical_schema = fastavro.schema.to_parsing_canonical_form(schema) requests.append( ProcessorRequest( - records=[ - GrpcRecord( - record_id=43, - value=Value(schema_id=42, avro_value=fp.getvalue()), + schema=Schema(schema_id=42, value=canonical_schema.encode("utf-8")) + ) + ) + + fp = BytesIO() + try: + fastavro.schemaless_writer(fp, schema, {"field": "test"}) + requests.append( + ProcessorRequest( + records=[ + GrpcRecord( + record_id=43, + value=Value(schema_id=42, avro_value=fp.getvalue()), + ) + ] + ) + ) + finally: + fp.close() + + responses: list[ProcessorResponse] + responses = list(server_and_stub.stub.process(iter(requests))) + response_schema = responses[0] + assert len(response_schema.results) == 0 + assert response_schema.HasField("schema") + assert response_schema.schema.schema_id == 1 + assert response_schema.schema.value.decode("utf-8") == canonical_schema + + response_record = responses[1] + assert len(response_record.results) == 1 + result = response_record.results[0] + assert result.record_id == 43 + assert len(result.records) == 1 + assert result.records[0].value.schema_id == 1 + fp = BytesIO(result.records[0].value.avro_value) + try: + decoded = fastavro.schemaless_reader(fp, json.loads(canonical_schema)) + assert decoded == {"field": "test"} + finally: + fp.close() + + +def test_empty_record(): + with ServerAndStub( + "langstream_grpc.tests.test_grpc_processor.MyProcessor" + ) as server_and_stub: + response: ProcessorResponse + for response in server_and_stub.stub.process( + iter([ProcessorRequest(records=[GrpcRecord()])]) + ): + assert len(response.results) == 1 + assert response.results[0].record_id == 0 + assert response.results[0].HasField("error") is False + assert len(response.results[0].records) == 1 + result = response.results[0].records[0] + assert result.HasField("key") is False + assert result.HasField("value") is False + assert len(result.headers) == 0 + assert result.origin == "" + assert result.HasField("timestamp") is False + + +def test_failing_record(): + with ServerAndStub( + "langstream_grpc.tests.test_grpc_processor.MyFailingProcessor" + ) as server_and_stub: + for response in server_and_stub.stub.process( + iter([ProcessorRequest(records=[GrpcRecord()])]) + ): + assert len(response.results) == 1 + assert response.results[0].HasField("error") is True + assert response.results[0].error == "failure" + + +def test_future_record(): + with ServerAndStub( + "langstream_grpc.tests.test_grpc_processor.MyFutureProcessor" + ) as server_and_stub: + response: ProcessorResponse + for response in server_and_stub.stub.process( + iter( + [ + ProcessorRequest( + records=[GrpcRecord(value=Value(string_value="test"))] ) ] ) - ) - finally: - fp.close() - - responses: list[ProcessorResponse] - responses = list(stub.process(iter(requests))) - response_schema = responses[0] - assert len(response_schema.results) == 0 - assert response_schema.HasField("schema") - assert response_schema.schema.schema_id == 1 - assert response_schema.schema.value.decode("utf-8") == canonical_schema - - response_record = responses[1] - assert len(response_record.results) == 1 - result = response_record.results[0] - assert result.record_id == 43 - assert len(result.records) == 1 - assert result.records[0].value.schema_id == 1 - fp = BytesIO(result.records[0].value.avro_value) - try: - decoded = fastavro.schemaless_reader(fp, json.loads(canonical_schema)) - assert decoded == {"field": "test"} - finally: - fp.close() - - -def test_empty_record(request): - stub = request.getfixturevalue("stub") - record = GrpcRecord() - for response in stub.process(iter([ProcessorRequest(records=[record])])): - assert len(response.results) == 1 - assert response.results[0].record_id == 0 - assert response.results[0].HasField("error") is False - assert len(response.results[0].records) == 1 - result = response.results[0].records[0] - assert result.HasField("key") is False - assert result.HasField("value") is False - assert len(result.headers) == 0 - assert result.origin == "" - assert result.HasField("timestamp") is False - - -def test_failing_record(request): - stub = request.getfixturevalue("stub") - record = GrpcRecord(origin="failing-record") - for response in stub.process(iter([ProcessorRequest(records=[record])])): - assert len(response.results) == 1 - assert response.results[0].HasField("error") is True - assert response.results[0].error == "failure" - - -def test_info(stub): - info: InfoResponse = stub.agent_info(empty_pb2.Empty()) - assert info.json_info == '{"test-info-key": "test-info-value"}' - - -class MyProcessor(SingleRecordProcessor): + ): + assert len(response.results) == 1 + assert response.results[0].HasField("error") is False + assert len(response.results[0].records) == 1 + assert response.results[0].records[0].value.string_value == "test" + + +def test_info(): + with ServerAndStub( + "langstream_grpc.tests.test_grpc_processor.MyProcessor" + ) as server_and_stub: + info: InfoResponse = server_and_stub.stub.agent_info(empty_pb2.Empty()) + assert info.json_info == '{"test-info-key": "test-info-value"}' + + +class MyProcessor(Processor): def agent_info(self) -> Dict[str, Any]: return {"test-info-key": "test-info-value"} - def process_record(self, record: Record) -> List[RecordType]: - if record.origin() == "failing-record": - raise Exception("failure") + def process(self, record: Record) -> List[RecordType]: if isinstance(record.value(), str): return [record] if isinstance(record.value(), float): @@ -221,3 +238,17 @@ def process_record(self, record: Record) -> List[RecordType]: record.timestamp(), ) ] + + +class MyFailingProcessor(Processor): + def process(self, record: Record) -> List[RecordType]: + raise Exception("failure") + + +class MyFutureProcessor(Processor): + def __init__(self): + self.written_records = [] + self.executor = ThreadPoolExecutor(max_workers=10) + + def process(self, record: Record) -> Future[List[RecordType]]: + return self.executor.submit(lambda r: [r], record) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_sink.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_sink.py index 7898b05d2..9956d8b3f 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_sink.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_sink.py @@ -14,15 +14,12 @@ # limitations under the License. # +from concurrent.futures import ThreadPoolExecutor, Future from io import BytesIO -from typing import List, Optional import fastavro -import grpc -import pytest -from langstream_grpc.api import Record, Sink, CommitCallback -from langstream_grpc.grpc_service import AgentServer +from langstream_grpc.api import Record, Sink from langstream_grpc.proto.agent_pb2 import ( Record as GrpcRecord, SinkRequest, @@ -30,91 +27,113 @@ Value, SinkResponse, ) -from langstream_grpc.proto.agent_pb2_grpc import AgentServiceStub - - -@pytest.fixture(autouse=True) -def server_and_stub(): - config = """{ - "className": "langstream_grpc.tests.test_grpc_sink.MySink" - }""" - server = AgentServer("[::]:0", config) - server.start() - channel = grpc.insecure_channel("localhost:%d" % server.port) - - yield server, AgentServiceStub(channel=channel) - - channel.close() - server.stop() - - -def test_write(server_and_stub): - server, stub = server_and_stub +from langstream_grpc.tests.server_and_stub import ServerAndStub + + +def test_write(): + with ServerAndStub( + "langstream_grpc.tests.test_grpc_sink.MySink" + ) as server_and_stub: + + def requests(): + schema = { + "type": "record", + "name": "Test", + "namespace": "test", + "fields": [{"name": "field", "type": {"type": "string"}}], + } + canonical_schema = fastavro.schema.to_parsing_canonical_form(schema) + yield SinkRequest( + schema=Schema(schema_id=42, value=canonical_schema.encode("utf-8")) + ) - def requests(): - schema = { - "type": "record", - "name": "Test", - "namespace": "test", - "fields": [{"name": "field", "type": {"type": "string"}}], - } - canonical_schema = fastavro.schema.to_parsing_canonical_form(schema) - yield SinkRequest( - schema=Schema(schema_id=42, value=canonical_schema.encode("utf-8")) + fp = BytesIO() + try: + fastavro.schemaless_writer(fp, schema, {"field": "test"}) + yield SinkRequest( + record=GrpcRecord( + record_id=43, + value=Value(schema_id=42, avro_value=fp.getvalue()), + ) + ) + finally: + fp.close() + + responses: list[SinkResponse] + responses = list(server_and_stub.stub.write(iter(requests()))) + assert len(responses) == 1 + assert responses[0].record_id == 43 + assert len(server_and_stub.server.agent.written_records) == 1 + assert ( + server_and_stub.server.agent.written_records[0].value().value["field"] + == "test" ) - fp = BytesIO() - try: - fastavro.schemaless_writer(fp, schema, {"field": "test"}) - yield SinkRequest( - record=GrpcRecord( - record_id=43, - value=Value(schema_id=42, avro_value=fp.getvalue()), + +def test_write_error(): + with ServerAndStub( + "langstream_grpc.tests.test_grpc_sink.MyErrorSink" + ) as server_and_stub: + responses: list[SinkResponse] + responses = list( + server_and_stub.stub.write( + iter( + [ + SinkRequest( + record=GrpcRecord( + value=Value(string_value="test"), + ) + ) + ] ) ) - finally: - fp.close() - - responses: list[SinkResponse] - responses = list(stub.write(iter(requests()))) - assert len(responses) == 1 - assert responses[0].record_id == 43 - assert len(server.agent.written_records) == 1 - assert server.agent.written_records[0].value().value["field"] == "test" - - -def test_write_error(server_and_stub): - server, stub = server_and_stub - - responses: list[SinkResponse] - responses = list( - stub.write( - iter( - [ - SinkRequest( - record=GrpcRecord( - value=Value(string_value="test"), origin="failing-record" + ) + assert len(responses) == 1 + assert responses[0].error == "test-error" + + +def test_write_future(): + with ServerAndStub( + "langstream_grpc.tests.test_grpc_sink.MyFutureSink" + ) as server_and_stub: + responses: list[SinkResponse] + responses = list( + server_and_stub.stub.write( + iter( + [ + SinkRequest( + record=GrpcRecord( + record_id=42, + value=Value(string_value="test"), + ) ) - ) - ] + ] + ) ) ) - ) - assert len(responses) == 1 - assert responses[0].error == "test-error" + assert len(responses) == 1 + assert responses[0].record_id == 42 + assert len(server_and_stub.server.agent.written_records) == 1 + assert server_and_stub.server.agent.written_records[0].value() == "test" class MySink(Sink): def __init__(self): - self.commit_callback: Optional[CommitCallback] = None self.written_records = [] - def write(self, records: List[Record]): - for record in records: - if record.origin() == "failing-record": - raise RuntimeError("test-error") - self.written_records.extend(records) - self.commit_callback.commit(records) + def write(self, record: Record): + self.written_records.append(record) + + +class MyErrorSink(Sink): + def write(self, record: Record): + raise RuntimeError("test-error") + + +class MyFutureSink(Sink): + def __init__(self): + self.written_records = [] + self.executor = ThreadPoolExecutor(max_workers=10) - def set_commit_callback(self, commit_callback: CommitCallback): - self.commit_callback = commit_callback + def write(self, record: Record) -> Future[None]: + return self.executor.submit(lambda r: self.written_records.append(r), record) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py index 7b7963a16..64e2d39ec 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py @@ -25,34 +25,24 @@ import pytest from langstream_grpc.api import Record, RecordType, Source -from langstream_grpc.grpc_service import AgentServer from langstream_grpc.proto.agent_pb2 import ( SourceResponse, SourceRequest, PermanentFailure, ) -from langstream_grpc.proto.agent_pb2_grpc import AgentServiceStub +from langstream_grpc.tests.server_and_stub import ServerAndStub from langstream_grpc.util import AvroValue, SimpleRecord -@pytest.fixture(autouse=True) +@pytest.fixture def server_and_stub(): - config = """{ - "className": "langstream_grpc.tests.test_grpc_source.MySource" - }""" - server = AgentServer("[::]:0", config) - server.start() - channel = grpc.insecure_channel("localhost:%d" % server.port) - - yield server, AgentServiceStub(channel=channel) - - channel.close() - server.stop() + with ServerAndStub( + "langstream_grpc.tests.test_grpc_source.MySource" + ) as server_and_stub: + yield server_and_stub def test_read(server_and_stub): - server, stub = server_and_stub - stop = False def requests(): @@ -62,7 +52,7 @@ def requests(): responses: list[SourceResponse] = [] i = 0 - for response in stub.read(iter(requests())): + for response in server_and_stub.stub.read(iter(requests())): responses.append(response) i += 1 stop = i == 4 @@ -103,7 +93,6 @@ def requests(): def test_commit(server_and_stub): - server, stub = server_and_stub to_commit = queue.Queue() def send_commit(): @@ -118,17 +107,18 @@ def send_commit(): with pytest.raises(grpc.RpcError): response: SourceResponse - for response in stub.read(iter(send_commit())): + for response in server_and_stub.stub.read(iter(send_commit())): for record in response.records: to_commit.put(record.record_id) - assert len(server.agent.committed) == 2 - assert server.agent.committed[0] == server.agent.sent[0] - assert server.agent.committed[1].value() == server.agent.sent[1]["value"] + sent = server_and_stub.server.agent.sent + committed = server_and_stub.server.agent.committed + assert len(committed) == 2 + assert committed[0] == sent[0] + assert committed[1].value() == sent[1]["value"] def test_permanent_failure(server_and_stub): - server, stub = server_and_stub to_fail = queue.Queue() def send_failure(): @@ -143,13 +133,14 @@ def send_failure(): pass response: SourceResponse - for response in stub.read(iter(send_failure())): + for response in server_and_stub.stub.read(iter(send_failure())): for record in response.records: to_fail.put(record.record_id) - assert len(server.agent.failures) == 1 - assert server.agent.failures[0][0] == server.agent.sent[0] - assert str(server.agent.failures[0][1]) == "failure" + failures = server_and_stub.server.agent.failures + assert len(failures) == 1 + assert failures[0][0] == server_and_stub.server.agent.sent[0] + assert str(failures[0][1]) == "failure" class MySource(Source): @@ -180,11 +171,10 @@ def read(self) -> List[RecordType]: return [record] return [] - def commit(self, records: List[Record]): - for record in records: - if record.value() == 43: - raise Exception("test error") - self.committed.extend(records) + def commit(self, record: Record): + if record.value() == 43: + raise Exception("test error") + self.committed.append(record) def permanent_failure(self, record: Record, error: Exception): self.failures.append((record, error)) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/util.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/util.py index 8ec3556fb..98f06444f 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/util.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/util.py @@ -14,13 +14,12 @@ # limitations under the License. # -from abc import abstractmethod from dataclasses import dataclass -from typing import Any, List, Tuple, Union +from typing import Any, List, Tuple -from .api import Record, Processor, RecordType +from .api import Record -__all__ = ["SimpleRecord", "SingleRecordProcessor", "AvroValue"] +__all__ = ["SimpleRecord", "AvroValue"] class SimpleRecord(Record): @@ -57,46 +56,15 @@ def timestamp(self) -> int: def __str__(self): return ( - f"Record(value={self._value}, key={self._key}, origin={self._origin}, " - f"timestamp={self._timestamp}, headers={self._headers})" + f"SimpleRecord(value={self._value}, key={self._key}, " + f"origin={self._origin},timestamp={self._timestamp}, " + f"headers={self._headers})" ) def __repr__(self): return self.__str__() -class SingleRecordProcessor(Processor): - """A Processor that processes records one-by-one""" - - @abstractmethod - def process_record(self, record: Record) -> List[RecordType]: - """Process one record and return a list of records or raise an exception. - - :returns: the list of processed records. The records must either respect the - Record API contract (have methods value(), key() and so on) or be tuples/list. - If the records are tuples/list, the framework will automatically construct - Record objects from them with the values in the following order : value, key, - headers, origin, timestamp. - Eg: - * if you return [("foo",)] a record Record(value="foo") will be built. - * if you return [("foo", "bar")] a record Record(value="foo", key="bar") will - be built. - """ - pass - - def process( - self, records: List[Record] - ) -> List[Tuple[Record, Union[List[RecordType], Exception]]]: - results = [] - for record in records: - try: - processed = self.process_record(record) - results.append((record, processed)) - except Exception as e: - results.append((record, e)) - return results - - @dataclass class AvroValue(object): schema: dict