Skip to content

Commit

Permalink
Add permanent failure to Python gRPC Source (#461)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet authored Sep 21, 2023
1 parent cff0481 commit 118260d
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,22 @@ public List<Record> read() throws Exception {
return read;
}

@Override
public void permanentFailure(Record record, Exception error) throws Exception {
if (record instanceof GrpcAgentRecord grpcAgentRecord) {
request.onNext(
SourceRequest.newBuilder()
.setPermanentFailure(
PermanentFailure.newBuilder()
.setRecordId(grpcAgentRecord.id())
.setErrorMessage(error.getMessage()))
.build());
} else {
throw new IllegalArgumentException(
"Record %s is not a GrpcAgentRecord".formatted(record));
}
}

@Override
public void commit(List<Record> records) throws Exception {
SourceRequest.Builder requestBuilder = SourceRequest.newBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,14 @@ message Record {
optional int64 timestamp = 6;
}

message PermanentFailure {
int64 record_id = 1;
string error_message = 2;
}

message SourceRequest {
repeated int64 committed_records = 1;
PermanentFailure permanent_failure = 2;
}

message SourceResponse {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,18 @@ void testAvroAndSchema() throws Exception {
source.close();
}

@Test
void testPermanentFailure() throws Exception {
GrpcAgentSource source = new GrpcAgentSource(channel);
source.setContext(new TestAgentContext());
source.start();
List<Record> read = readRecords(source, 1);
source.permanentFailure(read.get(0), new RuntimeException("permanent-failure"));
assertEquals(testSourceService.permanentFailure.getRecordId(), 42);
assertEquals(testSourceService.permanentFailure.getErrorMessage(), "permanent-failure");
source.close();
}

static List<Record> readRecords(GrpcAgentSource source, int numberOfRecords) {
List<Record> read = new ArrayList<>();
await().atMost(5, TimeUnit.SECONDS)
Expand All @@ -151,6 +163,7 @@ static byte[] serializeGenericRecord(GenericRecord record) throws IOException {
static class TestSourceService extends AgentServiceGrpc.AgentServiceImplBase {

final List<Long> committedRecords = new CopyOnWriteArrayList<>();
PermanentFailure permanentFailure;

@Override
public StreamObserver<SourceRequest> read(StreamObserver<SourceResponse> responseObserver) {
Expand Down Expand Up @@ -200,6 +213,9 @@ public StreamObserver<SourceRequest> read(StreamObserver<SourceResponse> respons
@Override
public void onNext(SourceRequest request) {
committedRecords.addAll(request.getCommittedRecordsList());
if (request.hasPermanentFailure()) {
permanentFailure = request.getPermanentFailure();
}
if (request.getCommittedRecordsList().contains(43L)) {
responseObserver.onError(new RuntimeException("test error"));
} else if (request.getCommittedRecordsList().contains(44L)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ def handle_requests(
if record is not None:
records.append(record)
call_method_if_exists(agent, "commit", records)
if request.HasField("permanent_failure"):
failure = request.permanent_failure
record = read_records.pop(failure.record_id, None)
call_method_if_exists(
agent,
"permanent_failure",
record,
RuntimeError(failure.error_message),
)
read_result.append(True)
except Exception as e:
read_result.append(e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n!langstream_grpc/proto/agent.proto\x1a\x1bgoogle/protobuf/empty.proto"!\n\x0cInfoResponse\x12\x11\n\tjson_info\x18\x01 \x01(\t"\xa3\x02\n\x05Value\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\x15\n\x0b\x62ytes_value\x18\x02 \x01(\x0cH\x00\x12\x17\n\rboolean_value\x18\x03 \x01(\x08H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x14\n\nbyte_value\x18\x05 \x01(\x05H\x00\x12\x15\n\x0bshort_value\x18\x06 \x01(\x05H\x00\x12\x13\n\tint_value\x18\x07 \x01(\x05H\x00\x12\x14\n\nlong_value\x18\x08 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\t \x01(\x02H\x00\x12\x16\n\x0c\x64ouble_value\x18\n \x01(\x01H\x00\x12\x14\n\njson_value\x18\x0b \x01(\tH\x00\x12\x14\n\navro_value\x18\x0c \x01(\x0cH\x00\x42\x0c\n\ntype_oneof"-\n\x06Header\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x05value\x18\x02 \x01(\x0b\x32\x06.Value"*\n\x06Schema\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x0c"\xb3\x01\n\x06Record\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x18\n\x03key\x18\x02 \x01(\x0b\x32\x06.ValueH\x00\x88\x01\x01\x12\x1a\n\x05value\x18\x03 \x01(\x0b\x32\x06.ValueH\x01\x88\x01\x01\x12\x18\n\x07headers\x18\x04 \x03(\x0b\x32\x07.Header\x12\x0e\n\x06origin\x18\x05 \x01(\t\x12\x16\n\ttimestamp\x18\x06 \x01(\x03H\x02\x88\x01\x01\x42\x06\n\x04_keyB\x08\n\x06_valueB\x0c\n\n_timestamp"*\n\rSourceRequest\x12\x19\n\x11\x63ommitted_records\x18\x01 \x03(\x03"C\n\x0eSourceResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"E\n\x10ProcessorRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"O\n\x11ProcessorResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12!\n\x07results\x18\x02 \x03(\x0b\x32\x10.ProcessorResult"\\\n\x0fProcessorResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x18\n\x07records\x18\x03 \x03(\x0b\x32\x07.RecordB\x08\n\x06_error2\xac\x01\n\x0c\x41gentService\x12\x35\n\nagent_info\x12\x16.google.protobuf.Empty\x1a\r.InfoResponse"\x00\x12-\n\x04read\x12\x0e.SourceRequest\x1a\x0f.SourceResponse"\x00(\x01\x30\x01\x12\x36\n\x07process\x12\x11.ProcessorRequest\x1a\x12.ProcessorResponse"\x00(\x01\x30\x01\x42\x1d\n\x19\x61i.langstream.agents.grpcP\x01\x62\x06proto3'
b'\n!langstream_grpc/proto/agent.proto\x1a\x1bgoogle/protobuf/empty.proto"!\n\x0cInfoResponse\x12\x11\n\tjson_info\x18\x01 \x01(\t"\xa3\x02\n\x05Value\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\x15\n\x0b\x62ytes_value\x18\x02 \x01(\x0cH\x00\x12\x17\n\rboolean_value\x18\x03 \x01(\x08H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x14\n\nbyte_value\x18\x05 \x01(\x05H\x00\x12\x15\n\x0bshort_value\x18\x06 \x01(\x05H\x00\x12\x13\n\tint_value\x18\x07 \x01(\x05H\x00\x12\x14\n\nlong_value\x18\x08 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\t \x01(\x02H\x00\x12\x16\n\x0c\x64ouble_value\x18\n \x01(\x01H\x00\x12\x14\n\njson_value\x18\x0b \x01(\tH\x00\x12\x14\n\navro_value\x18\x0c \x01(\x0cH\x00\x42\x0c\n\ntype_oneof"-\n\x06Header\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x05value\x18\x02 \x01(\x0b\x32\x06.Value"*\n\x06Schema\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x0c"\xb3\x01\n\x06Record\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x18\n\x03key\x18\x02 \x01(\x0b\x32\x06.ValueH\x00\x88\x01\x01\x12\x1a\n\x05value\x18\x03 \x01(\x0b\x32\x06.ValueH\x01\x88\x01\x01\x12\x18\n\x07headers\x18\x04 \x03(\x0b\x32\x07.Header\x12\x0e\n\x06origin\x18\x05 \x01(\t\x12\x16\n\ttimestamp\x18\x06 \x01(\x03H\x02\x88\x01\x01\x42\x06\n\x04_keyB\x08\n\x06_valueB\x0c\n\n_timestamp"<\n\x10PermanentFailure\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x15\n\rerror_message\x18\x02 \x01(\t"X\n\rSourceRequest\x12\x19\n\x11\x63ommitted_records\x18\x01 \x03(\x03\x12,\n\x11permanent_failure\x18\x02 \x01(\x0b\x32\x11.PermanentFailure"C\n\x0eSourceResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"E\n\x10ProcessorRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"O\n\x11ProcessorResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12!\n\x07results\x18\x02 \x03(\x0b\x32\x10.ProcessorResult"\\\n\x0fProcessorResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x18\n\x07records\x18\x03 \x03(\x0b\x32\x07.RecordB\x08\n\x06_error2\xac\x01\n\x0c\x41gentService\x12\x35\n\nagent_info\x12\x16.google.protobuf.Empty\x1a\r.InfoResponse"\x00\x12-\n\x04read\x12\x0e.SourceRequest\x1a\x0f.SourceResponse"\x00(\x01\x30\x01\x12\x36\n\x07process\x12\x11.ProcessorRequest\x1a\x12.ProcessorResponse"\x00(\x01\x30\x01\x42\x1d\n\x19\x61i.langstream.agents.grpcP\x01\x62\x06proto3'
)

_globals = globals()
Expand All @@ -53,16 +53,18 @@
_globals["_SCHEMA"]._serialized_end = 484
_globals["_RECORD"]._serialized_start = 487
_globals["_RECORD"]._serialized_end = 666
_globals["_SOURCEREQUEST"]._serialized_start = 668
_globals["_SOURCEREQUEST"]._serialized_end = 710
_globals["_SOURCERESPONSE"]._serialized_start = 712
_globals["_SOURCERESPONSE"]._serialized_end = 779
_globals["_PROCESSORREQUEST"]._serialized_start = 781
_globals["_PROCESSORREQUEST"]._serialized_end = 850
_globals["_PROCESSORRESPONSE"]._serialized_start = 852
_globals["_PROCESSORRESPONSE"]._serialized_end = 931
_globals["_PROCESSORRESULT"]._serialized_start = 933
_globals["_PROCESSORRESULT"]._serialized_end = 1025
_globals["_AGENTSERVICE"]._serialized_start = 1028
_globals["_AGENTSERVICE"]._serialized_end = 1200
_globals["_PERMANENTFAILURE"]._serialized_start = 668
_globals["_PERMANENTFAILURE"]._serialized_end = 728
_globals["_SOURCEREQUEST"]._serialized_start = 730
_globals["_SOURCEREQUEST"]._serialized_end = 818
_globals["_SOURCERESPONSE"]._serialized_start = 820
_globals["_SOURCERESPONSE"]._serialized_end = 887
_globals["_PROCESSORREQUEST"]._serialized_start = 889
_globals["_PROCESSORREQUEST"]._serialized_end = 958
_globals["_PROCESSORRESPONSE"]._serialized_start = 960
_globals["_PROCESSORRESPONSE"]._serialized_end = 1039
_globals["_PROCESSORRESULT"]._serialized_start = 1041
_globals["_PROCESSORRESULT"]._serialized_end = 1133
_globals["_AGENTSERVICE"]._serialized_start = 1136
_globals["_AGENTSERVICE"]._serialized_end = 1308
# @@protoc_insertion_point(module_scope)
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,27 @@ class Record(_message.Message):
timestamp: _Optional[int] = ...,
) -> None: ...

class PermanentFailure(_message.Message):
__slots__ = ["record_id", "error_message"]
RECORD_ID_FIELD_NUMBER: _ClassVar[int]
ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int]
record_id: int
error_message: str
def __init__(
self, record_id: _Optional[int] = ..., error_message: _Optional[str] = ...
) -> None: ...

class SourceRequest(_message.Message):
__slots__ = ["committed_records"]
__slots__ = ["committed_records", "permanent_failure"]
COMMITTED_RECORDS_FIELD_NUMBER: _ClassVar[int]
PERMANENT_FAILURE_FIELD_NUMBER: _ClassVar[int]
committed_records: _containers.RepeatedScalarFieldContainer[int]
def __init__(self, committed_records: _Optional[_Iterable[int]] = ...) -> None: ...
permanent_failure: PermanentFailure
def __init__(
self,
committed_records: _Optional[_Iterable[int]] = ...,
permanent_failure: _Optional[_Union[PermanentFailure, _Mapping]] = ...,
) -> None: ...

class SourceResponse(_message.Message):
__slots__ = ["schema", "records"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from langstream_grpc.proto.agent_pb2 import (
SourceResponse,
SourceRequest,
PermanentFailure,
)
from langstream_grpc.proto.agent_pb2_grpc import AgentServiceStub
from langstream_runtime.api import Record, RecordType, Source
Expand Down Expand Up @@ -85,7 +86,7 @@ def send_commit():
committed = 0
while committed < 2:
try:
commit_id = to_commit.get(True)
commit_id = to_commit.get(True, 1)
yield SourceRequest(committed_records=[commit_id])
committed += 1
except queue.Empty:
Expand All @@ -101,6 +102,31 @@ def send_commit():
assert server.agent.committed[0] == server.agent.sent[0]


def test_permanent_failure(server_and_stub):
server, stub = server_and_stub
to_fail = queue.Queue()

def send_failure():
try:
record_id = to_fail.get(True)
yield SourceRequest(
permanent_failure=PermanentFailure(
record_id=record_id, error_message="failure"
)
)
except queue.Empty:
pass

response: SourceResponse
for response in stub.read(iter(send_failure())):
for record in response.records:
to_fail.put(record.record_id)

assert len(server.agent.failures) == 1
assert server.agent.failures[0][0] == server.agent.sent[0]
assert str(server.agent.failures[0][1]) == "failure"


class MySource(Source):
def __init__(self):
self.records = [
Expand All @@ -119,6 +145,7 @@ def __init__(self):
]
self.sent = []
self.committed = []
self.failures = []

def read(self) -> List[RecordType]:
if len(self.records) > 0:
Expand All @@ -132,3 +159,6 @@ def commit(self, records: List[Record]):
if record.value() == 42:
raise Exception("test error")
self.committed.extend(records)

def permanent_failure(self, record: Record, error: Exception):
self.failures.append((record, error))

0 comments on commit 118260d

Please sign in to comment.