Skip to content

Commit

Permalink
Ensure remote pipeline early termination
Browse files Browse the repository at this point in the history
  • Loading branch information
smalyshev committed Dec 12, 2024
1 parent b53e1f1 commit 514e22d
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ static TransportVersion def(int id) {
public static final TransportVersion SOURCE_MODE_TELEMETRY = def(8_802_00_0);
public static final TransportVersion NEW_REFRESH_CLUSTER_BLOCK = def(8_803_00_0);
public static final TransportVersion RETRIES_AND_OPERATIONS_IN_BLOBSTORE_STATS = def(8_804_00_0);
public static final TransportVersion COMPUTE_RESPONSE_PARTIAL = def(8_805_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

package org.elasticsearch.compute.operator;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.SubscribableListener;
Expand Down Expand Up @@ -174,6 +176,8 @@ public DriverContext driverContext() {
return driverContext;
}

private static final Logger LOGGER = LogManager.getLogger(Driver.class);

/**
* Runs computations on the chain of operators for a given maximum amount of time or iterations.
* Returns a blocked future when the chain of operators is blocked, allowing the caller
Expand Down Expand Up @@ -239,6 +243,7 @@ private void checkForEarlyTermination() throws DriverEarlyTerminationException {
for (int i = activeOperators.size() - 2; i >= 0; i--) {
Operator op = activeOperators.get(i);
if (op.isFinished() == false) {
LOGGER.debug("Early terminated!");
throw new DriverEarlyTerminationException();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.compute.operator.exchange;

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.compute.data.BlockFactory;
Expand Down Expand Up @@ -41,6 +42,7 @@ public final class ExchangeSinkHandler {
private final LongSupplier nowInMillis;
private final AtomicLong lastUpdatedInMillis;
private final BlockFactory blockFactory;
private final SetOnce<ExchangeSourceHandler> source = new SetOnce<>();

public ExchangeSinkHandler(BlockFactory blockFactory, int maxBufferSize, LongSupplier nowInMillis) {
this.blockFactory = blockFactory;
Expand Down Expand Up @@ -98,6 +100,10 @@ public IsBlockedResult waitForWriting() {
public void fetchPageAsync(boolean sourceFinished, ActionListener<ExchangeResponse> listener) {
if (sourceFinished) {
buffer.finish(true);
var subSource = source.get();
if (subSource != null) {
subSource.finishEarly(true, ActionListener.noop());
}
}
listeners.add(listener);
onChanged();
Expand Down Expand Up @@ -150,6 +156,10 @@ private void notifyListeners() {
}
}

public void setSource(ExchangeSourceHandler sub) {
source.set(sub);
}

/**
* Create a new exchange sink for exchanging data
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public ExchangeSourceHandler(int maxBufferSize, Executor fetchExecutor, ActionLi
}

public void onFinishEarly(Runnable finishEarlyHandler) {
// TODO: not sure this is the best way but we need to know when the exchange source is finished early to set exec info
this.finishEarlyHandler = finishEarlyHandler;
}

Expand Down Expand Up @@ -320,6 +321,10 @@ public void finishEarly(boolean drainingPages, ActionListener<Void> listener) {
finishEarlyHandler.run();
}
buffer.finish(drainingPages);
if (remoteSinks.isEmpty()) {
listener.onResponse(null);
return;
}
try (EsqlRefCountingListener refs = new EsqlRefCountingListener(listener)) {
for (RemoteSink remoteSink : remoteSinks.values()) {
remoteSink.close(refs.acquire());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.core.Predicates;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.rest.action.RestActions;
import org.elasticsearch.transport.NoSuchRemoteClusterException;
import org.elasticsearch.transport.RemoteClusterAware;
import org.elasticsearch.transport.RemoteClusterService;
import org.elasticsearch.xcontent.ParseField;
Expand Down Expand Up @@ -189,7 +190,7 @@ public Set<String> clusterAliases() {
/**
* @param clusterAlias to lookup skip_unavailable from
* @return skip_unavailable setting (true/false)
* @throws org.elasticsearch.transport.NoSuchRemoteClusterException if clusterAlias is unknown to this node's RemoteClusterService
* @throws NoSuchRemoteClusterException if clusterAlias is unknown to this node's RemoteClusterService
*/
public boolean isSkipUnavailable(String clusterAlias) {
if (RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY.equals(clusterAlias)) {
Expand Down Expand Up @@ -279,10 +280,20 @@ public boolean isPartial() {
return isPartial;
}

public void setPartial() {
/**
* Mark the query as having partial results.
*/
public void markAsPartial() {
isPartial = true;
}

/**
* Mark this cluster as having partial results.
*/
public void markClusterAsPartial(String clusterAlias) {
swapCluster(clusterAlias, (k, v) -> new Cluster.Builder(v).setStatus(Cluster.Status.PARTIAL).build());
}

/**
* Represents the search metadata about a particular cluster involved in a cross-cluster search.
* The Cluster object can represent either the local cluster or a remote cluster.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ private ComputeListener(
cluster.getTotalShards(),
cluster.getSuccessfulShards(),
cluster.getSkippedShards(),
cluster.getFailedShards()
cluster.getFailedShards(),
cluster.getStatus() == EsqlExecutionInfo.Cluster.Status.PARTIAL
);
} else {
result = new ComputeResponse(collectedProfiles.isEmpty() ? List.of() : collectedProfiles.stream().toList());
Expand All @@ -135,10 +136,14 @@ private ComputeListener(

private static void setFinalStatusAndShardCounts(String clusterAlias, EsqlExecutionInfo executionInfo) {
executionInfo.swapCluster(clusterAlias, (k, v) -> {
// TODO: once PARTIAL status is supported (partial results work to come), modify this code as needed
if (v.getStatus() != EsqlExecutionInfo.Cluster.Status.SKIPPED) {
assert v.getTotalShards() != null && v.getSkippedShards() != null : "Null total or skipped shard count: " + v;
return new EsqlExecutionInfo.Cluster.Builder(v).setStatus(EsqlExecutionInfo.Cluster.Status.SUCCESSFUL)
EsqlExecutionInfo.Cluster.Status newStatus = v.getStatus();
// Do not update the status if it is already set to e.g. PARTIAL
if (newStatus == EsqlExecutionInfo.Cluster.Status.RUNNING) {
newStatus = EsqlExecutionInfo.Cluster.Status.SUCCESSFUL;
}
return new EsqlExecutionInfo.Cluster.Builder(v).setStatus(newStatus)
/*
* Total and skipped shard counts are set early in execution (after can-match).
* Until ES|QL supports shard-level partial results, we just set all non-skipped shards
Expand Down Expand Up @@ -244,15 +249,16 @@ ActionListener<ComputeResponse> acquireCompute(@Nullable String computeClusterAl

private void updateExecutionInfoWithRemoteResponse(String computeClusterAlias, ComputeResponse resp) {
TimeValue tookOnCluster;
EsqlExecutionInfo.Cluster.Status resultStatus = resp.isPartial()
? EsqlExecutionInfo.Cluster.Status.PARTIAL
: EsqlExecutionInfo.Cluster.Status.SUCCESSFUL;
if (resp.getTook() != null) {
TimeValue remoteExecutionTime = resp.getTook();
TimeValue planningTookTime = esqlExecutionInfo.planningTookTime();
tookOnCluster = new TimeValue(planningTookTime.nanos() + remoteExecutionTime.nanos(), TimeUnit.NANOSECONDS);
esqlExecutionInfo.swapCluster(
computeClusterAlias,
(k, v) -> new EsqlExecutionInfo.Cluster.Builder(v)
// for now ESQL doesn't return partial results, so set status to SUCCESSFUL
.setStatus(EsqlExecutionInfo.Cluster.Status.SUCCESSFUL)
(k, v) -> new EsqlExecutionInfo.Cluster.Builder(v).setStatus(resultStatus)
.setTook(tookOnCluster)
.setTotalShards(resp.getTotalShards())
.setSuccessfulShards(resp.getSuccessfulShards())
Expand All @@ -267,11 +273,7 @@ private void updateExecutionInfoWithRemoteResponse(String computeClusterAlias, C
tookOnCluster = new TimeValue(remoteTook, TimeUnit.NANOSECONDS);
esqlExecutionInfo.swapCluster(
computeClusterAlias,
(k, v) -> new EsqlExecutionInfo.Cluster.Builder(v)
// for now ESQL doesn't return partial results, so set status to SUCCESSFUL
.setStatus(EsqlExecutionInfo.Cluster.Status.SUCCESSFUL)
.setTook(tookOnCluster)
.build()
(k, v) -> new EsqlExecutionInfo.Cluster.Builder(v).setStatus(resultStatus).setTook(tookOnCluster).build()
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ final class ComputeResponse extends TransportResponse {
public final int successfulShards;
public final int skippedShards;
public final int failedShards;
public final boolean isPartial;

ComputeResponse(List<DriverProfile> profiles) {
this(profiles, null, null, null, null, null);
this(profiles, null, null, null, null, null, false);
}

ComputeResponse(
Expand All @@ -40,14 +41,16 @@ final class ComputeResponse extends TransportResponse {
Integer totalShards,
Integer successfulShards,
Integer skippedShards,
Integer failedShards
Integer failedShards,
boolean isPartial
) {
this.profiles = profiles;
this.took = took;
this.totalShards = totalShards == null ? 0 : totalShards.intValue();
this.successfulShards = successfulShards == null ? 0 : successfulShards.intValue();
this.skippedShards = skippedShards == null ? 0 : skippedShards.intValue();
this.failedShards = failedShards == null ? 0 : failedShards.intValue();
this.isPartial = isPartial;
}

ComputeResponse(StreamInput in) throws IOException {
Expand All @@ -74,6 +77,11 @@ final class ComputeResponse extends TransportResponse {
this.skippedShards = 0;
this.failedShards = 0;
}
if (in.getTransportVersion().onOrAfter(TransportVersions.COMPUTE_RESPONSE_PARTIAL)) {
this.isPartial = in.readBoolean();
} else {
this.isPartial = false;
}
}

@Override
Expand All @@ -93,6 +101,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(skippedShards);
out.writeVInt(failedShards);
}
if (out.getTransportVersion().onOrAfter(TransportVersions.COMPUTE_RESPONSE_PARTIAL)) {
out.writeBoolean(isPartial);
}
}

public List<DriverProfile> getProfiles() {
Expand All @@ -118,4 +129,8 @@ public int getSkippedShards() {
public int getFailedShards() {
return failedShards;
}

public boolean isPartial() {
return isPartial;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,7 @@ public void execute(
transportService.getThreadPool().executor(ThreadPool.Names.SEARCH),
ActionListener.runBefore(computeListener.acquireAvoid(), () -> exchangeService.removeExchangeSourceHandler(sessionId))
);
exchangeSource.onFinishEarly(() -> {
execInfo.setPartial();
});
exchangeSource.onFinishEarly(execInfo::markAsPartial);
exchangeService.addExchangeSourceHandler(sessionId, exchangeSource);
try (Releasable ignored = exchangeSource.addEmptySink()) {
// run compute on the coordinator
Expand Down Expand Up @@ -802,6 +800,7 @@ private void runComputeOnDataNode(
task.addListener(() -> exchangeService.finishSinkHandler(externalId, new TaskCancelledException(task.getReasonCancelled())));
var exchangeSource = new ExchangeSourceHandler(1, esqlExecutor, computeListener.acquireAvoid());
exchangeSource.addRemoteSink(internalSink::fetchPageAsync, true, 1, ActionListener.noop());
externalSink.setSource(exchangeSource);
ActionListener<ComputeResponse> reductionListener = computeListener.acquireCompute();
runCompute(
task,
Expand Down Expand Up @@ -940,6 +939,8 @@ void runComputeOnRemoteCluster(
transportService.getThreadPool().executor(ThreadPool.Names.SEARCH),
computeListener.acquireAvoid()
);
exchangeSink.setSource(exchangeSource);
exchangeSource.onFinishEarly(() -> executionInfo.markClusterAsPartial(clusterAlias));
try (Releasable ignored = exchangeSource.addEmptySink()) {
exchangeSink.addCompletionListener(computeListener.acquireAvoid());
runCompute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ private ComputeResponse randomResponse(boolean includeExecutionInfo) {
10,
10,
randomIntBetween(0, 3),
0
0,
false
);
} else {
return new ComputeResponse(profiles);
Expand Down

0 comments on commit 514e22d

Please sign in to comment.