From 118260d187bed6a9944d8c6359b13e2b647a694c Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Thu, 21 Sep 2023 09:21:30 +0200 Subject: [PATCH] Add permanent failure to Python gRPC Source (#461) --- .../agents/grpc/GrpcAgentSource.java | 16 ++++++++++ .../proto/langstream_grpc/proto/agent.proto | 6 ++++ .../agents/grpc/GrpcAgentSourceTest.java | 16 ++++++++++ .../python/langstream_grpc/grpc_service.py | 9 ++++++ .../python/langstream_grpc/proto/agent_pb2.py | 28 ++++++++-------- .../langstream_grpc/proto/agent_pb2.pyi | 20 ++++++++++-- .../langstream_grpc/tests/test_grpc_source.py | 32 ++++++++++++++++++- 7 files changed, 111 insertions(+), 16 deletions(-) diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSource.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSource.java index e4f2b0f13..f14fa597a 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSource.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSource.java @@ -68,6 +68,22 @@ public List read() throws Exception { return read; } + @Override + public void permanentFailure(Record record, Exception error) throws Exception { + if (record instanceof GrpcAgentRecord grpcAgentRecord) { + request.onNext( + SourceRequest.newBuilder() + .setPermanentFailure( + PermanentFailure.newBuilder() + .setRecordId(grpcAgentRecord.id()) + .setErrorMessage(error.getMessage())) + .build()); + } else { + throw new IllegalArgumentException( + "Record %s is not a GrpcAgentRecord".formatted(record)); + } + } + @Override public void commit(List records) throws Exception { SourceRequest.Builder requestBuilder = SourceRequest.newBuilder(); diff --git a/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto b/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto index 4bdbebfc5..215383850 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto +++ b/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto @@ -67,8 +67,14 @@ message Record { optional int64 timestamp = 6; } +message PermanentFailure { + int64 record_id = 1; + string error_message = 2; +} + message SourceRequest { repeated int64 committed_records = 1; + PermanentFailure permanent_failure = 2; } message SourceResponse { diff --git a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java index 8598c4c86..bace6dc1e 100644 --- a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java +++ b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java @@ -126,6 +126,18 @@ void testAvroAndSchema() throws Exception { source.close(); } + @Test + void testPermanentFailure() throws Exception { + GrpcAgentSource source = new GrpcAgentSource(channel); + source.setContext(new TestAgentContext()); + source.start(); + List read = readRecords(source, 1); + source.permanentFailure(read.get(0), new RuntimeException("permanent-failure")); + assertEquals(testSourceService.permanentFailure.getRecordId(), 42); + assertEquals(testSourceService.permanentFailure.getErrorMessage(), "permanent-failure"); + source.close(); + } + static List readRecords(GrpcAgentSource source, int numberOfRecords) { List read = new ArrayList<>(); await().atMost(5, TimeUnit.SECONDS) @@ -151,6 +163,7 @@ static byte[] serializeGenericRecord(GenericRecord record) throws IOException { static class TestSourceService extends AgentServiceGrpc.AgentServiceImplBase { final List committedRecords = new CopyOnWriteArrayList<>(); + PermanentFailure permanentFailure; @Override public StreamObserver read(StreamObserver responseObserver) { @@ -200,6 +213,9 @@ public StreamObserver read(StreamObserver respons @Override public void onNext(SourceRequest request) { committedRecords.addAll(request.getCommittedRecordsList()); + if (request.hasPermanentFailure()) { + permanentFailure = request.getPermanentFailure(); + } if (request.getCommittedRecordsList().contains(43L)) { responseObserver.onError(new RuntimeException("test error")); } else if (request.getCommittedRecordsList().contains(44L)) { diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py index 1a1f266d0..db43aa431 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py @@ -72,6 +72,15 @@ def handle_requests( if record is not None: records.append(record) call_method_if_exists(agent, "commit", records) + if request.HasField("permanent_failure"): + failure = request.permanent_failure + record = read_records.pop(failure.record_id, None) + call_method_if_exists( + agent, + "permanent_failure", + record, + RuntimeError(failure.error_message), + ) read_result.append(True) except Exception as e: read_result.append(e) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.py index f99deaba8..eee8d6e15 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.py @@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n!langstream_grpc/proto/agent.proto\x1a\x1bgoogle/protobuf/empty.proto"!\n\x0cInfoResponse\x12\x11\n\tjson_info\x18\x01 \x01(\t"\xa3\x02\n\x05Value\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\x15\n\x0b\x62ytes_value\x18\x02 \x01(\x0cH\x00\x12\x17\n\rboolean_value\x18\x03 \x01(\x08H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x14\n\nbyte_value\x18\x05 \x01(\x05H\x00\x12\x15\n\x0bshort_value\x18\x06 \x01(\x05H\x00\x12\x13\n\tint_value\x18\x07 \x01(\x05H\x00\x12\x14\n\nlong_value\x18\x08 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\t \x01(\x02H\x00\x12\x16\n\x0c\x64ouble_value\x18\n \x01(\x01H\x00\x12\x14\n\njson_value\x18\x0b \x01(\tH\x00\x12\x14\n\navro_value\x18\x0c \x01(\x0cH\x00\x42\x0c\n\ntype_oneof"-\n\x06Header\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x05value\x18\x02 \x01(\x0b\x32\x06.Value"*\n\x06Schema\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x0c"\xb3\x01\n\x06Record\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x18\n\x03key\x18\x02 \x01(\x0b\x32\x06.ValueH\x00\x88\x01\x01\x12\x1a\n\x05value\x18\x03 \x01(\x0b\x32\x06.ValueH\x01\x88\x01\x01\x12\x18\n\x07headers\x18\x04 \x03(\x0b\x32\x07.Header\x12\x0e\n\x06origin\x18\x05 \x01(\t\x12\x16\n\ttimestamp\x18\x06 \x01(\x03H\x02\x88\x01\x01\x42\x06\n\x04_keyB\x08\n\x06_valueB\x0c\n\n_timestamp"*\n\rSourceRequest\x12\x19\n\x11\x63ommitted_records\x18\x01 \x03(\x03"C\n\x0eSourceResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"E\n\x10ProcessorRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"O\n\x11ProcessorResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12!\n\x07results\x18\x02 \x03(\x0b\x32\x10.ProcessorResult"\\\n\x0fProcessorResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x18\n\x07records\x18\x03 \x03(\x0b\x32\x07.RecordB\x08\n\x06_error2\xac\x01\n\x0c\x41gentService\x12\x35\n\nagent_info\x12\x16.google.protobuf.Empty\x1a\r.InfoResponse"\x00\x12-\n\x04read\x12\x0e.SourceRequest\x1a\x0f.SourceResponse"\x00(\x01\x30\x01\x12\x36\n\x07process\x12\x11.ProcessorRequest\x1a\x12.ProcessorResponse"\x00(\x01\x30\x01\x42\x1d\n\x19\x61i.langstream.agents.grpcP\x01\x62\x06proto3' + b'\n!langstream_grpc/proto/agent.proto\x1a\x1bgoogle/protobuf/empty.proto"!\n\x0cInfoResponse\x12\x11\n\tjson_info\x18\x01 \x01(\t"\xa3\x02\n\x05Value\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\x15\n\x0b\x62ytes_value\x18\x02 \x01(\x0cH\x00\x12\x17\n\rboolean_value\x18\x03 \x01(\x08H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x14\n\nbyte_value\x18\x05 \x01(\x05H\x00\x12\x15\n\x0bshort_value\x18\x06 \x01(\x05H\x00\x12\x13\n\tint_value\x18\x07 \x01(\x05H\x00\x12\x14\n\nlong_value\x18\x08 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\t \x01(\x02H\x00\x12\x16\n\x0c\x64ouble_value\x18\n \x01(\x01H\x00\x12\x14\n\njson_value\x18\x0b \x01(\tH\x00\x12\x14\n\navro_value\x18\x0c \x01(\x0cH\x00\x42\x0c\n\ntype_oneof"-\n\x06Header\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x05value\x18\x02 \x01(\x0b\x32\x06.Value"*\n\x06Schema\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x0c"\xb3\x01\n\x06Record\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x18\n\x03key\x18\x02 \x01(\x0b\x32\x06.ValueH\x00\x88\x01\x01\x12\x1a\n\x05value\x18\x03 \x01(\x0b\x32\x06.ValueH\x01\x88\x01\x01\x12\x18\n\x07headers\x18\x04 \x03(\x0b\x32\x07.Header\x12\x0e\n\x06origin\x18\x05 \x01(\t\x12\x16\n\ttimestamp\x18\x06 \x01(\x03H\x02\x88\x01\x01\x42\x06\n\x04_keyB\x08\n\x06_valueB\x0c\n\n_timestamp"<\n\x10PermanentFailure\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x15\n\rerror_message\x18\x02 \x01(\t"X\n\rSourceRequest\x12\x19\n\x11\x63ommitted_records\x18\x01 \x03(\x03\x12,\n\x11permanent_failure\x18\x02 \x01(\x0b\x32\x11.PermanentFailure"C\n\x0eSourceResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"E\n\x10ProcessorRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"O\n\x11ProcessorResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12!\n\x07results\x18\x02 \x03(\x0b\x32\x10.ProcessorResult"\\\n\x0fProcessorResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x18\n\x07records\x18\x03 \x03(\x0b\x32\x07.RecordB\x08\n\x06_error2\xac\x01\n\x0c\x41gentService\x12\x35\n\nagent_info\x12\x16.google.protobuf.Empty\x1a\r.InfoResponse"\x00\x12-\n\x04read\x12\x0e.SourceRequest\x1a\x0f.SourceResponse"\x00(\x01\x30\x01\x12\x36\n\x07process\x12\x11.ProcessorRequest\x1a\x12.ProcessorResponse"\x00(\x01\x30\x01\x42\x1d\n\x19\x61i.langstream.agents.grpcP\x01\x62\x06proto3' ) _globals = globals() @@ -53,16 +53,18 @@ _globals["_SCHEMA"]._serialized_end = 484 _globals["_RECORD"]._serialized_start = 487 _globals["_RECORD"]._serialized_end = 666 - _globals["_SOURCEREQUEST"]._serialized_start = 668 - _globals["_SOURCEREQUEST"]._serialized_end = 710 - _globals["_SOURCERESPONSE"]._serialized_start = 712 - _globals["_SOURCERESPONSE"]._serialized_end = 779 - _globals["_PROCESSORREQUEST"]._serialized_start = 781 - _globals["_PROCESSORREQUEST"]._serialized_end = 850 - _globals["_PROCESSORRESPONSE"]._serialized_start = 852 - _globals["_PROCESSORRESPONSE"]._serialized_end = 931 - _globals["_PROCESSORRESULT"]._serialized_start = 933 - _globals["_PROCESSORRESULT"]._serialized_end = 1025 - _globals["_AGENTSERVICE"]._serialized_start = 1028 - _globals["_AGENTSERVICE"]._serialized_end = 1200 + _globals["_PERMANENTFAILURE"]._serialized_start = 668 + _globals["_PERMANENTFAILURE"]._serialized_end = 728 + _globals["_SOURCEREQUEST"]._serialized_start = 730 + _globals["_SOURCEREQUEST"]._serialized_end = 818 + _globals["_SOURCERESPONSE"]._serialized_start = 820 + _globals["_SOURCERESPONSE"]._serialized_end = 887 + _globals["_PROCESSORREQUEST"]._serialized_start = 889 + _globals["_PROCESSORREQUEST"]._serialized_end = 958 + _globals["_PROCESSORRESPONSE"]._serialized_start = 960 + _globals["_PROCESSORRESPONSE"]._serialized_end = 1039 + _globals["_PROCESSORRESULT"]._serialized_start = 1041 + _globals["_PROCESSORRESULT"]._serialized_end = 1133 + _globals["_AGENTSERVICE"]._serialized_start = 1136 + _globals["_AGENTSERVICE"]._serialized_end = 1308 # @@protoc_insertion_point(module_scope) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.pyi b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.pyi index 02ea9364b..820162357 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.pyi +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.pyi @@ -119,11 +119,27 @@ class Record(_message.Message): timestamp: _Optional[int] = ..., ) -> None: ... +class PermanentFailure(_message.Message): + __slots__ = ["record_id", "error_message"] + RECORD_ID_FIELD_NUMBER: _ClassVar[int] + ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int] + record_id: int + error_message: str + def __init__( + self, record_id: _Optional[int] = ..., error_message: _Optional[str] = ... + ) -> None: ... + class SourceRequest(_message.Message): - __slots__ = ["committed_records"] + __slots__ = ["committed_records", "permanent_failure"] COMMITTED_RECORDS_FIELD_NUMBER: _ClassVar[int] + PERMANENT_FAILURE_FIELD_NUMBER: _ClassVar[int] committed_records: _containers.RepeatedScalarFieldContainer[int] - def __init__(self, committed_records: _Optional[_Iterable[int]] = ...) -> None: ... + permanent_failure: PermanentFailure + def __init__( + self, + committed_records: _Optional[_Iterable[int]] = ..., + permanent_failure: _Optional[_Union[PermanentFailure, _Mapping]] = ..., + ) -> None: ... class SourceResponse(_message.Message): __slots__ = ["schema", "records"] diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py index 5b33d3373..59577ab67 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py @@ -27,6 +27,7 @@ from langstream_grpc.proto.agent_pb2 import ( SourceResponse, SourceRequest, + PermanentFailure, ) from langstream_grpc.proto.agent_pb2_grpc import AgentServiceStub from langstream_runtime.api import Record, RecordType, Source @@ -85,7 +86,7 @@ def send_commit(): committed = 0 while committed < 2: try: - commit_id = to_commit.get(True) + commit_id = to_commit.get(True, 1) yield SourceRequest(committed_records=[commit_id]) committed += 1 except queue.Empty: @@ -101,6 +102,31 @@ def send_commit(): assert server.agent.committed[0] == server.agent.sent[0] +def test_permanent_failure(server_and_stub): + server, stub = server_and_stub + to_fail = queue.Queue() + + def send_failure(): + try: + record_id = to_fail.get(True) + yield SourceRequest( + permanent_failure=PermanentFailure( + record_id=record_id, error_message="failure" + ) + ) + except queue.Empty: + pass + + response: SourceResponse + for response in stub.read(iter(send_failure())): + for record in response.records: + to_fail.put(record.record_id) + + assert len(server.agent.failures) == 1 + assert server.agent.failures[0][0] == server.agent.sent[0] + assert str(server.agent.failures[0][1]) == "failure" + + class MySource(Source): def __init__(self): self.records = [ @@ -119,6 +145,7 @@ def __init__(self): ] self.sent = [] self.committed = [] + self.failures = [] def read(self) -> List[RecordType]: if len(self.records) > 0: @@ -132,3 +159,6 @@ def commit(self, records: List[Record]): if record.value() == 42: raise Exception("test error") self.committed.extend(records) + + def permanent_failure(self, record: Record, error: Exception): + self.failures.append((record, error))