Skip to content

Commit

Permalink
[ML] Validate streaming HTTP Response (elastic#112481) (elastic#113134)
Browse files Browse the repository at this point in the history
Read the first HttpResult from the stream and validate the HttpResponse
object, invoking the listener to retry or fail the request.  On success,
ResponseHandlers will create an InferenceServiceResults with the
Flow.Publisher to parse the bytes for the provider.

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Co-authored-by: David Kyle <david.kyle@elastic.co>
  • Loading branch information
3 people authored Sep 18, 2024
1 parent 88a4fb2 commit 2a94927
Show file tree
Hide file tree
Showing 7 changed files with 413 additions and 18 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/112481.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 112481
summary: Validate streaming HTTP Response
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,17 @@ private ActionListener<InferenceServiceResults> createListener(
) {
if (request.isStreaming()) {
return listener.delegateFailureAndWrap((l, inferenceResults) -> {
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
inferenceResults.publisher().subscribe(taskProcessor);
l.onResponse(new InferenceAction.Response(inferenceResults, taskProcessor));
if (inferenceResults.isStreaming()) {
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(
STREAMING_INFERENCE_TASK_TYPE,
STREAMING_TASK_ACTION
);
inferenceResults.publisher().subscribe(taskProcessor);
l.onResponse(new InferenceAction.Response(inferenceResults, taskProcessor));
} else {
// if we asked for streaming but the provider doesn't support it, for now we're going to get back the single response
l.onResponse(new InferenceAction.Response(inferenceResults));
}
});
}
return listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@ default boolean canHandleStreamingResponses() {
}

/**
* A method for parsing the streamed response from the server.
* A method for parsing the streamed response from the server. Implementations must invoke the
* {@link Flow.Publisher#subscribe(Flow.Subscriber)} method on the {@code Flow.Publisher<HttpResult> flow} parameter in order to stream
* HttpResults to the InferenceServiceResults.
*
* @param request The original request sent to the server
* @param result The first result that initiated the stream. If the result is HTTP 200, this result will not contain content bytes
* @param flow The remaining stream of results from the server. If the result is HTTP 200, these results will contain content bytes
* @return an inference results with {@link InferenceServiceResults#publisher()} set and {@link InferenceServiceResults#isStreaming()}
* set to true
* set to true.
*/
default InferenceServiceResults parseResult(Request request, HttpResult result, Flow.Publisher<HttpResult> flow) {
assert canHandleStreamingResponses() == false : "This must be implemented when canHandleStreamingResponses() == true";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,23 +108,29 @@ public void tryAction(ActionListener<InferenceServiceResults> listener) {
return;
}

ActionListener<HttpResult> responseListener = ActionListener.wrap(result -> {
try {
responseHandler.validateResponse(throttlerManager, logger, request, result);
InferenceServiceResults inferenceResults = responseHandler.parseResult(request, result);

listener.onResponse(inferenceResults);
} catch (Exception e) {
logException(logger, request, result, responseHandler.getRequestType(), e);
listener.onFailure(e);
}
}, e -> {
var retryableListener = listener.delegateResponse((l, e) -> {
logException(logger, request, responseHandler.getRequestType(), e);
listener.onFailure(transformIfRetryable(e));
l.onFailure(transformIfRetryable(e));
});

try {
httpClient.send(request.createHttpRequest(), context, responseListener);
if (request.isStreaming() && responseHandler.canHandleStreamingResponses()) {
httpClient.stream(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {
r.subscribe(new StreamingResponseHandler(throttlerManager, logger, request, responseHandler, l));
}));
} else {
httpClient.send(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {
try {
responseHandler.validateResponse(throttlerManager, logger, request, r);
InferenceServiceResults inferenceResults = responseHandler.parseResult(request, r);

l.onResponse(inferenceResults);
} catch (Exception e) {
logException(logger, request, r, responseHandler.getRequestType(), e);
listener.onFailure(e); // skip retrying
}
}));
}
} catch (Exception e) {
logException(logger, request, responseHandler.getRequestType(), e);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.retry;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicBoolean;

import static org.elasticsearch.core.Strings.format;

class StreamingResponseHandler implements Flow.Processor<HttpResult, HttpResult> {
private static final Logger log = LogManager.getLogger(StreamingResponseHandler.class);
private final ThrottlerManager throttlerManager;
private final Logger throttlerLogger;
private final Request request;
private final ResponseHandler responseHandler;
private final ActionListener<InferenceServiceResults> listener;

private final AtomicBoolean upstreamIsClosed = new AtomicBoolean(false);
private final AtomicBoolean processedFirstItem = new AtomicBoolean(false);

private volatile Flow.Subscription upstream;
private volatile Flow.Subscriber<? super HttpResult> downstream;

StreamingResponseHandler(
ThrottlerManager throttlerManager,
Logger throttlerLogger,
Request request,
ResponseHandler responseHandler,
ActionListener<InferenceServiceResults> listener
) {
this.throttlerManager = throttlerManager;
this.throttlerLogger = throttlerLogger;
this.request = request;
this.responseHandler = responseHandler;
this.listener = listener;
}

@Override
public void subscribe(Flow.Subscriber<? super HttpResult> subscriber) {
if (downstream != null) {
subscriber.onError(
new IllegalStateException("Failed to initialize streaming response. Another subscriber is already subscribed.")
);
return;
}

downstream = subscriber;
subscriber.onSubscribe(forwardingSubscription());
}

private Flow.Subscription forwardingSubscription() {
return new Flow.Subscription() {
@Override
public void request(long n) {
if (upstreamIsClosed.get()) {
downstream.onComplete(); // shouldn't happen, but reinforce that we're no longer listening
} else if (upstream != null) {
upstream.request(n);
} else {
// this shouldn't happen, the expected call pattern is onNext -> subscribe after the listener is invoked
var errorMessage = "Failed to initialize streaming response. onSubscribe must be called first to set the upstream";
assert false : errorMessage;
downstream.onError(new IllegalStateException(errorMessage));
}
}

@Override
public void cancel() {
if (upstreamIsClosed.compareAndSet(false, true) && upstream != null) {
upstream.cancel();
}
}
};
}

@Override
public void onSubscribe(Flow.Subscription subscription) {
upstream = subscription;
// start the first request, which will call onNext and validate the first HttpResult
upstream.request(1);
}

@Override
public void onNext(HttpResult item) {
if (processedFirstItem.compareAndSet(false, true)) {
try {
responseHandler.validateResponse(throttlerManager, throttlerLogger, request, item);
var inferenceServiceResults = responseHandler.parseResult(request, item, this);
assert downstream != null : "the responseHandler must invoke the subscribe method";
listener.onResponse(inferenceServiceResults);
} catch (Exception e) {
logException(throttlerLogger, request, item, responseHandler.getRequestType(), e);
listener.onFailure(e);
upstream.cancel();
onError(e);
}
} else {
downstream.onNext(item);
}
}

@Override
public void onError(Throwable throwable) {
if (upstreamIsClosed.compareAndSet(false, true)) {
if (downstream != null) {
downstream.onError(throwable);
} else {
log.warn(
"Flow failed before the InferenceServiceResults were generated. The error should go to the listener directly.",
throwable
);
}
}
}

@Override
public void onComplete() {
if (upstreamIsClosed.compareAndSet(false, true)) {
if (downstream != null) {
downstream.onComplete();
} else {
log.debug("Flow completed before the InferenceServiceResults were generated. Shutting down this Processor.");
}
}
}

private void logException(Logger logger, Request request, HttpResult result, String requestType, Exception exception) {
var causeException = ExceptionsHelper.unwrapCause(exception);

throttlerManager.warn(
logger,
format(
"Failed to process the stream connection for request from inference entity id [%s] of type [%s] with status [%s] [%s]",
request.getInferenceEntityId(),
requestType,
result.response().getStatusLine().getStatusCode(),
result.response().getStatusLine().getReasonPhrase()
),
causeException
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,17 @@

import java.io.IOException;
import java.net.UnknownHostException;
import java.util.concurrent.Flow;

import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.createDefaultRetrySettings;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.sameInstance;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.only;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
Expand Down Expand Up @@ -454,6 +457,56 @@ public void testSend_ReturnsFailure_WhenHttpResultsListenerCallsOnFailure_WithNo
verifyNoMoreInteractions(httpClient);
}

public void testStream() throws IOException {
var httpClient = mock(HttpClient.class);
Flow.Publisher<HttpResult> publisher = mock();
doAnswer(ans -> {
ActionListener<Flow.Publisher<HttpResult>> listener = ans.getArgument(2);
listener.onResponse(publisher);
return null;
}).when(httpClient).stream(any(), any(), any());

var retrier = createRetrier(httpClient);

ActionListener<InferenceServiceResults> listener = mock();
var request = mockRequest();
when(request.isStreaming()).thenReturn(true);
var responseHandler = mock(ResponseHandler.class);
when(responseHandler.canHandleStreamingResponses()).thenReturn(true);
executeTasks(() -> retrier.send(mock(Logger.class), request, () -> false, responseHandler, listener), 0);

verify(httpClient, times(1)).stream(any(), any(), any());
verifyNoMoreInteractions(httpClient);
verify(publisher, only()).subscribe(any(StreamingResponseHandler.class));
}

public void testStream_ResponseHandlerDoesNotHandleStreams() throws IOException {
var httpClient = mock(HttpClient.class);
doAnswer(ans -> {
ActionListener<HttpResult> listener = ans.getArgument(2);
listener.onResponse(new HttpResult(mock(), new byte[0]));
return null;
}).when(httpClient).send(any(), any(), any());

var expectedResponse = mock(InferenceServiceResults.class);

var retrier = createRetrier(httpClient);

var listener = new PlainActionFuture<InferenceServiceResults>();
var request = mockRequest();
when(request.isStreaming()).thenReturn(true);
var responseHandler = mock(ResponseHandler.class);
when(responseHandler.parseResult(any(), any())).thenReturn(expectedResponse);
when(responseHandler.canHandleStreamingResponses()).thenReturn(false);
executeTasks(() -> retrier.send(mock(Logger.class), request, () -> false, responseHandler, listener), 0);

var actualResponse = listener.actionGet(TIMEOUT);

verify(httpClient, times(1)).send(any(), any(), any());
verifyNoMoreInteractions(httpClient);
assertThat(actualResponse, sameInstance(expectedResponse));
}

public void testSend_DoesNotRetryIndefinitely() throws IOException {
var threadPool = new TestThreadPool(getTestName());
try {
Expand Down
Loading

0 comments on commit 2a94927

Please sign in to comment.