Skip to content

Commit

Permalink
Support cancellation for admin apis
Browse files Browse the repository at this point in the history
  • Loading branch information
aasom143 committed Jun 4, 2024
1 parent 8ac6806 commit da592e1
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@

import org.opensearch.action.admin.indices.stats.CommonStatsFlags;
import org.opensearch.action.support.nodes.BaseNodesRequest;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.tasks.TaskId;
import org.opensearch.rest.action.admin.cluster.ClusterTask;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.*;
import java.util.stream.Collectors;

/**
Expand All @@ -57,6 +57,8 @@ public class NodesStatsRequest extends BaseNodesRequest<NodesStatsRequest> {
private CommonStatsFlags indices = new CommonStatsFlags();
private final Set<String> requestedMetrics = new HashSet<>();

private TimeValue cancelAfterTimeInterval;

public NodesStatsRequest() {
super((String[]) null);
}
Expand Down Expand Up @@ -95,6 +97,20 @@ public NodesStatsRequest clear() {
return this;
}

@Override
public ClusterTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new ClusterTask(id, type, action, parentTaskId, headers, cancelAfterTimeInterval);
}

public void setCancelAfterTimeInterval(TimeValue cancelAfterTimeInterval) {
this.cancelAfterTimeInterval = cancelAfterTimeInterval;
}

public TimeValue getCancelAfterTimeInterval() {
return cancelAfterTimeInterval;
}


/**
* Get indices. Handles separately from other metrics because it may or
* may not have submetrics.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.core.common.io.stream.StreamInput;
Expand Down Expand Up @@ -62,6 +63,29 @@ public class TransportNodesStatsAction extends TransportNodesAction<
private final NodeService nodeService;

@Inject
public TransportNodesStatsAction(
NodeClient client,
ThreadPool threadPool,
ClusterService clusterService,
TransportService transportService,
NodeService nodeService,
ActionFilters actionFilters
) {
super(
client,
NodesStatsAction.NAME,
threadPool,
clusterService,
transportService,
actionFilters,
NodesStatsRequest::new,
NodeStatsRequest::new,
ThreadPool.Names.MANAGEMENT,
NodeStats.class
);
this.nodeService = nodeService;
}

public TransportNodesStatsAction(
ThreadPool threadPool,
ClusterService clusterService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,17 @@
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.TimeoutTaskCancellationUtility;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.tasks.TaskCancelledException;
import org.opensearch.node.Node;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.NodeShouldNotConnectException;
Expand Down Expand Up @@ -81,6 +86,8 @@ public abstract class TransportNodesAction<

private final String finalExecutor;

private final NodeClient client;

/**
* @param actionName action name
* @param threadPool thread-pool
Expand All @@ -94,6 +101,7 @@ public abstract class TransportNodesAction<
* @param nodeResponseClass class of the node responses
*/
protected TransportNodesAction(
NodeClient client,
String actionName,
ThreadPool threadPool,
ClusterService clusterService,
Expand All @@ -113,6 +121,31 @@ protected TransportNodesAction(

this.transportNodeAction = actionName + "[n]";
this.finalExecutor = finalExecutor;
this.client = client;
transportService.registerRequestHandler(transportNodeAction, nodeExecutor, nodeRequest, new NodeTransportHandler());
}

protected TransportNodesAction(
String actionName,
ThreadPool threadPool,
ClusterService clusterService,
TransportService transportService,
ActionFilters actionFilters,
Writeable.Reader<NodesRequest> request,
Writeable.Reader<NodeRequest> nodeRequest,
String nodeExecutor,
String finalExecutor,
Class<NodeResponse> nodeResponseClass
) {
super(actionName, transportService, actionFilters, request);
this.threadPool = threadPool;
this.clusterService = Objects.requireNonNull(clusterService);
this.transportService = Objects.requireNonNull(transportService);
this.nodeResponseClass = Objects.requireNonNull(nodeResponseClass);

this.transportNodeAction = actionName + "[n]";
this.finalExecutor = finalExecutor;
this.client = null;
transportService.registerRequestHandler(transportNodeAction, nodeExecutor, nodeRequest, new NodeTransportHandler());
}

Expand All @@ -123,6 +156,33 @@ protected TransportNodesAction(
* This constructor should only be used for actions for which the creation of the final response is fast enough to be safely executed
* on a transport thread.
*/
protected TransportNodesAction(
NodeClient client,
String actionName,
ThreadPool threadPool,
ClusterService clusterService,
TransportService transportService,
ActionFilters actionFilters,
Writeable.Reader<NodesRequest> request,
Writeable.Reader<NodeRequest> nodeRequest,
String nodeExecutor,
Class<NodeResponse> nodeResponseClass
) {
this(
client,
actionName,
threadPool,
clusterService,
transportService,
actionFilters,
request,
nodeRequest,
nodeExecutor,
ThreadPool.Names.SAME,
nodeResponseClass
);
}

protected TransportNodesAction(
String actionName,
ThreadPool threadPool,
Expand Down Expand Up @@ -150,6 +210,16 @@ protected TransportNodesAction(

@Override
protected void doExecute(Task task, NodesRequest request, ActionListener<NodesResponse> listener) {

if (task instanceof CancellableTask){
listener = TimeoutTaskCancellationUtility.wrapWithCancellationListener(
client,
(CancellableTask) task,
clusterService.getClusterSettings(),
listener
);
}

new AsyncAction(task, request, listener).start();
}

Expand Down Expand Up @@ -256,6 +326,9 @@ void start() {
final DiscoveryNode node = nodes[i];
final String nodeId = node.getId();
try {
if (task instanceof CancellableTask && ((CancellableTask) task).isCancelled()){
throw new TaskCancelledException("cancelled task with reason: " + ((CancellableTask) task).getReasonCancelled());
}
TransportRequest nodeRequest = newNodeRequest(request);
if (task != null) {
nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.rest.action.admin.cluster;

import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.tasks.TaskId;
import org.opensearch.tasks.CancellableTask;

import java.util.Map;

import static org.opensearch.search.SearchService.NO_TIMEOUT;
/**
* Task storing information about a currently running ClusterRequest.
*
* @opensearch.api
*/
@PublicApi(since = "1.0.0")
public class ClusterTask extends CancellableTask {

public ClusterTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
this(id, type, action, parentTaskId, headers, NO_TIMEOUT);
}

public ClusterTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers, TimeValue cancelAfterTimeInterval) {
super(id, type, action, null, parentTaskId, headers, cancelAfterTimeInterval);
}

@Override
public boolean shouldCancelChildrenOnCancellation() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.opensearch.action.admin.indices.stats.CommonStatsFlags.Flag;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.cache.CacheType;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.common.Strings;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
Expand Down Expand Up @@ -110,6 +111,7 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC

NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(nodesIds);
nodesStatsRequest.timeout(request.param("timeout"));
nodesStatsRequest.setCancelAfterTimeInterval(request.paramAsTime("cancel_after_time_interval", TimeValue.timeValueSeconds(30)));

if (metrics.size() == 1 && metrics.contains("_all")) {
if (request.hasParam("index_metric")) {
Expand Down

0 comments on commit da592e1

Please sign in to comment.