Skip to content

Commit

Permalink
Add capability to stop async query on demand
Browse files Browse the repository at this point in the history
  • Loading branch information
smalyshev committed Dec 4, 2024
1 parent 5633936 commit 037564d
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.core.async;

import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;

import java.io.IOException;
import java.util.Objects;

public class AsyncStopRequest extends ActionRequest {
private final String id;

/**
* Creates a new request
*
* @param id The id of the search progress request.
*/
public AsyncStopRequest(String id) {
this.id = id;
}

public AsyncStopRequest(StreamInput in) throws IOException {
super(in);
this.id = in.readString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(id);
}

@Override
public ActionRequestValidationException validate() {
return null;
}

/**
* Returns the id of the async search.
*/
public String getId() {
return id;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AsyncStopRequest request = (AsyncStopRequest) o;
return Objects.equals(id, request.id);
}

@Override
public int hashCode() {
return Objects.hash(id);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
*/
public class EsqlAsyncActionNames {
public static final String ESQL_ASYNC_GET_RESULT_ACTION_NAME = "indices:data/read/esql/async/get";
public static final String ESQL_ASYNC_STOP_ACTION_NAME = "indices:data/read/esql/async/stop";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.action;

import org.elasticsearch.action.ActionType;
import org.elasticsearch.xpack.core.esql.EsqlAsyncActionNames;

public class EsqlAsyncStopAction extends ActionType<EsqlQueryResponse> {

public static final EsqlAsyncStopAction INSTANCE = new EsqlAsyncStopAction();

public static final String NAME = EsqlAsyncActionNames.ESQL_ASYNC_STOP_ACTION_NAME;

private EsqlAsyncStopAction() {
super(NAME);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.action;

import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.Scope;
import org.elasticsearch.rest.ServerlessScope;
import org.elasticsearch.rest.action.RestRefCountedChunkedToXContentListener;
import org.elasticsearch.xpack.core.async.AsyncStopRequest;

import java.util.List;
import java.util.Set;

import static org.elasticsearch.rest.RestRequest.Method.POST;
import static org.elasticsearch.xpack.esql.action.EsqlQueryResponse.DROP_NULL_COLUMNS_OPTION;
import static org.elasticsearch.xpack.esql.formatter.TextFormat.URL_PARAM_DELIMITER;

@ServerlessScope(Scope.PUBLIC)
public class RestEsqlStopAsyncAction extends BaseRestHandler {
@Override
public List<Route> routes() {
return List.of(new Route(POST, "/_query/async/{id}/stop"));
}

@Override
public String getName() {
return "esql_async_stop";
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) {
AsyncStopRequest get = new AsyncStopRequest(request.param("id"));
return channel -> client.execute(EsqlAsyncStopAction.INSTANCE, get, new RestRefCountedChunkedToXContentListener<>(channel));
}

@Override
protected Set<String> responseParams() {
return Set.of(URL_PARAM_DELIMITER, DROP_NULL_COLUMNS_OPTION);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.elasticsearch.xpack.esql.EsqlInfoTransportAction;
import org.elasticsearch.xpack.esql.EsqlUsageTransportAction;
import org.elasticsearch.xpack.esql.action.EsqlAsyncGetResultAction;
import org.elasticsearch.xpack.esql.action.EsqlAsyncStopAction;
import org.elasticsearch.xpack.esql.action.EsqlQueryAction;
import org.elasticsearch.xpack.esql.action.EsqlQueryRequestBuilder;
import org.elasticsearch.xpack.esql.action.EsqlResolveFieldsAction;
Expand All @@ -60,6 +61,7 @@
import org.elasticsearch.xpack.esql.action.RestEsqlDeleteAsyncResultAction;
import org.elasticsearch.xpack.esql.action.RestEsqlGetAsyncResultAction;
import org.elasticsearch.xpack.esql.action.RestEsqlQueryAction;
import org.elasticsearch.xpack.esql.action.RestEsqlStopAsyncAction;
import org.elasticsearch.xpack.esql.enrich.EnrichLookupOperator;
import org.elasticsearch.xpack.esql.execution.PlanExecutor;
import org.elasticsearch.xpack.esql.expression.ExpressionWritables;
Expand Down Expand Up @@ -150,7 +152,8 @@ public List<Setting<?>> getSettings() {
new ActionHandler<>(XPackUsageFeatureAction.ESQL, EsqlUsageTransportAction.class),
new ActionHandler<>(XPackInfoFeatureAction.ESQL, EsqlInfoTransportAction.class),
new ActionHandler<>(EsqlResolveFieldsAction.TYPE, EsqlResolveFieldsAction.class),
new ActionHandler<>(EsqlSearchShardsAction.TYPE, EsqlSearchShardsAction.class)
new ActionHandler<>(EsqlSearchShardsAction.TYPE, EsqlSearchShardsAction.class),
new ActionHandler<>(EsqlAsyncStopAction.INSTANCE, TransportEsqlAsyncStopAction.class)
);
}

Expand All @@ -170,6 +173,7 @@ public List<RestHandler> getRestHandlers(
new RestEsqlQueryAction(),
new RestEsqlAsyncQueryAction(),
new RestEsqlGetAsyncResultAction(),
new RestEsqlStopAsyncAction(),
new RestEsqlDeleteAsyncResultAction()
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public Writeable.Reader<EsqlQueryResponse> responseReader() {

/**
* Unwraps the exception in the case of failure. This keeps the exception types
* the same as the sync API, namely ParsingException and ParsingException.
* the same as the sync API, namely ParsingException and VerificationException.
*/
static <R> ActionListener<R> unwrapListener(ActionListener<R> listener) {
return new ActionListener<>() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.plugin;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.async.AsyncExecutionId;
import org.elasticsearch.xpack.core.async.AsyncStopRequest;
import org.elasticsearch.xpack.core.async.GetAsyncResultRequest;
import org.elasticsearch.xpack.esql.action.EsqlAsyncStopAction;
import org.elasticsearch.xpack.esql.action.EsqlQueryResponse;

import java.util.concurrent.TimeUnit;

/**
* This action will stop running async request and collect the results.
* If the request is already finished, it will do the same thing as the regular async get.
*/
public class TransportEsqlAsyncStopAction extends HandledTransportAction<AsyncStopRequest, EsqlQueryResponse> {

private final TransportEsqlQueryAction queryAction;
private final TransportEsqlAsyncGetResultsAction getResultsAction;
private final BlockFactory blockFactory;
private final ClusterService clusterService;
private final TransportService transportService;

@Inject
public TransportEsqlAsyncStopAction(
TransportService transportService,
ClusterService clusterService,
ActionFilters actionFilters,
TransportEsqlQueryAction queryAction,
TransportEsqlAsyncGetResultsAction getResultsAction,
BlockFactory blockFactory
) {
super(EsqlAsyncStopAction.NAME, transportService, actionFilters, AsyncStopRequest::new, EsExecutors.DIRECT_EXECUTOR_SERVICE);
this.queryAction = queryAction;
this.getResultsAction = getResultsAction;
this.blockFactory = blockFactory;
this.transportService = transportService;
this.clusterService = clusterService;
}

@Override
protected void doExecute(Task task, AsyncStopRequest request, ActionListener<EsqlQueryResponse> listener) {
AsyncExecutionId searchId = AsyncExecutionId.decode(request.getId());
DiscoveryNode node = clusterService.state().nodes().get(searchId.getTaskId().getNodeId());
if (clusterService.localNode().getId().equals(searchId.getTaskId().getNodeId()) || node == null) {
// Don't use original request ID here because base64 decoding may not need some padding, but we want to match the original ID
// for the map lookup
stopQueryAndReturnResult(task, searchId.getEncoded(), listener);
} else {
transportService.sendRequest(
node,
EsqlAsyncStopAction.NAME,
request,
new ActionListenerResponseHandler<>(listener, EsqlQueryResponse.reader(blockFactory), EsExecutors.DIRECT_EXECUTOR_SERVICE)
);
}
}

private void stopQueryAndReturnResult(Task task, String asyncId, ActionListener<EsqlQueryResponse> listener) {
var asyncListener = queryAction.getAsyncListener(asyncId);
if (asyncListener == null) {
// This should mean one of the two things: either bad request ID, or the query has already finished
// In both cases, let regular async get deal with it.
var getAsyncResultRequest = new GetAsyncResultRequest(asyncId);
// TODO: this should not be happening, but if the listener is not registered and the query is not finished,
// we give it some time to finish
getAsyncResultRequest.setWaitForCompletionTimeout(new TimeValue(1, TimeUnit.SECONDS));
getResultsAction.execute(task, getAsyncResultRequest, listener);
return;
}
asyncListener.addListener(listener);
// TODO: send the finish signal to the source
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.operator.exchange.ExchangeService;
Expand Down Expand Up @@ -69,6 +71,8 @@ public class TransportEsqlQueryAction extends HandledTransportAction<EsqlQueryRe
private final AsyncTaskManagementService<EsqlQueryRequest, EsqlQueryResponse, EsqlQueryTask> asyncTaskManagementService;
private final RemoteClusterService remoteClusterService;

private final Map<String, SubscribableListener<EsqlQueryResponse>> asyncListeners = ConcurrentCollections.newConcurrentMap();

@Inject
@SuppressWarnings("this-escape")
public TransportEsqlQueryAction(
Expand Down Expand Up @@ -153,7 +157,18 @@ private void doExecuteForked(Task task, EsqlQueryRequest request, ActionListener
public void execute(EsqlQueryRequest request, EsqlQueryTask task, ActionListener<EsqlQueryResponse> listener) {
// set EsqlExecutionInfo on async-search task so that it is accessible to GET _query/async while the query is still running
task.setExecutionInfo(createEsqlExecutionInfo(request));
ActionListener.run(listener, l -> innerExecute(task, request, l));
// If the request is async, we need to wrap the listener in a SubscribableListener so that we can collect the results from other
// endpoints
var subListener = new SubscribableListener<EsqlQueryResponse>();
String asyncExecutionId = task.getExecutionId().getEncoded();
// TODO: is runBefore correct here?
subListener.addListener(ActionListener.runBefore(listener, () -> asyncListeners.remove(asyncExecutionId)));
asyncListeners.put(asyncExecutionId, subListener);
ActionListener.run(subListener, l -> innerExecute(task, request, l));
}

public SubscribableListener<EsqlQueryResponse> getAsyncListener(String executionId) {
return asyncListeners.get(executionId);
}

private void innerExecute(Task task, EsqlQueryRequest request, ActionListener<EsqlQueryResponse> listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,7 @@ private static boolean isAsyncRelatedAction(String action) {
|| action.equals(TransportDeleteAsyncResultAction.TYPE.name())
|| action.equals(EqlAsyncActionNames.EQL_ASYNC_GET_RESULT_ACTION_NAME)
|| action.equals(EsqlAsyncActionNames.ESQL_ASYNC_GET_RESULT_ACTION_NAME)
|| action.equals(EsqlAsyncActionNames.ESQL_ASYNC_STOP_ACTION_NAME)
|| action.equals(SqlAsyncActionNames.SQL_ASYNC_GET_RESULT_ACTION_NAME);
}

Expand Down

0 comments on commit 037564d

Please sign in to comment.