Skip to content

Commit

Permalink
Catch all exceptions in agent gRPC observers and fail critically (#735)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet authored Nov 22, 2023
1 parent b376906 commit b951b61
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,41 +126,31 @@ private StreamObserver<ProcessorResponse> getResponseObserver() {
return new StreamObserver<>() {
@Override
public void onNext(ProcessorResponse response) {
if (response.hasSchema()) {
org.apache.avro.Schema schema =
new org.apache.avro.Schema.Parser()
.parse(response.getSchema().getValue().toStringUtf8());
serverSchemas.put(response.getSchema().getSchemaId(), schema);
try {
if (response.hasSchema()) {
org.apache.avro.Schema schema =
new org.apache.avro.Schema.Parser()
.parse(response.getSchema().getValue().toStringUtf8());
serverSchemas.put(response.getSchema().getSchemaId(), schema);
}
for (ProcessorResult result : response.getResultsList()) {
RecordAndSink recordAndSink = sourceRecords.remove(result.getRecordId());
if (recordAndSink == null) {
throw new IllegalArgumentException(
"Received unknown record id " + result.getRecordId());
} else {
recordAndSink
.sink()
.emit(fromGrpc(recordAndSink.sourceRecord(), result));
}
}
} catch (Exception e) {
agentContext.criticalFailure(
new RuntimeException(
"GrpcAgentProcessor error while processing record: %s"
.formatted(e.getMessage()),
e));
}
response.getResultsList()
.forEach(
result -> {
RecordAndSink recordAndSink =
sourceRecords.remove(result.getRecordId());
if (recordAndSink == null) {
agentContext.criticalFailure(
new RuntimeException(
"Received unknown record id "
+ result.getRecordId()));
} else {
try {
recordAndSink
.sink()
.emit(
fromGrpc(
recordAndSink.sourceRecord(),
result));
} catch (Exception e) {
agentContext.criticalFailure(
new RuntimeException(
"Error while processing record %s: %s"
.formatted(
result.getRecordId(),
e.getMessage()),
e));
}
}
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ private StreamObserver<SinkResponse> getResponseObserver() {
return new StreamObserver<>() {
@Override
public void onNext(SinkResponse response) {
if (!writeHandles.containsKey(response.getRecordId())) {
agentContext.criticalFailure(
new RuntimeException(
"GrpcAgentSink received unknown record id: %s"
.formatted(response.getRecordId())));
return;
}
CompletableFuture<?> handle = writeHandles.get(response.getRecordId());
if (response.hasError()) {
handle.completeExceptionally(new RuntimeException(response.getError()));
Expand Down Expand Up @@ -124,7 +131,7 @@ public void onError(Throwable throwable) {
public void onCompleted() {
if (startFailedButDevelopmentMode || restarting.get()) {
log.info(
"Ignoring server complietion during restart in dev mode, "
"Ignoring server completion during restart in dev mode, "
+ "ignoring records {}",
writeHandles);
writeHandles.forEach((id, handle) -> handle.complete(null));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,20 @@ private StreamObserver<SourceResponse> getResponseObserver() {
return new StreamObserver<>() {
@Override
public void onNext(SourceResponse response) {
if (response.hasSchema()) {
org.apache.avro.Schema schema =
new org.apache.avro.Schema.Parser()
.parse(response.getSchema().getValue().toStringUtf8());
serverSchemas.put(response.getSchema().getSchemaId(), schema);
}
try {
if (response.hasSchema()) {
org.apache.avro.Schema schema =
new org.apache.avro.Schema.Parser()
.parse(response.getSchema().getValue().toStringUtf8());
serverSchemas.put(response.getSchema().getSchemaId(), schema);
}

for (ai.langstream.agents.grpc.Record record : response.getRecordsList()) {
readRecords.add(fromGrpc(record));
}
} catch (Exception e) {
agentContext.criticalFailure(
new RuntimeException("Error while processing records", e));
new RuntimeException("GrpcAgentSource error while reading records", e));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ void testFailingRecord() throws Exception {
@CsvSource({
"failing-server,gRPC server sent error: INTERNAL: server error",
"completing-server,gRPC server completed the stream unexpectedly",
"wrong-record-id,Received unknown record id 2",
"wrong-schema-id,Error while processing record 1: Unknown schema id 1"
"wrong-record-id,GrpcAgentProcessor error while processing record: Received unknown record id 2",
"wrong-schema-id,GrpcAgentProcessor error while processing record: Unknown schema id 1"
})
void testServerError(String origin, String error) throws Exception {
Record inputRecord = SimpleRecord.builder().origin(origin).build();
Expand Down

0 comments on commit b951b61

Please sign in to comment.