Skip to content

Commit

Permalink
Fix service gateway with topics (#683)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoloboschi authored Nov 3, 2023
1 parent 4171599 commit f3de246
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ public void startReadingAsync(
log.debug("[{}] Started reader", logRef);
readMessages(stop, onMessage);
} catch (Throwable ex) {
log.error("[{}] Error reading messages", logRef, ex);
throw new RuntimeException(ex);
} finally {
closeReader();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,20 @@ protected TopicProducer setupProducer(String topic, StreamingCluster streamingCl
}

public void produceMessage(String payload) throws ProduceException {
final ProduceRequest produceRequest = parseProduceRequest(payload);
produceMessage(produceRequest);
}

public static ProduceRequest parseProduceRequest(String payload) throws ProduceException {
final ProduceRequest produceRequest;
try {
produceRequest = mapper.readValue(payload, ProduceRequest.class);
} catch (JsonProcessingException err) {
throw new ProduceException(err.getMessage(), ProduceResponse.Status.BAD_REQUEST);
throw new ProduceException(
"Error while parsing JSON payload: " + err.getMessage(),
ProduceResponse.Status.BAD_REQUEST);
}
produceMessage(produceRequest);
return produceRequest;
}

public void produceMessage(ProduceRequest produceRequest) throws ProduceException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -74,6 +76,7 @@ public class GatewayResource {
protected static final String GATEWAY_SERVICE_PATH =
"/service/{tenant}/{application}/{gateway}/**";
protected static final ObjectMapper MAPPER = new ObjectMapper();
protected static final String SERVICE_REQUEST_ID_HEADER = "langstream-service-request-id";
private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider;
private final TopicProducerCache topicProducerCache;
private final ApplicationStore applicationStore;
Expand All @@ -87,15 +90,13 @@ public class GatewayResource {
Executors.newCachedThreadPool(
new BasicThreadFactory.Builder().namingPattern("http-consume-%d").build());

@PostMapping(
value = "/produce/{tenant}/{application}/{gateway}",
consumes = MediaType.APPLICATION_JSON_VALUE)
@PostMapping(value = "/produce/{tenant}/{application}/{gateway}", consumes = "*/*")
ProduceResponse produce(
WebRequest request,
@NotBlank @PathVariable("tenant") String tenant,
@NotBlank @PathVariable("application") String application,
@NotBlank @PathVariable("gateway") String gateway,
@RequestBody ProduceRequest produceRequest)
@RequestBody String payload)
throws ProduceGateway.ProduceException {

final Map<String, String> queryString = computeQueryString(request);
Expand Down Expand Up @@ -125,11 +126,26 @@ ProduceResponse produce(
ProduceGateway.getProducerCommonHeaders(
context.gateway().getProduceOptions(), authContext);
produceGateway.start(context.gateway().getTopic(), commonHeaders, authContext);
final ProduceRequest produceRequest = parseProduceRequest(request, payload);
produceGateway.produceMessage(produceRequest);
return ProduceResponse.OK;
}
}

private ProduceRequest parseProduceRequest(WebRequest request, String payload)
throws ProduceGateway.ProduceException {
final String contentType = request.getHeader("Content-Type");
if (contentType == null || contentType.equals(MediaType.TEXT_PLAIN_VALUE)) {
return new ProduceRequest(null, payload, null);
} else if (contentType.equals(MediaType.APPLICATION_JSON_VALUE)) {
return ProduceGateway.parseProduceRequest(payload);
} else {
throw new ResponseStatusException(
HttpStatus.BAD_REQUEST,
String.format("Unsupported content type: %s", contentType));
}
}

private Map<String, String> computeHeaders(WebRequest request) {
final Map<String, String> headers = new HashMap<>();
request.getHeaderNames()
Expand Down Expand Up @@ -187,7 +203,7 @@ private CompletableFuture<ResponseEntity> handleServiceCall(
String tenant,
String application,
String gateway)
throws IOException {
throws IOException, ProduceGateway.ProduceException {
final Map<String, String> queryString = computeQueryString(request);
final Map<String, String> headers = computeHeaders(request);
final GatewayRequestContext context =
Expand Down Expand Up @@ -225,24 +241,37 @@ public void validateOptions(Map<String, String> options) {}
throw new ResponseStatusException(
HttpStatus.BAD_REQUEST, "Only POST method is supported");
}
final ProduceRequest produceRequest =
MAPPER.readValue(servletRequest.getInputStream(), ProduceRequest.class);
final String payload =
new String(
servletRequest.getInputStream().readAllBytes(), StandardCharsets.UTF_8);
final ProduceRequest produceRequest = parseProduceRequest(request, payload);
return handleServiceWithTopics(produceRequest, authContext);
}
}

private CompletableFuture<ResponseEntity> handleServiceWithTopics(
ProduceRequest produceRequest, AuthenticatedGatewayRequestContext authContext) {

final String langstreamServiceRequestId = UUID.randomUUID().toString();

final CompletableFuture<ResponseEntity> completableFuture = new CompletableFuture<>();
try (final ConsumeGateway consumeGateway =
new ConsumeGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry());
final ProduceGateway produceGateway =
new ProduceGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry(),
topicProducerCache); ) {
try (final ProduceGateway produceGateway =
new ProduceGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry(),
topicProducerCache); ) {

final ConsumeGateway consumeGateway =
new ConsumeGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry());
completableFuture.thenRunAsync(
() -> {
if (consumeGateway != null) {
consumeGateway.close();
}
},
consumeThreadPool);

final Gateway.ServiceOptions serviceOptions = authContext.gateway().getServiceOptions();
try {
Expand All @@ -251,7 +280,15 @@ private CompletableFuture<ResponseEntity> handleServiceWithTopics(
serviceOptions.getHeaders(),
authContext.userParameters(),
authContext.principalValues());
consumeGateway.setup(serviceOptions.getInputTopic(), messageFilters, authContext);
messageFilters.add(
record -> {
final Header header = record.getHeader(SERVICE_REQUEST_ID_HEADER);
if (header == null) {
return false;
}
return langstreamServiceRequestId.equals(header.valueAsString());
});
consumeGateway.setup(serviceOptions.getOutputTopic(), messageFilters, authContext);
final AtomicBoolean stop = new AtomicBoolean(false);
consumeGateway.startReadingAsync(
consumeThreadPool,
Expand All @@ -266,9 +303,18 @@ record -> {
}
final List<Header> commonHeaders =
ProduceGateway.getProducerCommonHeaders(serviceOptions, authContext);
produceGateway.start(serviceOptions.getOutputTopic(), commonHeaders, authContext);
produceGateway.produceMessage(produceRequest);
produceGateway.start(serviceOptions.getInputTopic(), commonHeaders, authContext);

Map<String, String> passedHeaders = produceRequest.headers();
if (passedHeaders == null) {
passedHeaders = new HashMap<>();
}
passedHeaders.put(SERVICE_REQUEST_ID_HEADER, langstreamServiceRequestId);
produceGateway.produceMessage(
new ProduceRequest(
produceRequest.key(), produceRequest.value(), passedHeaders));
} catch (Throwable t) {
log.error("Error on service gateway", t);
completableFuture.completeExceptionally(t);
}
return completableFuture;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.springframework.http.ProblemDetail;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.context.request.async.AsyncRequestTimeoutException;
import org.springframework.web.server.ResponseStatusException;

@ControllerAdvice
Expand All @@ -38,6 +39,11 @@ ProblemDetail handleAll(Throwable exception) {
log.error("Bad request", exception);
return ProblemDetail.forStatusAndDetail(HttpStatus.BAD_REQUEST, exception.getMessage());
}
if (exception instanceof AsyncRequestTimeoutException) {
log.error("Request timed out", exception);
return ProblemDetail.forStatusAndDetail(
HttpStatus.REQUEST_TIMEOUT, "Request timed out");
}
log.error("Internal error", exception);
return ProblemDetail.forStatusAndDetail(
HttpStatus.INTERNAL_SERVER_ERROR, exception.getMessage());
Expand Down
Loading

0 comments on commit f3de246

Please sign in to comment.