From b951b61185b772b3e2432d06e047968c20187ab4 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 22 Nov 2023 18:12:40 +0100 Subject: [PATCH] Catch all exceptions in agent gRPC observers and fail critically (#735) --- .../agents/grpc/GrpcAgentProcessor.java | 58 ++++++++----------- .../langstream/agents/grpc/GrpcAgentSink.java | 9 ++- .../agents/grpc/GrpcAgentSource.java | 15 ++--- .../agents/grpc/GrpcAgentProcessorTest.java | 4 +- 4 files changed, 42 insertions(+), 44 deletions(-) diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentProcessor.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentProcessor.java index 049d2eb4f..528d9c79c 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentProcessor.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentProcessor.java @@ -126,41 +126,31 @@ private StreamObserver getResponseObserver() { return new StreamObserver<>() { @Override public void onNext(ProcessorResponse response) { - if (response.hasSchema()) { - org.apache.avro.Schema schema = - new org.apache.avro.Schema.Parser() - .parse(response.getSchema().getValue().toStringUtf8()); - serverSchemas.put(response.getSchema().getSchemaId(), schema); + try { + if (response.hasSchema()) { + org.apache.avro.Schema schema = + new org.apache.avro.Schema.Parser() + .parse(response.getSchema().getValue().toStringUtf8()); + serverSchemas.put(response.getSchema().getSchemaId(), schema); + } + for (ProcessorResult result : response.getResultsList()) { + RecordAndSink recordAndSink = sourceRecords.remove(result.getRecordId()); + if (recordAndSink == null) { + throw new IllegalArgumentException( + "Received unknown record id " + result.getRecordId()); + } else { + recordAndSink + .sink() + .emit(fromGrpc(recordAndSink.sourceRecord(), result)); + } + } + } catch (Exception e) { + agentContext.criticalFailure( + new RuntimeException( + "GrpcAgentProcessor error while processing record: %s" + .formatted(e.getMessage()), + e)); } - response.getResultsList() - .forEach( - result -> { - RecordAndSink recordAndSink = - sourceRecords.remove(result.getRecordId()); - if (recordAndSink == null) { - agentContext.criticalFailure( - new RuntimeException( - "Received unknown record id " - + result.getRecordId())); - } else { - try { - recordAndSink - .sink() - .emit( - fromGrpc( - recordAndSink.sourceRecord(), - result)); - } catch (Exception e) { - agentContext.criticalFailure( - new RuntimeException( - "Error while processing record %s: %s" - .formatted( - result.getRecordId(), - e.getMessage()), - e)); - } - } - }); } @Override diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSink.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSink.java index 8954d5154..c96fc9895 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSink.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSink.java @@ -94,6 +94,13 @@ private StreamObserver getResponseObserver() { return new StreamObserver<>() { @Override public void onNext(SinkResponse response) { + if (!writeHandles.containsKey(response.getRecordId())) { + agentContext.criticalFailure( + new RuntimeException( + "GrpcAgentSink received unknown record id: %s" + .formatted(response.getRecordId()))); + return; + } CompletableFuture handle = writeHandles.get(response.getRecordId()); if (response.hasError()) { handle.completeExceptionally(new RuntimeException(response.getError())); @@ -124,7 +131,7 @@ public void onError(Throwable throwable) { public void onCompleted() { if (startFailedButDevelopmentMode || restarting.get()) { log.info( - "Ignoring server complietion during restart in dev mode, " + "Ignoring server completion during restart in dev mode, " + "ignoring records {}", writeHandles); writeHandles.forEach((id, handle) -> handle.complete(null)); diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSource.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSource.java index 3a377ea8a..30a75ae41 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSource.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSource.java @@ -105,19 +105,20 @@ private StreamObserver getResponseObserver() { return new StreamObserver<>() { @Override public void onNext(SourceResponse response) { - if (response.hasSchema()) { - org.apache.avro.Schema schema = - new org.apache.avro.Schema.Parser() - .parse(response.getSchema().getValue().toStringUtf8()); - serverSchemas.put(response.getSchema().getSchemaId(), schema); - } try { + if (response.hasSchema()) { + org.apache.avro.Schema schema = + new org.apache.avro.Schema.Parser() + .parse(response.getSchema().getValue().toStringUtf8()); + serverSchemas.put(response.getSchema().getSchemaId(), schema); + } + for (ai.langstream.agents.grpc.Record record : response.getRecordsList()) { readRecords.add(fromGrpc(record)); } } catch (Exception e) { agentContext.criticalFailure( - new RuntimeException("Error while processing records", e)); + new RuntimeException("GrpcAgentSource error while reading records", e)); } } diff --git a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentProcessorTest.java b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentProcessorTest.java index 6ff615da2..1dab229f1 100644 --- a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentProcessorTest.java +++ b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentProcessorTest.java @@ -227,8 +227,8 @@ void testFailingRecord() throws Exception { @CsvSource({ "failing-server,gRPC server sent error: INTERNAL: server error", "completing-server,gRPC server completed the stream unexpectedly", - "wrong-record-id,Received unknown record id 2", - "wrong-schema-id,Error while processing record 1: Unknown schema id 1" + "wrong-record-id,GrpcAgentProcessor error while processing record: Received unknown record id 2", + "wrong-schema-id,GrpcAgentProcessor error while processing record: Unknown schema id 1" }) void testServerError(String origin, String error) throws Exception { Record inputRecord = SimpleRecord.builder().origin(origin).build();