Skip to content

Commit

Permalink
[ML] Reduce chance of timeout in serverless ML autoscaling
Browse files Browse the repository at this point in the history
If ML serverless autoscaling fails to return a response within
the configured timeout period then the control plane autoscaler
will log an error. Too many of these errors will raise an alert,
therefore as much as possible should be done on the ML side to
_not_ time out.

Previously there were two possible causes of timeouts:

1. If a request for node stats from all ML nodes timed out
2. If a request to refresh the ML memory tracker timed out

The first case can happen if a node leaves the cluster at a bad
time and the message sent to it gets lost. The second case can
happen if searching the ML results indices for model size stats
documents is slow.

We can avoid timeouts in these two situations as follows:

1. There was no need to use the API to get the only value from
   the node stats that the autoscaler needs to know - the total
   amount of memory on each ML node is stored in a node attribute
   on startup so exists in cluster state
2. When we refresh the ML memory tracker we can just return stats
   that instruct the autoscaler to do nothing until the refresh
   is complete - this is functionally the same as timing out each
   request, but without generating error messages
  • Loading branch information
droberts195 committed Oct 16, 2023
1 parent 31736fc commit 30926e5
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 256 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,9 @@

package org.elasticsearch.xpack.ml.action;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.ListenerTimeouts;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.ParentTaskAssigningClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
Expand All @@ -22,10 +18,7 @@
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.GetMlAutoscalingStats;
Expand All @@ -41,7 +34,6 @@
*/
public class TransportGetMlAutoscalingStats extends TransportMasterNodeAction<Request, Response> {

private final Client client;
private final MlMemoryTracker mlMemoryTracker;
private final Settings settings;
private final Executor timeoutExecutor;
Expand All @@ -53,7 +45,6 @@ public TransportGetMlAutoscalingStats(
ThreadPool threadPool,
ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver,
Client client,
Settings settings,
MlMemoryTracker mlMemoryTracker
) {
Expand All @@ -68,63 +59,27 @@ public TransportGetMlAutoscalingStats(
Response::new,
EsExecutors.DIRECT_EXECUTOR_SERVICE
);
this.client = client;
this.mlMemoryTracker = mlMemoryTracker;
this.settings = settings;
this.timeoutExecutor = threadPool.generic();
}

@Override
protected void masterOperation(Task task, Request request, ClusterState state, ActionListener<Response> listener) {
TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId());
ParentTaskAssigningClient parentTaskAssigningClient = new ParentTaskAssigningClient(client, parentTaskId);

if (mlMemoryTracker.isRecentlyRefreshed()) {
MlAutoscalingResourceTracker.getMlAutoscalingStats(
state,
clusterService.getClusterSettings(),
parentTaskAssigningClient,
request.timeout(),
mlMemoryTracker,
settings,
ActionListener.wrap(autoscalingResources -> listener.onResponse(new Response(autoscalingResources)), listener::onFailure)
);
} else {
// recent memory statistics aren't available at the moment, trigger a refresh,
// if a refresh has been triggered before, this will wait until refresh has happened
// on busy cluster with many jobs this could take a while, therefore timeout and return a 408 in case
mlMemoryTracker.refresh(
state.getMetadata().custom(PersistentTasksCustomMetadata.TYPE),
ListenerTimeouts.wrapWithTimeout(
threadPool,
request.timeout(),
timeoutExecutor,
ActionListener.wrap(
ignored -> MlAutoscalingResourceTracker.getMlAutoscalingStats(
state,
clusterService.getClusterSettings(),
parentTaskAssigningClient,
request.timeout(),
mlMemoryTracker,
settings,
ActionListener.wrap(
autoscalingResources -> listener.onResponse(new Response(autoscalingResources)),
listener::onFailure
)
),
listener::onFailure
),
timeoutTrigger -> {
// Timeout triggered
listener.onFailure(
new ElasticsearchStatusException(
"ML autoscaling metrics could not be retrieved in time, but should be available shortly.",
RestStatus.REQUEST_TIMEOUT
)
);
}
)
);
// Recent memory statistics aren't available at the moment, trigger a refresh and return a no-scale.
// (If a refresh is already in progress, this won't trigger a new one.)
mlMemoryTracker.asyncRefresh();
listener.onResponse(new Response(MlAutoscalingResourceTracker.noScaleStats(state)));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,25 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.monitor.os.OsStats;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
import org.elasticsearch.xpack.core.ml.autoscaling.MlAutoscalingStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
import org.elasticsearch.xpack.ml.utils.MlProcessors;
import org.elasticsearch.xpack.ml.utils.NativeMemoryCalculator;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -67,93 +63,58 @@ private MlAutoscalingResourceTracker() {}
public static void getMlAutoscalingStats(
ClusterState clusterState,
ClusterSettings clusterSettings,
Client client,
TimeValue timeout,
MlMemoryTracker mlMemoryTracker,
Settings settings,
ActionListener<MlAutoscalingStats> listener
) {
String[] mlNodes = clusterState.nodes()
Map<String, Long> nodeSizeByMlNode = clusterState.nodes()
.stream()
.filter(node -> node.getRoles().contains(DiscoveryNodeRole.ML_ROLE))
.map(DiscoveryNode::getId)
.toArray(String[]::new);
.collect(Collectors.toMap(DiscoveryNode::getId, node -> NodeLoadDetector.getNodeSize(node).orElse(0L)));

String firstMlNode = (nodeSizeByMlNode.size() > 0) ? nodeSizeByMlNode.keySet().iterator().next() : null;

// the next 2 values are only used iff > 0 and iff all nodes have the same container size
long modelMemoryAvailableFirstNode = mlNodes.length > 0
? NativeMemoryCalculator.allowedBytesForMl(clusterState.nodes().get(mlNodes[0]), settings).orElse(0L)
long modelMemoryAvailableFirstNode = (firstMlNode != null)
? NativeMemoryCalculator.allowedBytesForMl(clusterState.nodes().get(firstMlNode), settings).orElse(0L)
: 0L;
int processorsAvailableFirstNode = mlNodes.length > 0
? MlProcessors.get(clusterState.nodes().get(mlNodes[0]), clusterSettings.get(MachineLearning.ALLOCATED_PROCESSORS_SCALE))
int processorsAvailableFirstNode = (firstMlNode != null)
? MlProcessors.get(clusterState.nodes().get(firstMlNode), clusterSettings.get(MachineLearning.ALLOCATED_PROCESSORS_SCALE))
.roundDown()
: 0;

// Todo: MAX_LOW_PRIORITY_MODELS_PER_NODE not checked yet
int maxOpenJobsPerNode = MAX_OPEN_JOBS_PER_NODE.get(settings);

getMlNodeStats(
mlNodes,
client,
timeout,
ActionListener.wrap(
osStatsPerNode -> getMemoryAndProcessors(
new MlAutoscalingContext(clusterState),
mlMemoryTracker,
osStatsPerNode,
modelMemoryAvailableFirstNode,
processorsAvailableFirstNode,
maxOpenJobsPerNode,
listener
),
listener::onFailure
)
getMemoryAndProcessors(
new MlAutoscalingContext(clusterState),
mlMemoryTracker,
nodeSizeByMlNode,
modelMemoryAvailableFirstNode,
processorsAvailableFirstNode,
maxOpenJobsPerNode,
listener
);
}

static void getMlNodeStats(String[] mlNodes, Client client, TimeValue timeout, ActionListener<Map<String, OsStats>> listener) {

// if the client is configured with no nodes, it automatically calls all
if (mlNodes.length == 0) {
listener.onResponse(Collections.emptyMap());
return;
}

client.admin()
.cluster()
.prepareNodesStats(mlNodes)
.clear()
.setOs(true)
.setTimeout(timeout)
.execute(
ActionListener.wrap(
nodesStatsResponse -> listener.onResponse(
nodesStatsResponse.getNodes()
.stream()
.collect(Collectors.toMap(nodeStats -> nodeStats.getNode().getId(), NodeStats::getOs))
),
listener::onFailure
)
);
}

static void getMemoryAndProcessors(
MlAutoscalingContext autoscalingContext,
MlMemoryTracker mlMemoryTracker,
Map<String, OsStats> osStatsPerNode,
Map<String, Long> nodeSizeByMlNode,
long perNodeAvailableModelMemoryInBytes,
int perNodeAvailableProcessors,
int maxOpenJobsPerNode,
ActionListener<MlAutoscalingStats> listener
) {
Map<String, List<MlJobRequirements>> perNodeModelMemoryInBytes = new HashMap<>();

int numberMlNodes = nodeSizeByMlNode.size();

// If the ML nodes in the cluster have different sizes, return 0.
// Otherwise, return the size, in bytes, of the container size of the ML nodes for a single container.
long perNodeMemoryInBytes = osStatsPerNode.values()
.stream()
.map(s -> s.getMem().getAdjustedTotal().getBytes())
.distinct()
.count() != 1 ? 0 : osStatsPerNode.values().iterator().next().getMem().getAdjustedTotal().getBytes();
long perNodeMemoryInBytes = nodeSizeByMlNode.values().stream().distinct().count() != 1
? 0L
: nodeSizeByMlNode.values().iterator().next();

long modelMemoryBytesSum = 0;
long extraSingleNodeModelMemoryInBytes = 0;
Expand Down Expand Up @@ -297,8 +258,8 @@ static void getMemoryAndProcessors(
&& perNodeAvailableModelMemoryInBytes > 0
&& extraModelMemoryInBytes == 0
&& extraProcessors == 0
&& modelMemoryBytesSum <= perNodeMemoryInBytes * (osStatsPerNode.size() - 1)
&& (perNodeModelMemoryInBytes.size() < osStatsPerNode.size() // a node has no assigned jobs
&& modelMemoryBytesSum <= perNodeMemoryInBytes * (numberMlNodes - 1)
&& (perNodeModelMemoryInBytes.size() < numberMlNodes // a node has no assigned jobs
|| checkIfOneNodeCouldBeRemoved(
perNodeModelMemoryInBytes,
perNodeAvailableModelMemoryInBytes,
Expand All @@ -310,7 +271,7 @@ static void getMemoryAndProcessors(

listener.onResponse(
new MlAutoscalingStats(
osStatsPerNode.size(),
numberMlNodes,
perNodeMemoryInBytes,
modelMemoryBytesSum,
processorsSum,
Expand All @@ -325,6 +286,26 @@ static void getMemoryAndProcessors(
);
}

/**
* Return some autoscaling stats that tell the autoscaler not to change anything, but without making it think an error has occurred.
*/
public static MlAutoscalingStats noScaleStats(ClusterState clusterState) {
int numberMlNodes = (int) clusterState.nodes().stream().filter(node -> node.getRoles().contains(DiscoveryNodeRole.ML_ROLE)).count();
return new MlAutoscalingStats(
numberMlNodes,
0,
0,
0,
Math.min(3, numberMlNodes),
0,
0,
0,
0,
0,
MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes()
);
}

/**
* Check if one node can be removed by placing the jobs of the least loaded node to others.
*
Expand Down
Loading

0 comments on commit 30926e5

Please sign in to comment.