diff --git a/docs/changelog/87366.yaml b/docs/changelog/87366.yaml new file mode 100644 index 0000000000000..0b2881e7c4778 --- /dev/null +++ b/docs/changelog/87366.yaml @@ -0,0 +1,5 @@ +pr: 87366 +summary: Improve scalability of NLP models +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/87735.yaml b/docs/changelog/87735.yaml new file mode 100644 index 0000000000000..43f296f1411ee --- /dev/null +++ b/docs/changelog/87735.yaml @@ -0,0 +1,5 @@ +pr: 87735 +summary: Use desired nodes during data tier allocation decisions +area: Allocation +type: enhancement +issues: [] diff --git a/docs/reference/ml/anomaly-detection/functions/ml-time-functions.asciidoc b/docs/reference/ml/anomaly-detection/functions/ml-time-functions.asciidoc index d67dfe084e11e..096fd817ccc4c 100644 --- a/docs/reference/ml/anomaly-detection/functions/ml-time-functions.asciidoc +++ b/docs/reference/ml/anomaly-detection/functions/ml-time-functions.asciidoc @@ -2,8 +2,8 @@ = Time functions The time functions detect events that happen at unusual times, either of the day -or of the week. These functions can be used to find unusual patterns of behavior, -typically associated with suspicious user activity. +or of the week. These functions can be used to find unusual patterns of +behavior, typically associated with suspicious user activity. The {ml-features} include the following time functions: @@ -77,6 +77,12 @@ its past behavior. The `time_of_week` function detects when events occur that are outside normal usage patterns. For example, it detects login events on the weekend. +IMPORTANT: The `time_of_week` function models time in epoch seconds modulo the + duration of a week in seconds. It means that the `typical` and `actual` values + are seconds after a whole number of weeks since 1/1/1970 in UTC which is a + Thursday. For example, a value of `475` is 475 seconds after midnight on + Thursday in UTC. + This function supports the following properties: * `by_field_name` (optional) @@ -102,3 +108,5 @@ models when events occur throughout the week for each `eventcode`. It detects when a workstation event occurs at an unusual time during the week for that `eventcode` compared to other workstations. It detects events for a particular workstation that are outside the normal usage pattern. + + diff --git a/libs/core/src/main/java/org/elasticsearch/core/IOUtils.java b/libs/core/src/main/java/org/elasticsearch/core/IOUtils.java index 0398418e503bc..fc2b10bf5b480 100644 --- a/libs/core/src/main/java/org/elasticsearch/core/IOUtils.java +++ b/libs/core/src/main/java/org/elasticsearch/core/IOUtils.java @@ -31,8 +31,6 @@ import java.nio.file.Path; import java.nio.file.StandardOpenOption; import java.nio.file.attribute.BasicFileAttributes; -import java.util.Arrays; -import java.util.Collection; import java.util.LinkedHashMap; import java.util.Map; @@ -62,7 +60,7 @@ private IOUtils() { * @param objects objects to close */ public static void close(final Closeable... objects) throws IOException { - close(null, Arrays.asList(objects)); + close(null, objects); } /** @@ -82,8 +80,28 @@ public static void close(@Nullable Closeable closeable) throws IOException { * * @param objects objects to close */ - public static void close(final Exception e, final Closeable... objects) throws IOException { - close(e, Arrays.asList(objects)); + public static void close(final Exception ex, final Closeable... objects) throws IOException { + Exception firstException = ex; + for (final Closeable object : objects) { + try { + close(object); + } catch (final IOException | RuntimeException e) { + firstException = addOrSuppress(firstException, e); + } + } + + if (firstException != null) { + throwRuntimeOrIOException(firstException); + } + } + + private static void throwRuntimeOrIOException(Exception firstException) throws IOException { + if (firstException instanceof IOException) { + throw (IOException) firstException; + } else { + // since we only assigned an IOException or a RuntimeException to ex above, in this case ex must be a RuntimeException + throw (RuntimeException) firstException; + } } /** @@ -95,50 +113,38 @@ public static void close(final Exception e, final Closeable... objects) throws I * @param objects objects to close */ public static void close(final Iterable objects) throws IOException { - close(null, objects); - } - - /** - * Closes all given {@link Closeable}s. If a non-null exception is passed in, or closing a - * stream causes an exception, throws the exception with other {@link RuntimeException} or - * {@link IOException} exceptions added as suppressed. - * - * @param ex existing Exception to add exceptions occurring during close to - * @param objects objects to close - * - * @see #close(Closeable...) - */ - public static void close(final Exception ex, final Iterable objects) throws IOException { - Exception firstException = ex; + Exception firstException = null; for (final Closeable object : objects) { try { close(object); } catch (final IOException | RuntimeException e) { - if (firstException == null) { - firstException = e; - } else { - firstException.addSuppressed(e); - } + firstException = addOrSuppress(firstException, e); } } if (firstException != null) { - if (firstException instanceof IOException) { - throw (IOException) firstException; - } else { - // since we only assigned an IOException or a RuntimeException to ex above, in this case ex must be a RuntimeException - throw (RuntimeException) firstException; - } + throwRuntimeOrIOException(firstException); } } + private static Exception addOrSuppress(Exception firstException, Exception e) { + if (firstException == null) { + firstException = e; + } else { + firstException.addSuppressed(e); + } + return firstException; + } + /** * Closes all given {@link Closeable}s, suppressing all thrown exceptions. Some of the {@link Closeable}s may be null, they are ignored. * * @param objects objects to close */ public static void closeWhileHandlingException(final Closeable... objects) { - closeWhileHandlingException(Arrays.asList(objects)); + for (final Closeable object : objects) { + closeWhileHandlingException(object); + } } /** @@ -170,15 +176,6 @@ public static void closeWhileHandlingException(final Closeable closeable) { * @param files the paths of files to delete */ public static void deleteFilesIgnoringExceptions(final Path... files) { - deleteFilesIgnoringExceptions(Arrays.asList(files)); - } - - /** - * Deletes all given files, suppressing all thrown {@link IOException}s. Some of the files may be null, if so they are ignored. - * - * @param files the paths of files to delete - */ - public static void deleteFilesIgnoringExceptions(final Collection files) { for (final Path name : files) { if (name != null) { // noinspection EmptyCatchBlock diff --git a/libs/core/src/main/java/org/elasticsearch/core/Releasables.java b/libs/core/src/main/java/org/elasticsearch/core/Releasables.java index 5cd8794428b96..05e3d154929a7 100644 --- a/libs/core/src/main/java/org/elasticsearch/core/Releasables.java +++ b/libs/core/src/main/java/org/elasticsearch/core/Releasables.java @@ -10,29 +10,22 @@ import java.io.IOException; import java.io.UncheckedIOException; -import java.util.Arrays; import java.util.concurrent.atomic.AtomicBoolean; /** Utility methods to work with {@link Releasable}s. */ public enum Releasables { ; - private static void close(Iterable releasables, boolean ignoreException) { + /** Release the provided {@link Releasable}s. */ + public static void close(Iterable releasables) { try { // this does the right thing with respect to add suppressed and not wrapping errors etc. IOUtils.close(releasables); } catch (IOException e) { - if (ignoreException == false) { - throw new UncheckedIOException(e); - } + throw new UncheckedIOException(e); } } - /** Release the provided {@link Releasable}s. */ - public static void close(Iterable releasables) { - close(releasables, false); - } - /** Release the provided {@link Releasable}. */ public static void close(@Nullable Releasable releasable) { try { @@ -44,7 +37,7 @@ public static void close(@Nullable Releasable releasable) { /** Release the provided {@link Releasable}s. */ public static void close(Releasable... releasables) { - close(Arrays.asList(releasables)); + close(true, releasables); } /** Release the provided {@link Releasable}s expecting no exception to by thrown by any of them. */ @@ -69,17 +62,19 @@ public static void closeExpectNoException(Releasable releasable) { /** Release the provided {@link Releasable}s, ignoring exceptions. */ public static void closeWhileHandlingException(Releasable... releasables) { - close(Arrays.asList(releasables), true); + close(false, releasables); } /** Release the provided {@link Releasable}s, ignoring exceptions if success is {@code false}. */ - public static void close(boolean success, Iterable releasables) { - close(releasables, success == false); - } - - /** Release the provided {@link Releasable}s, ignoring exceptions if success is {@code false}. */ - public static void close(boolean success, Releasable... releasables) { - close(success, Arrays.asList(releasables)); + private static void close(boolean success, Releasable... releasables) { + try { + // this does the right thing with respect to add suppressed and not wrapping errors etc. + IOUtils.close(releasables); + } catch (IOException e) { + if (success) { + throw new UncheckedIOException(e); + } + } } /** Wrap several releasables into a single one. This is typically useful for use with try-with-resources: for example let's assume diff --git a/libs/core/src/test/java/org/elasticsearch/core/IOUtilsTests.java b/libs/core/src/test/java/org/elasticsearch/core/IOUtilsTests.java index 719c2e3ef4846..0d7576a56acb0 100644 --- a/libs/core/src/test/java/org/elasticsearch/core/IOUtilsTests.java +++ b/libs/core/src/test/java/org/elasticsearch/core/IOUtilsTests.java @@ -117,10 +117,6 @@ public void testDeleteFilesIgnoringExceptionsArray() throws IOException { runDeleteFilesIgnoringExceptionsTest(Function.identity(), IOUtils::deleteFilesIgnoringExceptions); } - public void testDeleteFilesIgnoringExceptionsIterable() throws IOException { - runDeleteFilesIgnoringExceptionsTest((Function>) Arrays::asList, IOUtils::deleteFilesIgnoringExceptions); - } - private void runDeleteFilesIgnoringExceptionsTest( final Function function, CheckedConsumer deleteFilesIgnoringExceptions diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/mget.json b/rest-api-spec/src/main/resources/rest-api-spec/api/mget.json index 1b771b772bf69..28542bd91a7cb 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/mget.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/mget.json @@ -35,6 +35,12 @@ ] }, "params":{ + "force_synthetic_source": { + "type": "boolean", + "description": "Should this request force synthetic _source? Use this to test if the mapping supports synthetic _source and to get a sense of the worst case performance. Fetches with this enabled will be slower the enabling synthetic source natively in the index.", + "visibility": "feature_flag", + "feature_flag": "es.index_mode_feature_flag_registered" + }, "stored_fields":{ "type":"list", "description":"A comma-separated list of stored fields to return in the response" diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/mget/90_synthetic_source.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/mget/90_synthetic_source.yml index 05d7f41445601..222f29733ef14 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/mget/90_synthetic_source.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/mget/90_synthetic_source.yml @@ -45,3 +45,117 @@ keyword: - match: docs.1._source: kwd: bar + +--- +force_synthetic_source_ok: + - skip: + version: " - 8.3.99" + reason: introduced in 8.4.0 + + - do: + indices.create: + index: test + body: + mappings: + _source: + synthetic: false + properties: + kwd: + type: keyword + + - do: + index: + index: test + id: 1 + body: + kwd: foo + + - do: + index: + index: test + id: 2 + body: + kwd: bar + + # When _source is used in the fetch the original _source is perfect + - do: + mget: + index: test + body: + ids: [1, 2] + - match: + docs.0._source: + kwd: foo + - match: + docs.1._source: + kwd: bar + + # When we force synthetic source dots in field names get turned into objects + - do: + mget: + index: test + force_synthetic_source: true + body: + ids: [ 1, 2 ] + - match: + docs.0._source: + kwd: foo + - match: + docs.1._source: + kwd: bar + +--- +force_synthetic_source_bad_mapping: + - skip: + version: " - 8.3.99" + reason: introduced in 8.4.0 + + - do: + indices.create: + index: test + body: + settings: + number_of_shards: 1 # Use a single shard to get consistent error messages + mappings: + _source: + synthetic: false + properties: + text: + type: text + + - do: + index: + index: test + id: 1 + body: + text: foo + + - do: + index: + index: test + id: 2 + body: + text: bar + + # When _source is used in the fetch the original _source is perfect + - do: + mget: + index: test + body: + ids: [ 1, 2 ] + - match: + docs.0._source: + text: foo + - match: + docs.1._source: + text: bar + + # Forcing synthetic source fails because the mapping is invalid + - do: + mget: + index: test + force_synthetic_source: true + body: + ids: [ 1, 2 ] + - match: {docs.0.error.reason: "field [text] of type [text] doesn't support synthetic source unless it has a sub-field of type [keyword] with doc values enabled and without ignore_above or a normalizer"} + - match: {docs.1.error.reason: "field [text] of type [text] doesn't support synthetic source unless it has a sub-field of type [keyword] with doc values enabled and without ignore_above or a normalizer"} diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesAction.java index 08446fbcb4dcf..a4e857054f9c4 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesAction.java @@ -25,6 +25,7 @@ import org.elasticsearch.cluster.metadata.DesiredNodesMetadata; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.routing.RerouteService; +import org.elasticsearch.cluster.routing.allocation.AllocationService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Priority; import org.elasticsearch.common.inject.Inject; @@ -50,7 +51,8 @@ public TransportUpdateDesiredNodesAction( ThreadPool threadPool, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, - DesiredNodesSettingsValidator settingsValidator + DesiredNodesSettingsValidator settingsValidator, + AllocationService allocationService ) { super( UpdateDesiredNodesAction.NAME, @@ -65,7 +67,7 @@ public TransportUpdateDesiredNodesAction( ThreadPool.Names.SAME ); this.settingsValidator = settingsValidator; - this.taskExecutor = new UpdateDesiredNodesExecutor(clusterService.getRerouteService()); + this.taskExecutor = new UpdateDesiredNodesExecutor(clusterService.getRerouteService(), allocationService); } @Override @@ -167,9 +169,11 @@ private static class UpdateDesiredNodesExecutor implements ClusterStateTaskExecu ); private final RerouteService rerouteService; + private final AllocationService allocationService; - UpdateDesiredNodesExecutor(RerouteService rerouteService) { + UpdateDesiredNodesExecutor(RerouteService rerouteService, AllocationService allocationService) { this.rerouteService = rerouteService; + this.allocationService = allocationService; } @Override @@ -194,7 +198,12 @@ public ClusterState execute(ClusterState currentState, List items = new ArrayList<>(); + /** + * Should this request force {@link SourceLoader.Synthetic synthetic source}? + * Use this to test if the mapping supports synthetic _source and to get a sense + * of the worst case performance. Fetches with this enabled will be slower the + * enabling synthetic source natively in the index. + */ + private boolean forceSyntheticSource = false; + public MultiGetRequest() {} public MultiGetRequest(StreamInput in) throws IOException { @@ -254,6 +263,11 @@ public MultiGetRequest(StreamInput in) throws IOException { refresh = in.readBoolean(); realtime = in.readBoolean(); items = in.readList(Item::new); + if (in.getVersion().onOrAfter(Version.V_8_4_0)) { + forceSyntheticSource = in.readBoolean(); + } else { + forceSyntheticSource = false; + } } @Override @@ -263,6 +277,13 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(refresh); out.writeBoolean(realtime); out.writeList(items); + if (out.getVersion().onOrAfter(Version.V_8_4_0)) { + out.writeBoolean(forceSyntheticSource); + } else { + if (forceSyntheticSource) { + throw new IllegalArgumentException("force_synthetic_source is not supported before 8.4.0"); + } + } } public List getItems() { @@ -331,6 +352,26 @@ public MultiGetRequest refresh(boolean refresh) { return this; } + /** + * Should this request force {@link SourceLoader.Synthetic synthetic source}? + * Use this to test if the mapping supports synthetic _source and to get a sense + * of the worst case performance. Fetches with this enabled will be slower the + * enabling synthetic source natively in the index. + */ + public void setForceSyntheticSource(boolean forceSyntheticSource) { + this.forceSyntheticSource = forceSyntheticSource; + } + + /** + * Should this request force {@link SourceLoader.Synthetic synthetic source}? + * Use this to test if the mapping supports synthetic _source and to get a sense + * of the worst case performance. Fetches with this enabled will be slower the + * enabling synthetic source natively in the index. + */ + public boolean isForceSyntheticSource() { + return forceSyntheticSource; + } + public MultiGetRequest add( @Nullable String defaultIndex, @Nullable String[] defaultFields, diff --git a/server/src/main/java/org/elasticsearch/action/get/MultiGetShardRequest.java b/server/src/main/java/org/elasticsearch/action/get/MultiGetShardRequest.java index 60f38ce7b9883..fc641c31512a7 100644 --- a/server/src/main/java/org/elasticsearch/action/get/MultiGetShardRequest.java +++ b/server/src/main/java/org/elasticsearch/action/get/MultiGetShardRequest.java @@ -8,10 +8,12 @@ package org.elasticsearch.action.get; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.support.single.shard.SingleShardRequest; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.mapper.SourceLoader; import java.io.IOException; import java.util.ArrayList; @@ -27,6 +29,14 @@ public class MultiGetShardRequest extends SingleShardRequest locations; List items; + /** + * Should this request force {@link SourceLoader.Synthetic synthetic source}? + * Use this to test if the mapping supports synthetic _source and to get a sense + * of the worst case performance. Fetches with this enabled will be slower the + * enabling synthetic source natively in the index. + */ + private final boolean forceSyntheticSource; + MultiGetShardRequest(MultiGetRequest multiGetRequest, String index, int shardId) { super(index); this.shardId = shardId; @@ -35,6 +45,7 @@ public class MultiGetShardRequest extends SingleShardRequest + * Desired nodes represents the cluster topology that the operator of the cluster is aiming for. + * Therefore, it is possible that the desired nodes contain nodes that are not part of the + * cluster in contrast to {@link DiscoveryNodes} that contains only nodes that are part of the cluster. + *

+ * + *

+ * This concept is useful as it provides more context about future topology changes to the system + * as well as the desired set of nodes in the cluster, allowing it to make better decisions + * about allocation, autoscaling, auto-expand replicas, etc. + *

+ * + *

+ * Additionally, settings validation is done during desired nodes updates avoiding boot-looping + * when an invalid setting is provided before the node is started. + *

+ * + *

+ * To modify the desired nodes it is necessary to provide the entire collection of nodes that will + * be part of the proposed cluster topology. + *

+ * + *

+ * Desired nodes are expected to be part of a lineage defined by the provided {@code historyId}. + * The {@code historyId} is provided by the orchestrator taking care of managing the cluster. + * In order to identify the different proposed desired nodes within the same history, it is + * also expected that the orchestrator provides a monotonically increasing {@code version} + * when it communicates about future topology changes. + * The cluster rejects desired nodes updated with a {@code version} less than or equal + * than the current {@code version} for the same {@code historyId}. + *

+ * + *

+ * The {@code historyId} is expected to remain stable during the cluster lifecycle, but it is + * possible that the orchestrator loses its own state and needs to be restored to a + * previous point in time with an older desired nodes {@code version}. In those cases it is + * expected to use new {@code historyId} that would allow starting from a different version. + *

+ * + *

+ * Each {@link DesiredNode} part of {@link DesiredNodes} has a {@link DesiredNodeWithStatus.Status} + * depending on whether or not the node has been part of the cluster at some point. + *

+ * + * The two possible statuses {@link DesiredNodeWithStatus.Status} are: + *
    + *
  • {@code PENDING}: The {@link DesiredNode} is not part of the cluster yet
  • + *
  • {@code ACTUALIZED}: The {@link DesiredNode} is or has been part of the cluster. + * Notice that it is possible that a node has {@code ACTUALIZED} status but it is not part of {@link DiscoveryNodes}, + * this is a conscious decision as it is expected that nodes can leave the cluster momentarily due to network issues, + * gc pressure, restarts, hardware failures etc, but are expected to still be part of the cluster. + *
  • + *
+ * + *

+ * See {@code JoinTaskExecutor} and {@code TransportUpdateDesiredNodesAction} for more details about + * desired nodes status tracking. + *

+ * + *

+ * Finally, each {@link DesiredNode} is expected to provide a way of identifying the node when it joins, + * {@link Node#NODE_EXTERNAL_ID_SETTING} allows providing that identity through settings. + *

+ * + */ public class DesiredNodes implements Writeable, ToXContentObject, Iterable { public static final String CONTEXT_MODE_PARAM = "desired_nodes_x_content_context"; public static final String CONTEXT_MODE_API = SerializationContext.GET_DESIRED_NODES_API.toString(); diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/RoutingAllocation.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/RoutingAllocation.java index 87c08b6139f6d..e502362aa91f6 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/RoutingAllocation.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/RoutingAllocation.java @@ -11,6 +11,7 @@ import org.elasticsearch.cluster.ClusterInfo; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.RestoreInProgress; +import org.elasticsearch.cluster.metadata.DesiredNodes; import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata; import org.elasticsearch.cluster.node.DiscoveryNodes; @@ -87,6 +88,8 @@ public RoutingAllocation( * @param deciders {@link AllocationDeciders} to used to make decisions for routing allocations * @param routingNodes Routing nodes in the current cluster or {@code null} if using those in the given cluster state * @param clusterState cluster state before rerouting + * @param clusterInfo information about node disk usage and shard disk usage + * @param shardSizeInfo information about snapshot shard sizes * @param currentNanoTime the nano time to use for all delay allocation calculation (typically {@link System#nanoTime()}) */ public RoutingAllocation( @@ -168,6 +171,11 @@ public SnapshotShardSizeInfo snapshotShardSizeInfo() { return shardSizeInfo; } + @Nullable + public DesiredNodes desiredNodes() { + return DesiredNodes.latestFromClusterState(clusterState); + } + /** * Returns the map of node id to shutdown metadata currently in the cluster */ diff --git a/server/src/main/java/org/elasticsearch/common/Rounding.java b/server/src/main/java/org/elasticsearch/common/Rounding.java index f99da4902db66..5fbed758c9842 100644 --- a/server/src/main/java/org/elasticsearch/common/Rounding.java +++ b/server/src/main/java/org/elasticsearch/common/Rounding.java @@ -191,10 +191,6 @@ public TemporalField getField() { return field; } - public static DateTimeUnit resolve(String name) { - return DateTimeUnit.valueOf(name.toUpperCase(Locale.ROOT)); - } - public String shortName() { return shortName; } diff --git a/server/src/main/java/org/elasticsearch/common/collect/ImmutableOpenMap.java b/server/src/main/java/org/elasticsearch/common/collect/ImmutableOpenMap.java index 416bb6c88a020..df5b57055bda9 100644 --- a/server/src/main/java/org/elasticsearch/common/collect/ImmutableOpenMap.java +++ b/server/src/main/java/org/elasticsearch/common/collect/ImmutableOpenMap.java @@ -27,7 +27,6 @@ import java.util.function.BiConsumer; import java.util.function.BiPredicate; import java.util.function.Consumer; -import java.util.function.Predicate; /** * An immutable map implementation based on open hash map. @@ -458,13 +457,6 @@ public VType getOrDefault(KType kType, VType vType) { return mutableMap.getOrDefault(kType, vType); } - public void putAll(Builder builder) { - maybeCloneMap(); - for (var entry : builder.mutableMap) { - mutableMap.put(entry.key, entry.value); - } - } - public VType remove(KType key) { maybeCloneMap(); return mutableMap.remove(key); @@ -480,16 +472,6 @@ public int size() { return mutableMap.size(); } - public boolean isEmpty() { - maybeCloneMap(); - return mutableMap.isEmpty(); - } - - public int removeAll(Predicate predicate) { - maybeCloneMap(); - return mutableMap.removeAll(predicate::test); - } - public void clear() { maybeCloneMap(); mutableMap.clear(); @@ -505,15 +487,5 @@ public int removeAll(BiPredicate predicate) { return mutableMap.removeAll(predicate::test); } - public int indexOf(KType key) { - maybeCloneMap(); - return mutableMap.indexOf(key); - } - - public void release() { - maybeCloneMap(); - mutableMap.release(); - } - } } diff --git a/server/src/main/java/org/elasticsearch/common/io/Streams.java b/server/src/main/java/org/elasticsearch/common/io/Streams.java index 839ef1f72288c..1089884ee0e04 100644 --- a/server/src/main/java/org/elasticsearch/common/io/Streams.java +++ b/server/src/main/java/org/elasticsearch/common/io/Streams.java @@ -246,11 +246,6 @@ public void close() throws IOException { flush(); } - @Override - public void reset() throws IOException { - delegate.reset(); - } - @Override public BytesReference bytes() { return delegate.bytes(); diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/BytesRefStreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/BytesRefStreamOutput.java index cc2c6afd8a46e..07475bdd4bb61 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/BytesRefStreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/BytesRefStreamOutput.java @@ -58,7 +58,6 @@ public void flush() {} @Override public void close() {} - @Override public void reset() { builder.clear(); } diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/BytesStreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/BytesStreamOutput.java index 5d2682f6ffd85..1a7f04b7b2e52 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/BytesStreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/BytesStreamOutput.java @@ -92,7 +92,6 @@ public void writeBytes(byte[] b, int offset, int length) { count += length; } - @Override public void reset() { // shrink list of pages if (bytes != null && bytes.size() > PageCacheRecycler.PAGE_SIZE_IN_BYTES) { diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/DataOutputStreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/DataOutputStreamOutput.java index b464ea8fba5f1..e7077c31b25d2 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/DataOutputStreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/DataOutputStreamOutput.java @@ -35,11 +35,6 @@ public void flush() throws IOException { // nothing to do there... } - @Override - public void reset() throws IOException { - // nothing to do there... - } - @Override public void close() throws IOException { if (out instanceof Closeable) { diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/DelayableWriteable.java b/server/src/main/java/org/elasticsearch/common/io/stream/DelayableWriteable.java index 2c934858454d3..5c637f27dfd19 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/DelayableWriteable.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/DelayableWriteable.java @@ -256,11 +256,5 @@ public void flush() throws IOException {} @Override public void close() throws IOException {} - - @Override - public void reset() throws IOException { - size = 0; - } - } } diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/OutputStreamStreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/OutputStreamStreamOutput.java index ce08e8e7fed81..de08b78fe9a12 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/OutputStreamStreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/OutputStreamStreamOutput.java @@ -38,9 +38,4 @@ public void flush() throws IOException { public void close() throws IOException { out.close(); } - - @Override - public void reset() throws IOException { - throw new UnsupportedOperationException(); - } } diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutput.java index 2f50fbfeccf12..265e3e5bf1a4c 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutput.java @@ -144,7 +144,6 @@ public void writeWithSizePrefix(Writeable writeable) throws IOException { } } - @Override public void reset() { Releasables.close(pages); pages.clear(); diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java index 1c9990b2dfac8..ee76260f3d73e 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java @@ -494,8 +494,6 @@ public void writeOptionalBoolean(@Nullable Boolean b) throws IOException { @Override public abstract void close() throws IOException; - public abstract void reset() throws IOException; - @Override public void write(int b) throws IOException { writeByte((byte) b); diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/VersionCheckingStreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/VersionCheckingStreamOutput.java index bd9ace99eb276..a686f35f394a4 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/VersionCheckingStreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/VersionCheckingStreamOutput.java @@ -45,11 +45,6 @@ public void close() throws IOException { } - @Override - public void reset() throws IOException { - // no-op - } - @Override public void writeNamedWriteable(NamedWriteable namedWriteable) throws IOException { if (namedWriteable instanceof VersionedNamedWriteable vnw) { diff --git a/server/src/main/java/org/elasticsearch/common/lucene/search/MoreLikeThisQuery.java b/server/src/main/java/org/elasticsearch/common/lucene/search/MoreLikeThisQuery.java index 02d9975397b55..d16e927efc973 100644 --- a/server/src/main/java/org/elasticsearch/common/lucene/search/MoreLikeThisQuery.java +++ b/server/src/main/java/org/elasticsearch/common/lucene/search/MoreLikeThisQuery.java @@ -24,7 +24,6 @@ import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.search.similarities.TFIDFSimilarity; import org.apache.lucene.util.BytesRef; -import org.elasticsearch.common.Strings; import org.elasticsearch.index.analysis.NamedAnalyzer; import java.io.IOException; @@ -32,7 +31,6 @@ import java.io.StringReader; import java.util.Arrays; import java.util.HashSet; -import java.util.List; import java.util.Objects; import java.util.Set; @@ -213,10 +211,6 @@ public String getLikeText() { return (likeText == null ? null : likeText[0]); } - public String[] getLikeTexts() { - return likeText; - } - public void setLikeText(String... likeText) { this.likeText = likeText; } @@ -229,10 +223,6 @@ public void setLikeFields(Fields... likeFields) { this.likeFields = likeFields; } - public void setLikeText(List likeText) { - setLikeText(likeText.toArray(Strings.EMPTY_ARRAY)); - } - public void setUnlikeFields(Fields... unlikeFields) { this.unlikeFields = unlikeFields; } @@ -249,10 +239,6 @@ public void setMoreLikeFields(String[] moreLikeFields) { this.moreLikeFields = moreLikeFields; } - public Similarity getSimilarity() { - return similarity; - } - public void setSimilarity(Similarity similarity) { if (similarity == null || similarity instanceof TFIDFSimilarity) { // LUCENE 4 UPGRADE we need TFIDF similarity here so I only set it if it is an instance of it @@ -269,16 +255,6 @@ public void setAnalyzer(String analyzerName, Analyzer analyzer) { this.analyzerName = analyzerName; } - /** - * Number of terms that must match the generated query expressed in the - * common syntax for minimum should match. - * - * @see org.elasticsearch.common.lucene.search.Queries#calculateMinShouldMatch(int, String) - */ - public String getMinimumShouldMatch() { - return minimumShouldMatch; - } - /** * Number of terms that must match the generated query expressed in the * common syntax for minimum should match. Defaults to {@code 30%}. @@ -308,58 +284,30 @@ public void setMaxQueryTerms(int maxQueryTerms) { this.maxQueryTerms = maxQueryTerms; } - public Set getStopWords() { - return stopWords; - } - public void setStopWords(Set stopWords) { this.stopWords = stopWords; } - public int getMinDocFreq() { - return minDocFreq; - } - public void setMinDocFreq(int minDocFreq) { this.minDocFreq = minDocFreq; } - public int getMaxDocFreq() { - return maxDocFreq; - } - public void setMaxDocFreq(int maxDocFreq) { this.maxDocFreq = maxDocFreq; } - public int getMinWordLen() { - return minWordLen; - } - public void setMinWordLen(int minWordLen) { this.minWordLen = minWordLen; } - public int getMaxWordLen() { - return maxWordLen; - } - public void setMaxWordLen(int maxWordLen) { this.maxWordLen = maxWordLen; } - public boolean isBoostTerms() { - return boostTerms; - } - public void setBoostTerms(boolean boostTerms) { this.boostTerms = boostTerms; } - public float getBoostTermsFactor() { - return boostTermsFactor; - } - public void setBoostTermsFactor(float boostTermsFactor) { this.boostTermsFactor = boostTermsFactor; } diff --git a/server/src/main/java/org/elasticsearch/common/lucene/search/XMoreLikeThis.java b/server/src/main/java/org/elasticsearch/common/lucene/search/XMoreLikeThis.java index 2bc909bcab1fe..1765e598c30a2 100644 --- a/server/src/main/java/org/elasticsearch/common/lucene/search/XMoreLikeThis.java +++ b/server/src/main/java/org/elasticsearch/common/lucene/search/XMoreLikeThis.java @@ -23,11 +23,8 @@ import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; -import org.apache.lucene.document.Document; -import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.Fields; import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.Term; import org.apache.lucene.index.Terms; @@ -38,7 +35,6 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; -import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.search.similarities.TFIDFSimilarity; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.CharsRefBuilder; @@ -47,9 +43,6 @@ import java.io.IOException; import java.io.Reader; -import java.io.StringReader; -import java.util.ArrayList; -import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Map; @@ -121,11 +114,9 @@ *
  • {@link #setMinTermFreq setMinTermFreq(...)} *
  • {@link #setMinDocFreq setMinDocFreq(...)} *
  • {@link #setMaxDocFreq setMaxDocFreq(...)} - *
  • {@link #setMaxDocFreqPct setMaxDocFreqPct(...)} *
  • {@link #setMinWordLen setMinWordLen(...)} *
  • {@link #setMaxWordLen setMaxWordLen(...)} *
  • {@link #setMaxQueryTerms setMaxQueryTerms(...)} - *
  • {@link #setMaxNumTokensParsed setMaxNumTokensParsed(...)} *
  • {@link #setStopWords setStopWord(...)} * *
    @@ -148,15 +139,12 @@ public final class XMoreLikeThis { /** * Default maximum number of tokens to parse in each example doc field that is not stored with TermVector support. - * - * @see #getMaxNumTokensParsed */ public static final int DEFAULT_MAX_NUM_TOKENS_PARSED = 5000; /** * Ignore terms with less than this frequency in the source doc. * - * @see #getMinTermFreq * @see #setMinTermFreq */ public static final int DEFAULT_MIN_TERM_FREQ = 2; @@ -164,7 +152,6 @@ public final class XMoreLikeThis { /** * Ignore words which do not occur in at least this many docs. * - * @see #getMinDocFreq * @see #setMinDocFreq */ public static final int DEFAULT_MIN_DOC_FREQ = 5; @@ -172,16 +159,13 @@ public final class XMoreLikeThis { /** * Ignore words which occur in more than this many docs. * - * @see #getMaxDocFreq * @see #setMaxDocFreq - * @see #setMaxDocFreqPct */ public static final int DEFAULT_MAX_DOC_FREQ = Integer.MAX_VALUE; /** * Boost terms in query based on score. * - * @see #isBoost * @see #setBoost */ public static final boolean DEFAULT_BOOST = false; @@ -195,7 +179,6 @@ public final class XMoreLikeThis { /** * Ignore words less than this length or if 0 then this has no effect. * - * @see #getMinWordLen * @see #setMinWordLen */ public static final int DEFAULT_MIN_WORD_LENGTH = 0; @@ -203,7 +186,6 @@ public final class XMoreLikeThis { /** * Ignore words greater than this length or if 0 then this has no effect. * - * @see #getMaxWordLen * @see #setMaxWordLen */ public static final int DEFAULT_MAX_WORD_LENGTH = 0; @@ -213,7 +195,6 @@ public final class XMoreLikeThis { * If null means to allow stop words. * * @see #setStopWords - * @see #getStopWords */ public static final Set DEFAULT_STOP_WORDS = null; @@ -226,7 +207,6 @@ public final class XMoreLikeThis { * Return a Query with no more than this many terms. * * @see BooleanQuery#getMaxClauseCount - * @see #getMaxQueryTerms * @see #setMaxQueryTerms */ public static final int DEFAULT_MAX_QUERY_TERMS = 25; @@ -266,11 +246,6 @@ public final class XMoreLikeThis { */ private String[] fieldNames = DEFAULT_FIELD_NAMES; - /** - * The maximum number of tokens to parse in each example doc field that is not stored with TermVector support - */ - private int maxNumTokensParsed = DEFAULT_MAX_NUM_TOKENS_PARSED; - /** * Ignore words if less than this len. */ @@ -289,7 +264,7 @@ public final class XMoreLikeThis { /** * For idf() calculations. */ - private TFIDFSimilarity similarity;// = new ClassicSimilarity(); + private final TFIDFSimilarity similarity;// = new ClassicSimilarity(); /** * IndexReader to use @@ -301,20 +276,8 @@ public final class XMoreLikeThis { */ private float boostFactor = 1; - /** - * Returns the boost factor used when boosting terms - * - * @return the boost factor used when boosting terms - * @see #setBoostFactor(float) - */ - public float getBoostFactor() { - return boostFactor; - } - /** * Sets the boost factor to use when boosting terms - * - * @see #getBoostFactor() */ public void setBoostFactor(float boostFactor) { this.boostFactor = boostFactor; @@ -327,39 +290,13 @@ public void setSkipTerms(Set skipTerms) { this.skipTerms = skipTerms; } - /** - * Constructor requiring an IndexReader. - */ - public XMoreLikeThis(IndexReader ir) { - this(ir, new ClassicSimilarity()); - } - public XMoreLikeThis(IndexReader ir, TFIDFSimilarity sim) { this.ir = ir; this.similarity = sim; } - public TFIDFSimilarity getSimilarity() { - return similarity; - } - - public void setSimilarity(TFIDFSimilarity similarity) { - this.similarity = similarity; - } - /** - * Returns an analyzer that will be used to parse source doc with. The default analyzer - * is not set. - * - * @return the analyzer that will be used to parse source doc with. - */ - public Analyzer getAnalyzer() { - return analyzer; - } - - /** - * Sets the analyzer to use. An analyzer is not required for generating a query with the - * {@link #like(int)} method, all other 'like' methods require an analyzer. + * Sets the analyzer to use. All 'like' methods require an analyzer. * * @param analyzer the analyzer to use to tokenize text. */ @@ -367,16 +304,6 @@ public void setAnalyzer(Analyzer analyzer) { this.analyzer = analyzer; } - /** - * Returns the frequency below which terms will be ignored in the source doc. The default - * frequency is the {@link #DEFAULT_MIN_TERM_FREQ}. - * - * @return the frequency below which terms will be ignored in the source doc. - */ - public int getMinTermFreq() { - return minTermFreq; - } - /** * Sets the frequency below which terms will be ignored in the source doc. * @@ -386,17 +313,6 @@ public void setMinTermFreq(int minTermFreq) { this.minTermFreq = minTermFreq; } - /** - * Returns the frequency at which words will be ignored which do not occur in at least this - * many docs. The default frequency is {@link #DEFAULT_MIN_DOC_FREQ}. - * - * @return the frequency at which words will be ignored which do not occur in at least this - * many docs. - */ - public int getMinDocFreq() { - return minDocFreq; - } - /** * Sets the frequency at which words will be ignored which do not occur in at least this * many docs. @@ -408,18 +324,6 @@ public void setMinDocFreq(int minDocFreq) { this.minDocFreq = minDocFreq; } - /** - * Returns the maximum frequency in which words may still appear. - * Words that appear in more than this many docs will be ignored. The default frequency is - * {@link #DEFAULT_MAX_DOC_FREQ}. - * - * @return get the maximum frequency at which words are still allowed, - * words which occur in more docs than this are ignored. - */ - public int getMaxDocFreq() { - return maxDocFreq; - } - /** * Set the maximum frequency in which words may still appear. Words that appear * in more than this many docs will be ignored. @@ -431,48 +335,15 @@ public void setMaxDocFreq(int maxFreq) { this.maxDocFreq = maxFreq; } - /** - * Set the maximum percentage in which words may still appear. Words that appear - * in more than this many percent of all docs will be ignored. - * - * @param maxPercentage the maximum percentage of documents (0-100) that a term may appear - * in to be still considered relevant - */ - public void setMaxDocFreqPct(int maxPercentage) { - this.maxDocFreq = maxPercentage * ir.numDocs() / 100; - } - - /** - * Returns whether to boost terms in query based on "score" or not. The default is - * {@link #DEFAULT_BOOST}. - * - * @return whether to boost terms in query based on "score" or not. - * @see #setBoost - */ - public boolean isBoost() { - return boost; - } - /** * Sets whether to boost terms in query based on "score" or not. * * @param boost true to boost terms in query based on "score", false otherwise. - * @see #isBoost */ public void setBoost(boolean boost) { this.boost = boost; } - /** - * Returns the field names that will be used when generating the 'More Like This' query. - * The default field names that will be used is {@link #DEFAULT_FIELD_NAMES}. - * - * @return the field names that will be used when generating the 'More Like This' query. - */ - public String[] getFieldNames() { - return fieldNames; - } - /** * Sets the field names that will be used when generating the 'More Like This' query. * Set this to null for the field names to be determined at runtime from the IndexReader @@ -485,16 +356,6 @@ public void setFieldNames(String[] fieldNames) { this.fieldNames = fieldNames; } - /** - * Returns the minimum word length below which words will be ignored. Set this to 0 for no - * minimum word length. The default is {@link #DEFAULT_MIN_WORD_LENGTH}. - * - * @return the minimum word length below which words will be ignored. - */ - public int getMinWordLen() { - return minWordLen; - } - /** * Sets the minimum word length below which words will be ignored. * @@ -504,16 +365,6 @@ public void setMinWordLen(int minWordLen) { this.minWordLen = minWordLen; } - /** - * Returns the maximum word length above which words will be ignored. Set this to 0 for no - * maximum word length. The default is {@link #DEFAULT_MAX_WORD_LENGTH}. - * - * @return the maximum word length above which words will be ignored. - */ - public int getMaxWordLen() { - return maxWordLen; - } - /** * Sets the maximum word length above which words will be ignored. * @@ -530,31 +381,11 @@ public void setMaxWordLen(int maxWordLen) { * for the purposes of document similarity it seems reasonable to assume that "a stop word is never interesting". * * @param stopWords set of stopwords, if null it means to allow stop words - * @see #getStopWords */ public void setStopWords(Set stopWords) { this.stopWords = stopWords; } - /** - * Get the current stop words being used. - * - * @see #setStopWords - */ - public Set getStopWords() { - return stopWords; - } - - /** - * Returns the maximum number of query terms that will be included in any generated query. - * The default is {@link #DEFAULT_MAX_QUERY_TERMS}. - * - * @return the maximum number of query terms that will be included in any generated query. - */ - public int getMaxQueryTerms() { - return maxQueryTerms; - } - /** * Sets the maximum number of query terms that will be included in any generated query. * @@ -565,37 +396,6 @@ public void setMaxQueryTerms(int maxQueryTerms) { this.maxQueryTerms = maxQueryTerms; } - /** - * @return The maximum number of tokens to parse in each example doc field that is not stored with TermVector support - * @see #DEFAULT_MAX_NUM_TOKENS_PARSED - */ - public int getMaxNumTokensParsed() { - return maxNumTokensParsed; - } - - /** - * @param i The maximum number of tokens to parse in each example doc field that is not stored with TermVector support - */ - public void setMaxNumTokensParsed(int i) { - maxNumTokensParsed = i; - } - - /** - * Return a query that will return docs like the passed lucene document ID. - * - * @param docNum the documentID of the lucene doc to generate the 'More Like This" query for. - * @return a query that will return docs like the passed lucene document ID. - */ - public Query like(int docNum) throws IOException { - if (fieldNames == null) { - // gather list of valid fields from lucene - Collection fields = FieldInfos.getIndexedFields(ir); - fieldNames = fields.toArray(new String[fields.size()]); - } - - return createQuery(retrieveTerms(docNum)); - } - /** * Return a query that will return docs like the passed Readers. * This was added in order to treat multi-value fields. @@ -610,19 +410,6 @@ public Query like(String fieldName, Reader... readers) throws IOException { return createQuery(createQueue(words)); } - /** - * Return a query that will return docs like the passed Terms. - * - * @return a query that will return docs like the passed Terms. - */ - public Query like(Terms... likeTerms) throws IOException { - Map termFreqMap = new HashMap<>(); - for (Terms vector : likeTerms) { - addTermFrequencies(termFreqMap, vector); - } - return createQuery(createQueue(termFreqMap)); - } - /** * Return a query that will return docs like the passed Fields. * @@ -751,71 +538,6 @@ private PriorityQueue createQueue(Map words, String... f return queue; } - /** - * Describe the parameters that control how the "more like this" query is formed. - */ - public String describeParams() { - StringBuilder sb = new StringBuilder(); - sb.append("\t").append("maxQueryTerms : ").append(maxQueryTerms).append("\n"); - sb.append("\t").append("minWordLen : ").append(minWordLen).append("\n"); - sb.append("\t").append("maxWordLen : ").append(maxWordLen).append("\n"); - sb.append("\t").append("fieldNames : "); - String delim = ""; - for (String fieldName : fieldNames) { - sb.append(delim).append(fieldName); - delim = ", "; - } - sb.append("\n"); - sb.append("\t").append("boost : ").append(boost).append("\n"); - sb.append("\t").append("minTermFreq : ").append(minTermFreq).append("\n"); - sb.append("\t").append("minDocFreq : ").append(minDocFreq).append("\n"); - return sb.toString(); - } - - /** - * Find words for a more-like-this query former. - * - * @param docNum the id of the lucene document from which to find terms - */ - private PriorityQueue retrieveTerms(int docNum) throws IOException { - Map termFreqMap = new HashMap<>(); - for (String fieldName : fieldNames) { - final Fields vectors = ir.getTermVectors(docNum); - final Terms vector; - if (vectors != null) { - vector = vectors.terms(fieldName); - } else { - vector = null; - } - - // field does not store term vector info - if (vector == null) { - Document d = ir.document(docNum); - IndexableField fields[] = d.getFields(fieldName); - for (IndexableField field : fields) { - final String stringValue = field.stringValue(); - if (stringValue != null) { - addTermFrequencies(new StringReader(stringValue), termFreqMap, fieldName); - } - } - } else { - addTermFrequencies(termFreqMap, vector, fieldName); - } - } - - return createQueue(termFreqMap); - } - - /** - * Adds terms and frequencies found in vector into the Map termFreqMap - * - * @param termFreqMap a Map of terms and their frequencies - * @param vector List of terms and their frequencies for a doc/field - */ - private void addTermFrequencies(Map termFreqMap, Terms vector) throws IOException { - addTermFrequencies(termFreqMap, vector, null); - } - /** * Adds terms and frequencies found in vector into the Map termFreqMap * @@ -874,7 +596,10 @@ private void addTermFrequencies(Reader r, Map termFreqMap, String f while (ts.incrementToken()) { String word = termAtt.toString(); tokenCount++; - if (tokenCount > maxNumTokensParsed) { + /** + * The maximum number of tokens to parse in each example doc field that is not stored with TermVector support + */ + if (tokenCount > DEFAULT_MAX_NUM_TOKENS_PARSED) { break; } if (isNoiseWord(word)) { @@ -920,73 +645,6 @@ private boolean isSkipTerm(@Nullable String field, String value) { return field != null && skipTerms != null && skipTerms.contains(new Term(field, value)); } - /** - * Find words for a more-like-this query former. - * The result is a priority queue of arrays with one entry for every word in the document. - * Each array has 6 elements. - * The elements are: - *
      - *
    1. The word (String) - *
    2. The top field that this word comes from (String) - *
    3. The score for this word (Float) - *
    4. The IDF value (Float) - *
    5. The frequency of this word in the index (Integer) - *
    6. The frequency of this word in the source document (Integer) - *
    - * This is a somewhat "advanced" routine, and in general only the 1st entry in the array is of interest. - * This method is exposed so that you can identify the "interesting words" in a document. - * For an easier method to call see {@link #retrieveInterestingTerms retrieveInterestingTerms()}. - * - * @param r the reader that has the content of the document - * @param fieldName field passed to the analyzer to use when analyzing the content - * @return the most interesting words in the document ordered by score, with the highest scoring, or best entry, first - * @see #retrieveInterestingTerms - */ - private PriorityQueue retrieveTerms(Reader r, String fieldName) throws IOException { - Map words = new HashMap<>(); - addTermFrequencies(r, words, fieldName); - return createQueue(words); - } - - /** - * @see #retrieveInterestingTerms(java.io.Reader, String) - */ - public String[] retrieveInterestingTerms(int docNum) throws IOException { - ArrayList al = new ArrayList<>(maxQueryTerms); - PriorityQueue pq = retrieveTerms(docNum); - ScoreTerm scoreTerm; - int lim = maxQueryTerms; // have to be careful, retrieveTerms returns all words but that's probably not useful to our caller... - // we just want to return the top words - while (((scoreTerm = pq.pop()) != null) && lim-- > 0) { - al.add(scoreTerm.word); // the 1st entry is the interesting word - } - String[] res = new String[al.size()]; - return al.toArray(res); - } - - /** - * Convenience routine to make it easy to return the most interesting words in a document. - * More advanced users will call {@link #retrieveTerms(Reader, String) retrieveTerms()} directly. - * - * @param r the source document - * @param fieldName field passed to analyzer to use when analyzing the content - * @return the most interesting words in the document - * @see #retrieveTerms(java.io.Reader, String) - * @see #setMaxQueryTerms - */ - public String[] retrieveInterestingTerms(Reader r, String fieldName) throws IOException { - ArrayList al = new ArrayList<>(maxQueryTerms); - PriorityQueue pq = retrieveTerms(r, fieldName); - ScoreTerm scoreTerm; - int lim = maxQueryTerms; // have to be careful, retrieveTerms returns all words but that's probably not useful to our caller... - // we just want to return the top words - while (((scoreTerm = pq.pop()) != null) && lim-- > 0) { - al.add(scoreTerm.word); // the 1st entry is the interesting word - } - String[] res = new String[al.size()]; - return al.toArray(res); - } - /** * PriorityQueue that orders words by score. */ diff --git a/server/src/main/java/org/elasticsearch/common/unit/DistanceUnit.java b/server/src/main/java/org/elasticsearch/common/unit/DistanceUnit.java index a7a212ffb0906..0ef00f4a389b6 100644 --- a/server/src/main/java/org/elasticsearch/common/unit/DistanceUnit.java +++ b/server/src/main/java/org/elasticsearch/common/unit/DistanceUnit.java @@ -59,24 +59,6 @@ public double getEarthCircumference() { return GeoUtils.EARTH_EQUATOR / meters; } - /** - * Measures the radius of earth in this unit - * - * @return length of earth radius in this unit - */ - public double getEarthRadius() { - return GeoUtils.EARTH_SEMI_MAJOR_AXIS / meters; - } - - /** - * Measures a longitude in this unit - * - * @return length of a longitude degree in this unit - */ - public double getDistancePerDegree() { - return GeoUtils.EARTH_EQUATOR / (360.0 * meters); - } - /** * Convert a value into meters * @@ -181,24 +163,6 @@ public static DistanceUnit fromString(String unit) { throw new IllegalArgumentException("No distance unit match [" + unit + "]"); } - /** - * Parses the suffix of a given distance string and return the corresponding {@link DistanceUnit} - * - * @param distance string representing a distance - * @param defaultUnit default unit to use, if no unit is provided by the string - * @return unit of the given distance - */ - public static DistanceUnit parseUnit(String distance, DistanceUnit defaultUnit) { - for (DistanceUnit unit : values()) { - for (String name : unit.names) { - if (distance.endsWith(name)) { - return unit; - } - } - } - return defaultUnit; - } - /** * This class implements a value+unit tuple. */ diff --git a/server/src/main/java/org/elasticsearch/common/unit/RelativeByteSizeValue.java b/server/src/main/java/org/elasticsearch/common/unit/RelativeByteSizeValue.java index 47637e4e2e174..7475f89f5910c 100644 --- a/server/src/main/java/org/elasticsearch/common/unit/RelativeByteSizeValue.java +++ b/server/src/main/java/org/elasticsearch/common/unit/RelativeByteSizeValue.java @@ -18,7 +18,6 @@ */ public class RelativeByteSizeValue { - public static final String MAX_HEADROOM_PREFIX = "max_headroom="; private final ByteSizeValue absolute; private final RatioValue ratio; diff --git a/server/src/main/java/org/elasticsearch/common/unit/SizeValue.java b/server/src/main/java/org/elasticsearch/common/unit/SizeValue.java index ea06fb454b58f..9584a5c55f503 100644 --- a/server/src/main/java/org/elasticsearch/common/unit/SizeValue.java +++ b/server/src/main/java/org/elasticsearch/common/unit/SizeValue.java @@ -47,90 +47,46 @@ public long singles() { return sizeUnit.toSingles(size); } - public long getSingles() { - return singles(); - } - public long kilo() { return sizeUnit.toKilo(size); } - public long getKilo() { - return kilo(); - } - public long mega() { return sizeUnit.toMega(size); } - public long getMega() { - return mega(); - } - public long giga() { return sizeUnit.toGiga(size); } - public long getGiga() { - return giga(); - } - public long tera() { return sizeUnit.toTera(size); } - public long getTera() { - return tera(); - } - public long peta() { return sizeUnit.toPeta(size); } - public long getPeta() { - return peta(); - } - public double kiloFrac() { return ((double) singles()) / SizeUnit.C1; } - public double getKiloFrac() { - return kiloFrac(); - } - public double megaFrac() { return ((double) singles()) / SizeUnit.C2; } - public double getMegaFrac() { - return megaFrac(); - } - public double gigaFrac() { return ((double) singles()) / SizeUnit.C3; } - public double getGigaFrac() { - return gigaFrac(); - } - public double teraFrac() { return ((double) singles()) / SizeUnit.C4; } - public double getTeraFrac() { - return teraFrac(); - } - public double petaFrac() { return ((double) singles()) / SizeUnit.C5; } - public double getPetaFrac() { - return petaFrac(); - } - @Override public String toString() { long singles = singles(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java index c1fe5ee41b87b..1dfd44d8a3d0d 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java @@ -1078,21 +1078,27 @@ public BytesSyntheticFieldLoader(String name, String simpleName) { @Override public Leaf leaf(LeafReader reader) throws IOException { SortedSetDocValues leaf = DocValues.getSortedSet(reader, name); + if (leaf.getValueCount() == 0) { + return SourceLoader.SyntheticFieldLoader.NOTHING.leaf(reader); + } return new SourceLoader.SyntheticFieldLoader.Leaf() { private boolean hasValue; @Override - public void advanceToDoc(int docId) throws IOException { - hasValue = leaf.advanceExact(docId); + public boolean empty() { + return false; } @Override - public boolean hasValue() { - return hasValue; + public boolean advanceToDoc(int docId) throws IOException { + return hasValue = leaf.advanceExact(docId); } @Override - public void load(XContentBuilder b) throws IOException { + public void write(XContentBuilder b) throws IOException { + if (false == hasValue) { + return; + } long first = leaf.nextOrd(); long next = leaf.nextOrd(); if (next == SortedSetDocValues.NO_MORE_ORDS) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/NumberFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/NumberFieldMapper.java index c7199b1d20ba1..5b47d43fca80b 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/NumberFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/NumberFieldMapper.java @@ -17,6 +17,7 @@ import org.apache.lucene.index.DocValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.NumericDocValues; import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.sandbox.document.HalfFloatPoint; import org.apache.lucene.sandbox.search.IndexSortSortedNumericDocValuesRangeQuery; @@ -1631,22 +1632,28 @@ protected NumericSyntheticFieldLoader(String name, String simpleName) { @Override public Leaf leaf(LeafReader reader) throws IOException { - SortedNumericDocValues leaf = DocValues.getSortedNumeric(reader, name); + SortedNumericDocValues leaf = dv(reader); + if (leaf == null) { + return SourceLoader.SyntheticFieldLoader.NOTHING.leaf(reader); + } return new SourceLoader.SyntheticFieldLoader.Leaf() { private boolean hasValue; @Override - public void advanceToDoc(int docId) throws IOException { - hasValue = leaf.advanceExact(docId); + public boolean empty() { + return false; } @Override - public boolean hasValue() { - return hasValue; + public boolean advanceToDoc(int docId) throws IOException { + return hasValue = leaf.advanceExact(docId); } @Override - public void load(XContentBuilder b) throws IOException { + public void write(XContentBuilder b) throws IOException { + if (false == hasValue) { + return; + } if (leaf.docValueCount() == 1) { b.field(simpleName); loadNextValue(b, leaf.nextValue()); @@ -1662,5 +1669,23 @@ public void load(XContentBuilder b) throws IOException { } protected abstract void loadNextValue(XContentBuilder b, long value) throws IOException; + + /** + * Returns a {@link SortedNumericDocValues} or null if it doesn't have any doc values. + * See {@link DocValues#getSortedNumeric} which is *nearly* the same, but it returns + * an "empty" implementation if there aren't any doc values. We need to be able to + * tell if there aren't any and return our empty leaf source loader. + */ + private SortedNumericDocValues dv(LeafReader reader) throws IOException { + SortedNumericDocValues dv = reader.getSortedNumericDocValues(name); + if (dv != null) { + return dv; + } + NumericDocValues single = reader.getNumericDocValues(name); + if (single != null) { + return DocValues.singleton(single); + } + return null; + } } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/ObjectMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/ObjectMapper.java index 3e629b4a21119..e98ccda06e352 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/ObjectMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/ObjectMapper.java @@ -554,53 +554,51 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep @Override public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { - List fields = new ArrayList<>(); - mappers.values().stream().sorted(Comparator.comparing(Mapper::name)).forEach(sub -> { - SourceLoader.SyntheticFieldLoader subLoader = sub.syntheticFieldLoader(); - if (subLoader != null) { - fields.add(subLoader); - } - }); + List fields = mappers.values() + .stream() + .sorted(Comparator.comparing(Mapper::name)) + .map(Mapper::syntheticFieldLoader) + .filter(l -> l != null) + .toList(); return new SourceLoader.SyntheticFieldLoader() { @Override public Leaf leaf(LeafReader reader) throws IOException { - List leaves = new ArrayList<>(); + List l = new ArrayList<>(); for (SourceLoader.SyntheticFieldLoader field : fields) { - leaves.add(field.leaf(reader)); + Leaf leaf = field.leaf(reader); + if (false == leaf.empty()) { + l.add(leaf); + } } + SourceLoader.SyntheticFieldLoader.Leaf[] leaves = l.toArray(SourceLoader.SyntheticFieldLoader.Leaf[]::new); return new SourceLoader.SyntheticFieldLoader.Leaf() { + private boolean hasValue; + @Override - public void advanceToDoc(int docId) throws IOException { - for (SourceLoader.SyntheticFieldLoader.Leaf leaf : leaves) { - leaf.advanceToDoc(docId); - } + public boolean empty() { + return leaves.length == 0; } @Override - public boolean hasValue() { + public boolean advanceToDoc(int docId) throws IOException { + hasValue = false; for (SourceLoader.SyntheticFieldLoader.Leaf leaf : leaves) { - if (leaf.hasValue()) { - return true; - } + boolean leafHasValue = leaf.advanceToDoc(docId); + hasValue |= leafHasValue; } - return false; + return hasValue; } @Override - public void load(XContentBuilder b) throws IOException { - boolean started = false; - for (SourceLoader.SyntheticFieldLoader.Leaf leaf : leaves) { - if (leaf.hasValue()) { - if (false == started) { - started = true; - startSyntheticField(b); - } - leaf.load(b); - } + public void write(XContentBuilder b) throws IOException { + if (hasValue == false) { + return; } - if (started) { - b.endObject(); + startSyntheticField(b); + for (SourceLoader.SyntheticFieldLoader.Leaf leaf : leaves) { + leaf.write(b); } + b.endObject(); } }; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/SourceLoader.java b/server/src/main/java/org/elasticsearch/index/mapper/SourceLoader.java index 89a4638d66a79..03e41a45b7a9f 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/SourceLoader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/SourceLoader.java @@ -42,6 +42,16 @@ interface Leaf { * @param docId the doc to load */ BytesReference source(FieldsVisitor fieldsVisitor, int docId) throws IOException; + + Leaf EMPTY_OBJECT = new Leaf() { + @Override + public BytesReference source(FieldsVisitor fieldsVisitor, int docId) throws IOException { + // TODO accept a requested xcontent type + try (XContentBuilder b = new XContentBuilder(JsonXContent.jsonXContent, new ByteArrayOutputStream())) { + return BytesReference.bytes(b.startObject().endObject()); + } + } + }; } /** @@ -82,14 +92,16 @@ public boolean reordersFieldValues() { @Override public Leaf leaf(LeafReader reader) throws IOException { SyntheticFieldLoader.Leaf leaf = loader.leaf(reader); + if (leaf.empty()) { + return Leaf.EMPTY_OBJECT; + } return new Leaf() { @Override public BytesReference source(FieldsVisitor fieldsVisitor, int docId) throws IOException { // TODO accept a requested xcontent type try (XContentBuilder b = new XContentBuilder(JsonXContent.jsonXContent, new ByteArrayOutputStream())) { - leaf.advanceToDoc(docId); - if (leaf.hasValue()) { - leaf.load(b); + if (leaf.advanceToDoc(docId)) { + leaf.write(b); } else { b.startObject().endObject(); } @@ -109,15 +121,17 @@ interface SyntheticFieldLoader { */ SyntheticFieldLoader NOTHING = r -> new Leaf() { @Override - public void advanceToDoc(int docId) throws IOException {} + public boolean empty() { + return true; + } @Override - public boolean hasValue() { + public boolean advanceToDoc(int docId) throws IOException { return false; } @Override - public void load(XContentBuilder b) throws IOException {} + public void write(XContentBuilder b) throws IOException {} }; /** @@ -130,19 +144,19 @@ public void load(XContentBuilder b) throws IOException {} */ interface Leaf { /** - * Position the loader at a document. + * Is this entirely empty? */ - void advanceToDoc(int docId) throws IOException; + boolean empty(); /** - * Is there a value for this field in this document? + * Position the loader at a document. */ - boolean hasValue(); + boolean advanceToDoc(int docId) throws IOException; /** - * Load values for this document. + * Write values for this document. */ - void load(XContentBuilder b) throws IOException; + void write(XContentBuilder b) throws IOException; } } diff --git a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java index d55ceebd24c4b..cd5a04e48b054 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java +++ b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java @@ -1515,7 +1515,7 @@ private Engine.Searcher wrapSearcher(Engine.Searcher searcher) { throw new ElasticsearchException("failed to wrap searcher", ex); } finally { if (success == false) { - Releasables.close(success, searcher); + Releasables.closeWhileHandlingException(searcher); } } } diff --git a/server/src/main/java/org/elasticsearch/index/translog/BufferedChecksumStreamOutput.java b/server/src/main/java/org/elasticsearch/index/translog/BufferedChecksumStreamOutput.java index 8ae63f59937d7..0f0e6bd0241ba 100644 --- a/server/src/main/java/org/elasticsearch/index/translog/BufferedChecksumStreamOutput.java +++ b/server/src/main/java/org/elasticsearch/index/translog/BufferedChecksumStreamOutput.java @@ -54,12 +54,6 @@ public void close() throws IOException { out.close(); } - @Override - public void reset() throws IOException { - out.reset(); - digest.reset(); - } - public void resetDigest() { digest.reset(); } diff --git a/server/src/main/java/org/elasticsearch/rest/action/document/RestMultiGetAction.java b/server/src/main/java/org/elasticsearch/rest/action/document/RestMultiGetAction.java index e3b0e3ecbbcc0..80735e269c97a 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/document/RestMultiGetAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/document/RestMultiGetAction.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.RestApiVersion; +import org.elasticsearch.index.IndexSettings; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestToXContentListener; @@ -60,6 +61,10 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC multiGetRequest.refresh(request.paramAsBoolean("refresh", multiGetRequest.refresh())); multiGetRequest.preference(request.param("preference")); multiGetRequest.realtime(request.paramAsBoolean("realtime", multiGetRequest.realtime())); + if (IndexSettings.isTimeSeriesModeEnabled() && request.paramAsBoolean("force_synthetic_source", false)) { + multiGetRequest.setForceSyntheticSource(true); + } + if (request.param("fields") != null) { throw new IllegalArgumentException( "The parameter [fields] is no longer supported, " diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesActionTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesActionTests.java index 0d97570e04a77..175b49d83972a 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesActionTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.cluster.metadata.DesiredNodesTestCase; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.routing.allocation.AllocationService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; @@ -59,7 +60,8 @@ public void testWriteBlocks() { mock(ThreadPool.class), mock(ActionFilters.class), mock(IndexNameExpressionResolver.class), - NO_OP_SETTINGS_VALIDATOR + NO_OP_SETTINGS_VALIDATOR, + mock(AllocationService.class) ); final ClusterBlocks blocks = ClusterBlocks.builder() @@ -83,7 +85,8 @@ public void testNoBlocks() { mock(ThreadPool.class), mock(ActionFilters.class), mock(IndexNameExpressionResolver.class), - NO_OP_SETTINGS_VALIDATOR + NO_OP_SETTINGS_VALIDATOR, + mock(AllocationService.class) ); final ClusterBlocks blocks = ClusterBlocks.builder().build(); @@ -106,7 +109,8 @@ public void validate(List desiredNodes) { mock(ThreadPool.class), mock(ActionFilters.class), mock(IndexNameExpressionResolver.class), - validator + validator, + mock(AllocationService.class) ); final ClusterState state = ClusterState.builder(new ClusterName(randomAlphaOfLength(10))).build(); diff --git a/server/src/test/java/org/elasticsearch/action/get/GetRequestTests.java b/server/src/test/java/org/elasticsearch/action/get/GetRequestTests.java index 76f7e8e91a996..6e318ca60e744 100644 --- a/server/src/test/java/org/elasticsearch/action/get/GetRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/get/GetRequestTests.java @@ -7,7 +7,10 @@ */ package org.elasticsearch.action.get; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.test.ESTestCase; import static org.hamcrest.CoreMatchers.hasItems; @@ -33,4 +36,13 @@ public void testValidation() { assertThat(validate.validationErrors(), hasItems("id is missing")); } } + + public void testForceSyntheticUnsupported() { + GetRequest request = new GetRequest("index", "id"); + request.setForceSyntheticSource(true); + StreamOutput out = new BytesStreamOutput(); + out.setVersion(Version.V_8_3_0); + Exception e = expectThrows(IllegalArgumentException.class, () -> request.writeTo(out)); + assertEquals(e.getMessage(), "force_synthetic_source is not supported before 8.4.0"); + } } diff --git a/server/src/test/java/org/elasticsearch/action/get/MultiGetRequestTests.java b/server/src/test/java/org/elasticsearch/action/get/MultiGetRequestTests.java index 79def1398e967..49b726b9d93e7 100644 --- a/server/src/test/java/org/elasticsearch/action/get/MultiGetRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/get/MultiGetRequestTests.java @@ -8,8 +8,11 @@ package org.elasticsearch.action.get; +import org.elasticsearch.Version; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.VersionType; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.test.ESTestCase; @@ -117,6 +120,15 @@ public void testXContentSerialization() throws IOException { } } + public void testForceSyntheticUnsupported() { + MultiGetRequest request = createTestInstance(); + request.setForceSyntheticSource(true); + StreamOutput out = new BytesStreamOutput(); + out.setVersion(Version.V_8_3_0); + Exception e = expectThrows(IllegalArgumentException.class, () -> request.writeTo(out)); + assertEquals(e.getMessage(), "force_synthetic_source is not supported before 8.4.0"); + } + private MultiGetRequest createTestInstance() { int numItems = randomIntBetween(0, 128); MultiGetRequest request = new MultiGetRequest(); @@ -149,6 +161,9 @@ private MultiGetRequest createTestInstance() { } request.add(item); } + if (randomBoolean()) { + request.setForceSyntheticSource(true); + } return request; } diff --git a/server/src/test/java/org/elasticsearch/action/get/MultiGetShardRequestTests.java b/server/src/test/java/org/elasticsearch/action/get/MultiGetShardRequestTests.java index adeeb78c4bf05..2546e4ef2a0ec 100644 --- a/server/src/test/java/org/elasticsearch/action/get/MultiGetShardRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/get/MultiGetShardRequestTests.java @@ -8,19 +8,62 @@ package org.elasticsearch.action.get; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.VersionType; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.test.ESTestCase; import java.io.IOException; -import static org.elasticsearch.test.VersionUtils.randomVersion; +import static org.elasticsearch.test.VersionUtils.randomVersionBetween; import static org.hamcrest.CoreMatchers.equalTo; public class MultiGetShardRequestTests extends ESTestCase { public void testSerialization() throws IOException { + MultiGetShardRequest multiGetShardRequest = createTestInstance(randomBoolean()); + + BytesStreamOutput out = new BytesStreamOutput(); + Version minVersion = Version.CURRENT.minimumCompatibilityVersion(); + if (multiGetShardRequest.isForceSyntheticSource()) { + minVersion = Version.V_8_4_0; + } + out.setVersion(randomVersionBetween(random(), minVersion, Version.CURRENT)); + multiGetShardRequest.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + in.setVersion(out.getVersion()); + MultiGetShardRequest multiGetShardRequest2 = new MultiGetShardRequest(in); + assertThat(multiGetShardRequest2.index(), equalTo(multiGetShardRequest.index())); + assertThat(multiGetShardRequest2.preference(), equalTo(multiGetShardRequest.preference())); + assertThat(multiGetShardRequest2.realtime(), equalTo(multiGetShardRequest.realtime())); + assertThat(multiGetShardRequest2.refresh(), equalTo(multiGetShardRequest.refresh())); + assertThat(multiGetShardRequest2.items.size(), equalTo(multiGetShardRequest.items.size())); + for (int i = 0; i < multiGetShardRequest2.items.size(); i++) { + MultiGetRequest.Item item = multiGetShardRequest.items.get(i); + MultiGetRequest.Item item2 = multiGetShardRequest2.items.get(i); + assertThat(item2.index(), equalTo(item.index())); + assertThat(item2.id(), equalTo(item.id())); + assertThat(item2.storedFields(), equalTo(item.storedFields())); + assertThat(item2.version(), equalTo(item.version())); + assertThat(item2.versionType(), equalTo(item.versionType())); + assertThat(item2.fetchSourceContext(), equalTo(item.fetchSourceContext())); + } + assertThat(multiGetShardRequest2.indices(), equalTo(multiGetShardRequest.indices())); + assertThat(multiGetShardRequest2.indicesOptions(), equalTo(multiGetShardRequest.indicesOptions())); + } + + public void testForceSyntheticUnsupported() { + MultiGetShardRequest request = createTestInstance(true); + StreamOutput out = new BytesStreamOutput(); + out.setVersion(Version.V_8_3_0); + Exception e = expectThrows(IllegalArgumentException.class, () -> request.writeTo(out)); + assertEquals(e.getMessage(), "force_synthetic_source is not supported before 8.4.0"); + } + + private MultiGetShardRequest createTestInstance(boolean forceSyntheticSource) { MultiGetRequest multiGetRequest = new MultiGetRequest(); if (randomBoolean()) { multiGetRequest.preference(randomAlphaOfLength(randomIntBetween(1, 10))); @@ -31,6 +74,9 @@ public void testSerialization() throws IOException { if (randomBoolean()) { multiGetRequest.refresh(true); } + if (forceSyntheticSource) { + multiGetRequest.setForceSyntheticSource(true); + } MultiGetShardRequest multiGetShardRequest = new MultiGetShardRequest(multiGetRequest, "index", 0); int numItems = iterations(10, 30); for (int i = 0; i < numItems; i++) { @@ -52,30 +98,7 @@ public void testSerialization() throws IOException { } multiGetShardRequest.add(0, item); } - - BytesStreamOutput out = new BytesStreamOutput(); - out.setVersion(randomVersion(random())); - multiGetShardRequest.writeTo(out); - - StreamInput in = out.bytes().streamInput(); - in.setVersion(out.getVersion()); - MultiGetShardRequest multiGetShardRequest2 = new MultiGetShardRequest(in); - assertThat(multiGetShardRequest2.index(), equalTo(multiGetShardRequest.index())); - assertThat(multiGetShardRequest2.preference(), equalTo(multiGetShardRequest.preference())); - assertThat(multiGetShardRequest2.realtime(), equalTo(multiGetShardRequest.realtime())); - assertThat(multiGetShardRequest2.refresh(), equalTo(multiGetShardRequest.refresh())); - assertThat(multiGetShardRequest2.items.size(), equalTo(multiGetShardRequest.items.size())); - for (int i = 0; i < multiGetShardRequest2.items.size(); i++) { - MultiGetRequest.Item item = multiGetShardRequest.items.get(i); - MultiGetRequest.Item item2 = multiGetShardRequest2.items.get(i); - assertThat(item2.index(), equalTo(item.index())); - assertThat(item2.id(), equalTo(item.id())); - assertThat(item2.storedFields(), equalTo(item.storedFields())); - assertThat(item2.version(), equalTo(item.version())); - assertThat(item2.versionType(), equalTo(item.versionType())); - assertThat(item2.fetchSourceContext(), equalTo(item.fetchSourceContext())); - } - assertThat(multiGetShardRequest2.indices(), equalTo(multiGetShardRequest.indices())); - assertThat(multiGetShardRequest2.indicesOptions(), equalTo(multiGetShardRequest.indicesOptions())); + return multiGetShardRequest; } + } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/DesiredNodeSerializationTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/DesiredNodeSerializationTests.java index 611d177d6e853..bd10abf7cb6ca 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/DesiredNodeSerializationTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/DesiredNodeSerializationTests.java @@ -16,8 +16,6 @@ import java.io.IOException; -import static org.elasticsearch.cluster.metadata.DesiredNodesTestCase.randomProcessorRange; - public class DesiredNodeSerializationTests extends AbstractSerializingTestCase { @Override protected DesiredNode doParseInstance(XContentParser parser) throws IOException { @@ -52,15 +50,14 @@ public static DesiredNode mutateDesiredNode(DesiredNode instance) { ); case 1 -> new DesiredNode( instance.settings(), - randomFloat() + randomIntBetween(1, 128), + randomValueOtherThan(instance.processors(), () -> randomFloat() + randomIntBetween(1, 128)), instance.memory(), instance.storage(), instance.version() ); - case 2 -> new DesiredNode( instance.settings(), - randomProcessorRange(), + randomValueOtherThan(instance.processorsRange(), DesiredNodesTestCase::randomProcessorRange), instance.memory(), instance.storage(), instance.version() @@ -69,7 +66,7 @@ public static DesiredNode mutateDesiredNode(DesiredNode instance) { instance.settings(), instance.processors(), instance.processorsRange(), - ByteSizeValue.ofGb(randomIntBetween(1, 128)), + ByteSizeValue.ofGb(randomValueOtherThan(instance.memory().getGb(), () -> (long) randomIntBetween(1, 128))), instance.storage(), instance.version() ); @@ -78,7 +75,7 @@ public static DesiredNode mutateDesiredNode(DesiredNode instance) { instance.processors(), instance.processorsRange(), instance.memory(), - ByteSizeValue.ofGb(randomIntBetween(1, 128)), + ByteSizeValue.ofGb(randomValueOtherThan(instance.storage().getGb(), () -> (long) randomIntBetween(1, 128))), instance.version() ); case 5 -> new DesiredNode( diff --git a/server/src/test/java/org/elasticsearch/common/unit/SizeValueTests.java b/server/src/test/java/org/elasticsearch/common/unit/SizeValueTests.java index 45c5351a4464e..4bbd56bd6e633 100644 --- a/server/src/test/java/org/elasticsearch/common/unit/SizeValueTests.java +++ b/server/src/test/java/org/elasticsearch/common/unit/SizeValueTests.java @@ -83,7 +83,7 @@ public void testCompareUnits() { public void testConversionHashCode() { SizeValue firstValue = new SizeValue(randomIntBetween(0, Integer.MAX_VALUE), SizeUnit.GIGA); - SizeValue secondValue = new SizeValue(firstValue.getSingles(), SizeUnit.SINGLE); + SizeValue secondValue = new SizeValue(firstValue.singles(), SizeUnit.SINGLE); assertEquals(firstValue.hashCode(), secondValue.hashCode()); } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/filter/FiltersAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/filter/FiltersAggregatorTests.java index eb52417db40c5..dd99b100592d5 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/filter/FiltersAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/filter/FiltersAggregatorTests.java @@ -99,6 +99,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.nullValue; import static org.mockito.Mockito.mock; @@ -694,7 +695,10 @@ public void onCache(ShardId shardId, Accountable accountable) {} .entry("segments_with_deleted_docs", 0) .entry( "filters", - matchesList().item(matchesMap().entry("query", "*:*").entry("segments_counted_in_constant_time", 1)) + matchesList().item( + matchesMap().entry("query", "*:*") + .entry("segments_counted_in_constant_time", searcher.getLeafContexts().size()) + ) ) ); } @@ -706,8 +710,10 @@ public void onCache(ShardId shardId, Accountable accountable) {} * the index set up kind of like document level security. As a bonus, this * "looks" to the agg just like an index with deleted documents. *

    - * This can't use the constant time counting because {@code term} doesn't - * know how to count in constant time when there are deleted documents. + * Segments with a filter that doesn't rewrite to {@code match_all} can't + * take the fast path. But segments who's filter rewrites to {@code match_all} + * can use the fast path - thus the assertion at the bottom of this: + * {@code "segments_counted_in_constant_time", lessThan(searcher.getLeafContexts().size())}. */ public void testTermOnFilteredIndex() throws IOException { KeywordFieldType ft = new KeywordFieldType("foo"); @@ -758,7 +764,75 @@ public void onCache(ShardId shardId, Accountable accountable) {} .entry("segments_with_deleted_docs", 0) .entry( "filters", - matchesList().item(matchesMap().entry("query", "foo:bar").entry("segments_counted_in_constant_time", 0)) + matchesList().item( + matchesMap().entry("query", "foo:bar") + .entry("segments_counted_in_constant_time", lessThan(searcher.getLeafContexts().size())) + ) + ) + ); + } + } + } + + /** + * This runs {@code filters} with a single {@code term} filter with + * the index set up kind of like document level security where the + * document level security query matches all documents. These can + * always take the fast path in filter-by-filter. + */ + public void testTermOnFilterWithMatchAll() throws IOException { + KeywordFieldType ft = new KeywordFieldType("foo"); + AggregationBuilder builder = new FiltersAggregationBuilder("test", new KeyedFilter("q1", new TermQueryBuilder("foo", "bar"))); + try (Directory directory = newDirectory()) { + RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory); + for (int i = 0; i < 10; i++) { + indexWriter.addDocument(List.of(new Field("foo", "bar", KeywordFieldMapper.Defaults.FIELD_TYPE), new LongPoint("t", i))); + } + indexWriter.close(); + + try (DirectoryReader directoryReader = DirectoryReader.open(directory)) { + BitsetFilterCache bitsetFilterCache = new BitsetFilterCache(createIndexSettings(), new BitsetFilterCache.Listener() { + @Override + public void onRemoval(ShardId shardId, Accountable accountable) {} + + @Override + public void onCache(ShardId shardId, Accountable accountable) {} + }); + IndexReader limitedReader = new DocumentSubsetDirectoryReader( + ElasticsearchDirectoryReader.wrap(directoryReader, new ShardId(bitsetFilterCache.index(), 0)), + bitsetFilterCache, + LongPoint.newRangeQuery("t", Long.MIN_VALUE, Long.MAX_VALUE) + ); + IndexSearcher searcher = newIndexSearcher(limitedReader); + AggregationContext context = createAggregationContext(searcher, new MatchAllDocsQuery(), ft); + FilterByFilterAggregator aggregator = createAggregator(builder, context); + aggregator.preCollection(); + searcher.search(context.query(), aggregator); + aggregator.postCollection(); + + InternalAggregation result = aggregator.buildTopLevel(); + result = result.reduce( + List.of(result), + new AggregationReduceContext.ForFinal(context.bigArrays(), getMockScriptService(), () -> false, null, b -> {}) + ); + InternalFilters filters = (InternalFilters) result; + assertThat(filters.getBuckets(), hasSize(1)); + assertThat(filters.getBucketByKey("q1").getDocCount(), equalTo(10L)); + + Map debug = new HashMap<>(); + aggregator.collectDebugInfo(debug::put); + assertMap( + debug, + matchesMap().entry("segments_counted", greaterThanOrEqualTo(1)) + .entry("segments_collected", 0) + .entry("segments_with_doc_count_field", 0) + .entry("segments_with_deleted_docs", 0) + .entry( + "filters", + matchesList().item( + matchesMap().entry("query", "foo:bar") + .entry("segments_counted_in_constant_time", searcher.getLeafContexts().size()) + ) ) ); } diff --git a/server/src/test/java/org/elasticsearch/threadpool/UpdateThreadPoolSettingsTests.java b/server/src/test/java/org/elasticsearch/threadpool/UpdateThreadPoolSettingsTests.java index 736a24b679f46..8e5cfce61938a 100644 --- a/server/src/test/java/org/elasticsearch/threadpool/UpdateThreadPoolSettingsTests.java +++ b/server/src/test/java/org/elasticsearch/threadpool/UpdateThreadPoolSettingsTests.java @@ -142,7 +142,7 @@ public void testShutdownNowInterrupts() throws Exception { .put("node.name", "testShutdownNowInterrupts") .build(); threadPool = new ThreadPool(nodeSettings); - assertEquals(info(threadPool, threadPoolName).getQueueSize().getSingles(), 1000L); + assertEquals(info(threadPool, threadPoolName).getQueueSize().singles(), 1000L); final CountDownLatch shutDownLatch = new CountDownLatch(1); final CountDownLatch latch = new CountDownLatch(1); diff --git a/test/framework/src/main/java/org/elasticsearch/search/geo/GeoShapeQueryTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/geo/GeoShapeQueryTestCase.java index d93c086581618..2634fc6dafedc 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/geo/GeoShapeQueryTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/geo/GeoShapeQueryTestCase.java @@ -157,8 +157,7 @@ public void testShapeFetchingPath() throws Exception { public void testRandomGeoCollectionQuery() throws Exception { // Create a random geometry collection to index. GeometryCollection randomIndexCollection = GeometryTestUtils.randomGeometryCollectionWithoutCircle(false); - org.apache.lucene.geo.Polygon randomPoly = GeoTestUtil.nextPolygon(); - Polygon polygon = new Polygon(new LinearRing(randomPoly.getPolyLons(), randomPoly.getPolyLats())); + Polygon polygon = GeometryTestUtils.randomPolygon(false); List indexGeometries = new ArrayList<>(); for (Geometry geometry : randomIndexCollection) { indexGeometries.add(geometry); diff --git a/x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderService.java b/x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderService.java index fbe59764371ef..4db5b5d451033 100644 --- a/x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderService.java +++ b/x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderService.java @@ -13,6 +13,7 @@ import org.elasticsearch.cluster.DiskUsage; import org.elasticsearch.cluster.metadata.DataStream; import org.elasticsearch.cluster.metadata.DataStreamMetadata; +import org.elasticsearch.cluster.metadata.DesiredNodes; import org.elasticsearch.cluster.metadata.IndexAbstraction; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.Metadata; @@ -401,7 +402,7 @@ private IndexMetadata indexMetadata(ShardRouting shard, RoutingAllocation alloca return allocation.metadata().getIndexSafe(shard.index()); } - private Optional highestPreferenceTier(List preferredTiers, DiscoveryNodes unused) { + private Optional highestPreferenceTier(List preferredTiers, DiscoveryNodes unused, DesiredNodes desiredNodes) { assert preferredTiers.isEmpty() == false; return Optional.of(preferredTiers.get(0)); } diff --git a/x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDeciderIT.java b/x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDeciderIT.java index 039bb8b840be1..3aa7dd0a7f8c3 100644 --- a/x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDeciderIT.java +++ b/x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDeciderIT.java @@ -7,30 +7,51 @@ package org.elasticsearch.xpack.cluster.routing.allocation; +import org.elasticsearch.Version; +import org.elasticsearch.action.admin.cluster.desirednodes.UpdateDesiredNodesAction; +import org.elasticsearch.action.admin.cluster.desirednodes.UpdateDesiredNodesRequest; import org.elasticsearch.action.admin.indices.settings.put.UpdateSettingsRequest; import org.elasticsearch.action.admin.indices.shrink.ResizeType; import org.elasticsearch.action.admin.indices.template.put.PutComposableIndexTemplateAction; import org.elasticsearch.cluster.health.ClusterHealthStatus; import org.elasticsearch.cluster.metadata.ComposableIndexTemplate; +import org.elasticsearch.cluster.metadata.DesiredNode; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.Template; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.routing.allocation.DataTier; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.core.Nullable; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.xpack.core.DataTiersFeatureSetUsage; import org.elasticsearch.xpack.core.action.XPackUsageRequestBuilder; import org.elasticsearch.xpack.core.action.XPackUsageResponse; +import org.junit.Before; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import static org.elasticsearch.cluster.metadata.IndexMetadata.INDEX_AUTO_EXPAND_REPLICAS_SETTING; +import static org.elasticsearch.cluster.metadata.IndexMetadata.INDEX_NUMBER_OF_REPLICAS_SETTING; +import static org.elasticsearch.node.Node.NODE_EXTERNAL_ID_SETTING; +import static org.elasticsearch.node.Node.NODE_NAME_SETTING; +import static org.elasticsearch.node.NodeRoleSettings.NODE_ROLES_SETTING; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; -@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0, numClientNodes = 0) +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0, numClientNodes = 0, autoManageMasterNodes = false) public class DataTierAllocationDeciderIT extends ESIntegTestCase { private static final String index = "myindex"; @@ -39,6 +60,13 @@ protected Collection> nodePlugins() { return Collections.singleton(DataTierTelemetryPlugin.class); } + @Before + public void setUpMasterNode() { + // Ensure that master nodes cannot hold any data + internalCluster().setBootstrapMasterNodeIndex(0); + internalCluster().startMasterOnlyNode(); + } + public void testDefaultIndexAllocateToContent() { startWarmOnlyNode(); startColdOnlyNode(); @@ -67,6 +95,194 @@ public void testDefaultIndexAllocateToContent() { ensureYellow(index); } + public void testDesiredNodesAreConsideredDuringAllocation() throws Exception { + final var warmDesiredNode = desiredNode(randomAlphaOfLength(10), DiscoveryNodeRole.DATA_WARM_NODE_ROLE); + final var coldDesiredNode = desiredNode(randomAlphaOfLength(15), DiscoveryNodeRole.DATA_COLD_NODE_ROLE); + final var masterDesiredNode = desiredNode(internalCluster().getMasterName(), DiscoveryNodeRole.MASTER_ROLE); + updateDesiredNodes(warmDesiredNode, coldDesiredNode, masterDesiredNode); + + startWarmOnlyNode(warmDesiredNode.externalId()); + final var coldNodeName = startColdOnlyNode(coldDesiredNode.externalId()); + + createIndexWithTierPreference(DataTier.DATA_COLD, DataTier.DATA_WARM); + + ensureGreen(index); + + assertPrimaryShardIsAllocatedInNodeWithRole(0, DiscoveryNodeRole.DATA_COLD_NODE_ROLE); + + // Remove the cold tier + updateDesiredNodes(masterDesiredNode, warmDesiredNode); + + assertBusy(() -> assertPrimaryShardIsAllocatedInNodeWithRole(0, DiscoveryNodeRole.DATA_WARM_NODE_ROLE)); + + ensureGreen(index); + } + + public void testShardsAreKeptInPreferredTierUntilTheNextTierIsInItsFinalState() throws Exception { + final var hotDesiredNode = desiredNode("hot-node-0", DiscoveryNodeRole.DATA_HOT_NODE_ROLE); + final var warmDesiredNode = desiredNode("warn-node-0", DiscoveryNodeRole.DATA_WARM_NODE_ROLE); + final var coldDesiredNode = desiredNode("cold-node-0", DiscoveryNodeRole.DATA_COLD_NODE_ROLE); + final var masterDesiredNode = desiredNode(internalCluster().getMasterName(), DiscoveryNodeRole.MASTER_ROLE); + updateDesiredNodes(hotDesiredNode, warmDesiredNode, coldDesiredNode, masterDesiredNode); + + startHotOnlyNode(hotDesiredNode.externalId()); + startWarmOnlyNode(warmDesiredNode.externalId()); + startColdOnlyNode(coldDesiredNode.externalId()); + + createIndexWithTierPreference(DataTier.DATA_COLD, DataTier.DATA_WARM, DataTier.DATA_HOT); + + ensureGreen(index); + + assertPrimaryShardIsAllocatedInNodeWithRole(0, DiscoveryNodeRole.DATA_COLD_NODE_ROLE); + + final List newDesiredNodesInLeastPreferredTiers = new ArrayList<>(); + final var numberOfNewNodes = randomIntBetween(1, 5); + for (int i = 1; i <= numberOfNewNodes; i++) { + if (randomBoolean()) { + newDesiredNodesInLeastPreferredTiers.add(desiredNode("hot-node-" + i, DiscoveryNodeRole.DATA_HOT_NODE_ROLE)); + } else { + newDesiredNodesInLeastPreferredTiers.add(desiredNode("warm-node-" + i, DiscoveryNodeRole.DATA_WARM_NODE_ROLE)); + } + } + + // Remove the cold tier and grow the next preferred tiers + final List newDesiredNodes = new ArrayList<>(newDesiredNodesInLeastPreferredTiers); + newDesiredNodes.add(masterDesiredNode); + newDesiredNodes.add(hotDesiredNode); + newDesiredNodes.add(warmDesiredNode); + updateDesiredNodes(newDesiredNodes); + + ensureGreen(index); + + assertBusy(() -> assertPrimaryShardIsAllocatedInNodeWithRole(0, DiscoveryNodeRole.DATA_COLD_NODE_ROLE)); + + for (final var newDesiredNode : newDesiredNodesInLeastPreferredTiers) { + if (newDesiredNode.getRoles().contains(DiscoveryNodeRole.DATA_HOT_NODE_ROLE)) { + startHotOnlyNode(newDesiredNode.externalId()); + } else { + startWarmOnlyNode(newDesiredNode.externalId()); + } + } + + ensureGreen(index); + + assertBusy(() -> assertPrimaryShardIsAllocatedInNodeWithRole(0, DiscoveryNodeRole.DATA_WARM_NODE_ROLE)); + } + + public void testSimpleAllocationDecisionWithDesiredNodes() { + final var warmDesiredNode = desiredNode("warn-node-0", DiscoveryNodeRole.DATA_WARM_NODE_ROLE); + final var warmDesiredNode2 = desiredNode("warn-node-1", DiscoveryNodeRole.DATA_WARM_NODE_ROLE); + final var masterDesiredNode = desiredNode(internalCluster().getMasterName(), DiscoveryNodeRole.MASTER_ROLE); + updateDesiredNodes(warmDesiredNode, warmDesiredNode2, masterDesiredNode); + + startWarmOnlyNode(warmDesiredNode.externalId()); + + createIndexWithTierPreference(DataTier.DATA_COLD, DataTier.DATA_WARM); + + ensureGreen(index); + + assertPrimaryShardIsAllocatedInNodeWithRole(0, DiscoveryNodeRole.DATA_WARM_NODE_ROLE); + } + + public void testGrowAndShrinkSingleNodeInTier() throws Exception { + final var warmDesiredNode = desiredNode("warm-node", DiscoveryNodeRole.DATA_WARM_NODE_ROLE); + final var coldDesiredNode = desiredNode("cold-node-1", DiscoveryNodeRole.DATA_COLD_NODE_ROLE); + final var masterDesiredNode = desiredNode(internalCluster().getMasterName(), DiscoveryNodeRole.MASTER_ROLE); + updateDesiredNodes(warmDesiredNode, coldDesiredNode, masterDesiredNode); + + startWarmOnlyNode(warmDesiredNode.externalId()); + var coldNodeName = startColdOnlyNode(coldDesiredNode.externalId()); + + createIndexWithTierPreference(DataTier.DATA_COLD, DataTier.DATA_WARM); + + ensureGreen(index); + + assertPrimaryShardIsAllocatedInNodeWithRole(0, DiscoveryNodeRole.DATA_COLD_NODE_ROLE); + + final var newColdDesiredNode = desiredNode("cold-node-2", DiscoveryNodeRole.DATA_COLD_NODE_ROLE); + updateDesiredNodes(warmDesiredNode, newColdDesiredNode, masterDesiredNode); + + // Exclude the node that we want to decommission, so it can move to the new cold node + client().admin() + .indices() + .prepareUpdateSettings(index) + .setSettings(Settings.builder().put("index.routing.allocation.exclude._name", coldNodeName).build()) + .get(); + + assertBusy(() -> assertPrimaryShardIsAllocatedInNodeWithRole(0, DiscoveryNodeRole.DATA_COLD_NODE_ROLE)); + + startColdOnlyNode(newColdDesiredNode.externalId()); + + ensureGreen(index); + + assertBusy(() -> assertPrimaryShardIsAllocatedInNode(0, newColdDesiredNode)); + + internalCluster().stopNode(coldNodeName); + + ensureGreen(index); + } + + public void testDesiredNodesAreTakenIntoAccountInAutoExpandReplicas() throws Exception { + final var masterDesiredNode = desiredNode(internalCluster().getMasterName(), DiscoveryNodeRole.MASTER_ROLE); + final int numberOfColdNodes = randomIntBetween(2, 5); + final List coldDesiredNodes = new ArrayList<>(); + for (int i = 0; i < numberOfColdNodes; i++) { + final var coldDesiredNode = desiredNode("cold-node-" + i, DiscoveryNodeRole.DATA_COLD_NODE_ROLE); + coldDesiredNodes.add(coldDesiredNode); + startColdOnlyNode(coldDesiredNode.externalId()); + } + final int numberOfWarmNodes = randomIntBetween(numberOfColdNodes + 1, 10); + final List warmDesiredNodes = new ArrayList<>(); + for (int i = 0; i < numberOfWarmNodes; i++) { + final var warmDesiredNode = desiredNode("warm-node-" + i, DiscoveryNodeRole.DATA_WARM_NODE_ROLE); + warmDesiredNodes.add(warmDesiredNode); + startWarmOnlyNode(warmDesiredNode.externalId()); + } + final List desiredNodesWithWarmAndColdTier = new ArrayList<>(); + desiredNodesWithWarmAndColdTier.addAll(warmDesiredNodes); + desiredNodesWithWarmAndColdTier.addAll(coldDesiredNodes); + desiredNodesWithWarmAndColdTier.add(masterDesiredNode); + + updateDesiredNodes(desiredNodesWithWarmAndColdTier); + + client().admin() + .indices() + .prepareCreate(index) + .setWaitForActiveShards(0) + .setSettings( + Settings.builder() + .put(DataTier.TIER_PREFERENCE, String.join(",", DataTier.DATA_COLD, DataTier.DATA_WARM)) + .put(INDEX_NUMBER_OF_REPLICAS_SETTING.getKey(), 0) + .put(INDEX_AUTO_EXPAND_REPLICAS_SETTING.getKey(), "0-all") + ) + .get(); + + var replicas = client().admin() + .indices() + .prepareGetIndex() + .setIndices(index) + .get() + .getSetting(index, INDEX_NUMBER_OF_REPLICAS_SETTING.getKey()); + + assertThat(Integer.parseInt(replicas), is(equalTo(numberOfColdNodes - 1))); + + final List desiredNodesWithoutColdTier = new ArrayList<>(warmDesiredNodes); + desiredNodesWithoutColdTier.add(masterDesiredNode); + + updateDesiredNodes(desiredNodesWithoutColdTier); + + assertBusy(() -> { + var newReplicaCount = client().admin() + .indices() + .prepareGetIndex() + .setIndices(index) + .get() + .getSetting(index, INDEX_NUMBER_OF_REPLICAS_SETTING.getKey()); + + assertThat(Integer.parseInt(newReplicaCount), is(equalTo(numberOfWarmNodes - 1))); + }); + } + public void testOverrideDefaultAllocation() { startWarmOnlyNode(); startColdOnlyNode(); @@ -293,27 +509,50 @@ public void startContentOnlyNode() { } public void startHotOnlyNode() { - Settings nodeSettings = Settings.builder() + startHotOnlyNode(null); + } + + public void startHotOnlyNode(@Nullable String externalId) { + Settings.Builder nodeSettings = Settings.builder() .putList("node.roles", Arrays.asList("master", "data_hot", "ingest")) - .put("node.attr.box", "hot") - .build(); + .put("node.attr.box", "hot"); + + if (externalId != null) { + nodeSettings.put(NODE_EXTERNAL_ID_SETTING.getKey(), externalId); + } + internalCluster().startNode(nodeSettings); } public void startWarmOnlyNode() { - Settings nodeSettings = Settings.builder() + startWarmOnlyNode(null); + } + + public String startWarmOnlyNode(@Nullable String externalId) { + Settings.Builder nodeSettings = Settings.builder() .putList("node.roles", Arrays.asList("master", "data_warm", "ingest")) - .put("node.attr.box", "warm") - .build(); - internalCluster().startNode(nodeSettings); + .put("node.attr.box", "warm"); + + if (externalId != null) { + nodeSettings.put(NODE_EXTERNAL_ID_SETTING.getKey(), externalId); + } + return internalCluster().startNode(nodeSettings); } public void startColdOnlyNode() { - Settings nodeSettings = Settings.builder() + startColdOnlyNode(null); + } + + public String startColdOnlyNode(@Nullable String externalId) { + Settings.Builder nodeSettings = Settings.builder() .putList("node.roles", Arrays.asList("master", "data_cold", "ingest")) - .put("node.attr.box", "cold") - .build(); - internalCluster().startNode(nodeSettings); + .put("node.attr.box", "cold"); + + if (externalId != null) { + nodeSettings.put(NODE_EXTERNAL_ID_SETTING.getKey(), externalId); + } + + return internalCluster().startNode(nodeSettings); } public void startFrozenOnlyNode() { @@ -323,4 +562,68 @@ public void startFrozenOnlyNode() { .build(); internalCluster().startNode(nodeSettings); } + + private DesiredNode desiredNode(String externalId, DiscoveryNodeRole... roles) { + assertThat(roles.length, is(greaterThan(0))); + + final var nodeRoles = Arrays.stream(roles).map(DiscoveryNodeRole::roleName).collect(Collectors.joining(",")); + final var settings = Settings.builder() + .put(NODE_ROLES_SETTING.getKey(), nodeRoles) + .put(NODE_EXTERNAL_ID_SETTING.getKey(), externalId) + .put(NODE_NAME_SETTING.getKey(), externalId) + .build(); + return new DesiredNode(settings, 1, ByteSizeValue.ONE, ByteSizeValue.ONE, Version.CURRENT); + } + + private void updateDesiredNodes(DesiredNode... desiredNodes) { + assertThat(desiredNodes.length, is(greaterThan(0))); + updateDesiredNodes(Arrays.asList(desiredNodes)); + } + + private void updateDesiredNodes(List desiredNodes) { + assertThat(desiredNodes.size(), is(greaterThan(0))); + + final var request = new UpdateDesiredNodesRequest(randomAlphaOfLength(10), 1, desiredNodes); + internalCluster().client().execute(UpdateDesiredNodesAction.INSTANCE, request).actionGet(); + } + + private void assertPrimaryShardIsAllocatedInNodeWithRole(int shard, DiscoveryNodeRole expectedRole) { + final var discoveryNode = getPrimaryShardAssignedNode(shard); + assertThat(explainAllocation(shard), discoveryNode.getRoles().contains(expectedRole), is(true)); + } + + private void assertPrimaryShardIsAllocatedInNode(int shard, DesiredNode expectedNode) { + final var discoveryNode = getPrimaryShardAssignedNode(shard); + assertThat(explainAllocation(shard), discoveryNode.getExternalId(), is(equalTo(expectedNode.externalId()))); + } + + private DiscoveryNode getPrimaryShardAssignedNode(int shard) { + final var state = client().admin().cluster().prepareState().get().getState(); + final var routingTable = state.routingTable().index(index).shard(shard); + final var primaryShard = routingTable.primaryShard(); + final var discoveryNode = state.nodes().get(primaryShard.currentNodeId()); + assertThat(discoveryNode, is(notNullValue())); + return discoveryNode; + } + + private String explainAllocation(int shard) { + return Strings.toString( + client().admin().cluster().prepareAllocationExplain().setIndex(index).setShard(shard).setPrimary(true).get().getExplanation(), + true, + true + ); + } + + private void createIndexWithTierPreference(String... tiers) { + assertThat(tiers.length, is(greaterThan(0))); + + client().admin() + .indices() + .prepareCreate(index) + .setWaitForActiveShards(0) + .setSettings( + Settings.builder().put(DataTier.TIER_PREFERENCE, String.join(",", tiers)).put(INDEX_NUMBER_OF_REPLICAS_SETTING.getKey(), 0) + ) + .get(); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDecider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDecider.java index 59704538b3a4a..cb2260a4950fd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDecider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDecider.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.cluster.routing.allocation; +import org.elasticsearch.cluster.metadata.DesiredNode; +import org.elasticsearch.cluster.metadata.DesiredNodes; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeRole; @@ -19,6 +21,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.Decision; import org.elasticsearch.common.Strings; +import java.util.Collection; import java.util.List; import java.util.Optional; import java.util.Set; @@ -65,7 +68,7 @@ private static Decision shouldFilter(IndexMetadata indexMd, Set apply(List tierPreference, DiscoveryNodes nodes); + Optional apply(List tierPreference, DiscoveryNodes nodes, DesiredNodes desiredNodes); } private static final Decision YES_PASSES = Decision.single(Decision.YES.type(), NAME, "node passes tier preference filters"); @@ -80,7 +83,7 @@ public static Decision shouldFilter( if (tierPreference.isEmpty() != false) { return YES_PASSES; } - Optional tier = preferredTierFunction.apply(tierPreference, allocation.nodes()); + Optional tier = preferredTierFunction.apply(tierPreference, allocation.nodes(), allocation.desiredNodes()); if (tier.isPresent()) { String tierName = tier.get(); if (allocationAllowed(tierName, roles)) { @@ -133,9 +136,74 @@ private static Decision debugYesAllowed(RoutingAllocation allocation, List}. + * {@code Optional}. This method takes into account the desired nodes + * in order to know if there are planned topology changes in the cluster + * that can remove a tier that's part of the cluster now. */ - public static Optional preferredAvailableTier(List prioritizedTiers, DiscoveryNodes nodes) { + public static Optional preferredAvailableTier(List prioritizedTiers, DiscoveryNodes nodes, DesiredNodes desiredNodes) { + final var desiredNodesPreferredTier = getPreferredTierFromDesiredNodes(prioritizedTiers, nodes, desiredNodes); + + if (desiredNodesPreferredTier.isPresent()) { + return desiredNodesPreferredTier; + } + + return getPreferredAvailableTierFromClusterMembers(prioritizedTiers, nodes); + } + + /** + * Given a list of tiers in descending order, return the tier that's present + * in the desired nodes with the highest priority, if none is present returns an + * {@code Optional.empty()}. + */ + public static Optional getPreferredTierFromDesiredNodes( + List prioritizedTiers, + DiscoveryNodes discoveryNodes, + DesiredNodes desiredNodes + ) { + if (desiredNodes == null) { + return Optional.empty(); + } + + for (int tierIndex = 0; tierIndex < prioritizedTiers.size(); tierIndex++) { + final var tier = prioritizedTiers.get(tierIndex); + if (tierNodesPresent(tier, desiredNodes.actualized()) + || isDesiredNodeWithinTierJoining(tier, discoveryNodes, desiredNodes) + || nextTierIsGrowingAndCurrentTierCanHoldTheIndex(prioritizedTiers, tierIndex, discoveryNodes, desiredNodes)) { + return Optional.of(tier); + } + } + return Optional.empty(); + } + + private static boolean nextTierIsGrowingAndCurrentTierCanHoldTheIndex( + List prioritizedTiers, + int tierIndex, + DiscoveryNodes discoveryNodes, + DesiredNodes desiredNodes + ) { + final var tier = prioritizedTiers.get(tierIndex); + assert tierNodesPresent(tier, desiredNodes.actualized()) == false; + // If there's a plan to grow the next preferred tier, and it hasn't materialized yet, + // wait until all the nodes in the next tier have joined. This would avoid overwhelming + // the next tier if within the same plan one tier is removed and the next preferred tier + // grows. + boolean nextPreferredTierIsGrowing = false; + for (int i = tierIndex + 1; i < prioritizedTiers.size(); i++) { + final var nextTier = prioritizedTiers.get(i); + nextPreferredTierIsGrowing |= tierNodesPresent(nextTier, desiredNodes.pending()); + } + return tierNodesPresent(tier, discoveryNodes) && nextPreferredTierIsGrowing; + } + + private static boolean isDesiredNodeWithinTierJoining(String tier, DiscoveryNodes discoveryNodes, DesiredNodes desiredNodes) { + assert tierNodesPresent(tier, desiredNodes.actualized()) == false; + // Take into account the case when the desired nodes have been updated and the node in the tier would be replaced by + // a new one. In that case the desired node in the tier won't be actualized as it has to join, but we still need to ensure + // that at least one cluster member has the requested tier as we would prefer to minimize the shard movements in these cases. + return tierNodesPresent(tier, desiredNodes.pending()) && tierNodesPresent(tier, discoveryNodes); + } + + private static Optional getPreferredAvailableTierFromClusterMembers(List prioritizedTiers, DiscoveryNodes nodes) { for (String tier : prioritizedTiers) { if (tierNodesPresent(tier, nodes)) { return Optional.of(tier); @@ -144,15 +212,16 @@ public static Optional preferredAvailableTier(List prioritizedTi return Optional.empty(); } + static boolean tierNodesPresent(String singleTier, Collection nodes) { + assert singleTier.equals(DiscoveryNodeRole.DATA_ROLE.roleName()) || DataTier.validTierName(singleTier) + : "tier " + singleTier + " is an invalid tier name"; + return nodes.stream().anyMatch(node -> allocationAllowed(singleTier, node.getRoles())); + } + static boolean tierNodesPresent(String singleTier, DiscoveryNodes nodes) { assert singleTier.equals(DiscoveryNodeRole.DATA_ROLE.roleName()) || DataTier.validTierName(singleTier) : "tier " + singleTier + " is an invalid tier name"; - for (DiscoveryNode node : nodes) { - if (allocationAllowed(singleTier, node.getRoles())) { - return true; - } - } - return false; + return nodes.stream().anyMatch(node -> allocationAllowed(singleTier, node.getRoles())); } private static boolean allocationAllowed(String tierName, Set roles) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/DataTierMigrationRoutedStep.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/DataTierMigrationRoutedStep.java index 7a6ccc1300f19..4385326a7c9b2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/DataTierMigrationRoutedStep.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/DataTierMigrationRoutedStep.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.action.support.ActiveShardCount; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.DesiredNodes; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders; import org.elasticsearch.index.Index; @@ -54,7 +55,8 @@ public Result isConditionMet(Index index, ClusterState clusterState) { List preferredTierConfiguration = idxMeta.getTierPreference(); Optional availableDestinationTier = DataTierAllocationDecider.preferredAvailableTier( preferredTierConfiguration, - clusterState.getNodes() + clusterState.getNodes(), + DesiredNodes.latestFromClusterState(clusterState) ); if (ActiveShardCount.ALL.enoughShardsActive(clusterState, index.getName()) == false) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/WaitForDataTierStep.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/WaitForDataTierStep.java index 0b557e7c3e034..e326f591c64cd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/WaitForDataTierStep.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/WaitForDataTierStep.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ilm; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.DesiredNodes; import org.elasticsearch.cluster.routing.allocation.DataTier; import org.elasticsearch.index.Index; import org.elasticsearch.xpack.cluster.routing.allocation.DataTierAllocationDecider; @@ -33,8 +34,11 @@ public WaitForDataTierStep(StepKey key, StepKey nextStepKey, String tierPreferen @Override public Result isConditionMet(Index index, ClusterState clusterState) { - boolean present = DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList(tierPreference), clusterState.nodes()) - .isPresent(); + boolean present = DataTierAllocationDecider.preferredAvailableTier( + DataTier.parseTierList(tierPreference), + clusterState.nodes(), + DesiredNodes.latestFromClusterState(clusterState) + ).isPresent(); SingleMessageFieldInfo info = present ? null : new SingleMessageFieldInfo("no nodes for tiers [" + tierPreference + "] available"); return new Result(present, info); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAssignmentStateAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAssignmentRoutingInfoAction.java similarity index 72% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAssignmentStateAction.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAssignmentRoutingInfoAction.java index 5b48777885da5..a8a56f09c7801 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAssignmentStateAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAssignmentRoutingInfoAction.java @@ -13,36 +13,36 @@ import org.elasticsearch.action.support.master.MasterNodeRequest; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfoUpdate; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.Objects; -public class UpdateTrainedModelAssignmentStateAction extends ActionType { - public static final UpdateTrainedModelAssignmentStateAction INSTANCE = new UpdateTrainedModelAssignmentStateAction(); +public class UpdateTrainedModelAssignmentRoutingInfoAction extends ActionType { + public static final UpdateTrainedModelAssignmentRoutingInfoAction INSTANCE = new UpdateTrainedModelAssignmentRoutingInfoAction(); public static final String NAME = "cluster:internal/xpack/ml/model_allocation/update"; - private UpdateTrainedModelAssignmentStateAction() { + private UpdateTrainedModelAssignmentRoutingInfoAction() { super(NAME, AcknowledgedResponse::readFrom); } public static class Request extends MasterNodeRequest { private final String nodeId; private final String modelId; - private final RoutingStateAndReason routingState; + private final RoutingInfoUpdate update; - public Request(String nodeId, String modelId, RoutingStateAndReason routingState) { + public Request(String nodeId, String modelId, RoutingInfoUpdate update) { this.nodeId = ExceptionsHelper.requireNonNull(nodeId, "node_id"); this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id"); - this.routingState = ExceptionsHelper.requireNonNull(routingState, "routing_state"); + this.update = ExceptionsHelper.requireNonNull(update, "update"); } public Request(StreamInput in) throws IOException { super(in); this.nodeId = in.readString(); this.modelId = in.readString(); - this.routingState = new RoutingStateAndReason(in); + this.update = new RoutingInfoUpdate(in); } public String getNodeId() { @@ -53,8 +53,8 @@ public String getModelId() { return modelId; } - public RoutingStateAndReason getRoutingState() { - return routingState; + public RoutingInfoUpdate getUpdate() { + return update; } @Override @@ -67,7 +67,7 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(nodeId); out.writeString(modelId); - routingState.writeTo(out); + update.writeTo(out); } @Override @@ -77,17 +77,17 @@ public boolean equals(Object o) { Request request = (Request) o; return Objects.equals(nodeId, request.nodeId) && Objects.equals(modelId, request.modelId) - && Objects.equals(routingState, request.routingState); + && Objects.equals(update, request.update); } @Override public int hashCode() { - return Objects.hash(nodeId, modelId, routingState); + return Objects.hash(nodeId, modelId, update); } @Override public String toString() { - return "Request{" + "nodeId='" + nodeId + '\'' + ", modelId='" + modelId + '\'' + ", routingState=" + routingState + '}'; + return "Request{" + "nodeId='" + nodeId + '\'' + ", modelId='" + modelId + '\'' + ", update=" + update + '}'; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfo.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfo.java new file mode 100644 index 0000000000000..967634aea63b9 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfo.java @@ -0,0 +1,155 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.assignment; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class RoutingInfo implements ToXContentObject, Writeable { + + private static final ParseField CURRENT_ALLOCATIONS = new ParseField("current_allocations"); + private static final ParseField TARGET_ALLOCATIONS = new ParseField("target_allocations"); + private static final ParseField ROUTING_STATE = new ParseField("routing_state"); + private static final ParseField REASON = new ParseField("reason"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "trained_model_routing_state", + a -> new RoutingInfo((Integer) a[0], (Integer) a[1], RoutingState.fromString((String) a[2]), (String) a[3]) + ); + static { + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), CURRENT_ALLOCATIONS); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), TARGET_ALLOCATIONS); + PARSER.declareString(ConstructingObjectParser.constructorArg(), ROUTING_STATE); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), REASON); + } + + public static RoutingInfo fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final int currentAllocations; + private final int targetAllocations; + private final RoutingState state; + private final String reason; + + // There may be objects in cluster state prior to 8.4 that do not contain values for currentAllocations and targetAllocations. + private RoutingInfo( + @Nullable Integer currentAllocations, + @Nullable Integer targetAllocations, + RoutingState state, + @Nullable String reason + ) { + this(currentAllocations == null ? 0 : currentAllocations, targetAllocations == null ? 0 : targetAllocations, state, reason); + } + + public RoutingInfo(int currentAllocations, int targetAllocations, RoutingState state, String reason) { + this.currentAllocations = currentAllocations; + this.targetAllocations = targetAllocations; + this.state = ExceptionsHelper.requireNonNull(state, ROUTING_STATE); + this.reason = reason; + } + + public RoutingInfo(StreamInput in) throws IOException { + if (in.getVersion().onOrAfter(Version.V_8_4_0)) { + this.currentAllocations = in.readVInt(); + this.targetAllocations = in.readVInt(); + } else { + this.currentAllocations = 0; + this.targetAllocations = 0; + } + this.state = in.readEnum(RoutingState.class); + this.reason = in.readOptionalString(); + } + + public int getCurrentAllocations() { + return currentAllocations; + } + + public int getTargetAllocations() { + return targetAllocations; + } + + public RoutingState getState() { + return state; + } + + @Nullable + public String getReason() { + return reason; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (out.getVersion().onOrAfter(Version.V_8_4_0)) { + out.writeVInt(currentAllocations); + out.writeVInt(targetAllocations); + } + out.writeEnum(state); + out.writeOptionalString(reason); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CURRENT_ALLOCATIONS.getPreferredName(), currentAllocations); + builder.field(TARGET_ALLOCATIONS.getPreferredName(), targetAllocations); + builder.field(ROUTING_STATE.getPreferredName(), state); + if (reason != null) { + builder.field(REASON.getPreferredName(), reason); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RoutingInfo that = (RoutingInfo) o; + return currentAllocations == that.currentAllocations + && targetAllocations == that.targetAllocations + && state == that.state + && Objects.equals(reason, that.reason); + } + + @Override + public int hashCode() { + return Objects.hash(currentAllocations, targetAllocations, state, reason); + } + + @Override + public String toString() { + return "RoutingInfo{" + + "current_allocations=" + + currentAllocations + + ", target_allocations=" + + targetAllocations + + ", reason='" + + reason + + '\'' + + ", state=" + + state + + '}'; + } + + public boolean isRoutable() { + return state == RoutingState.STARTED && currentAllocations > 0; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfoUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfoUpdate.java new file mode 100644 index 0000000000000..ce08941dfe02d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfoUpdate.java @@ -0,0 +1,94 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.assignment; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Objects; +import java.util.Optional; + +public class RoutingInfoUpdate implements Writeable { + + private final Optional numberOfAllocations; + private final Optional stateAndReason; + + public static RoutingInfoUpdate updateNumberOfAllocations(int numberOfAllocations) { + return new RoutingInfoUpdate(Optional.of(numberOfAllocations), Optional.empty()); + } + + public static RoutingInfoUpdate updateStateAndReason(RoutingStateAndReason routingStateAndReason) { + return new RoutingInfoUpdate(Optional.empty(), Optional.of(routingStateAndReason)); + } + + private RoutingInfoUpdate(Optional numberOfAllocations, Optional stateAndReason) { + this.numberOfAllocations = Objects.requireNonNull(numberOfAllocations); + this.stateAndReason = Objects.requireNonNull(stateAndReason); + } + + public RoutingInfoUpdate(StreamInput in) throws IOException { + if (in.getVersion().onOrAfter(Version.V_8_4_0)) { + numberOfAllocations = Optional.ofNullable(in.readOptionalVInt()); + stateAndReason = Optional.ofNullable(in.readOptionalWriteable(RoutingStateAndReason::new)); + } else { + numberOfAllocations = Optional.empty(); + stateAndReason = Optional.of(new RoutingStateAndReason(in)); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (out.getVersion().onOrAfter(Version.V_8_4_0)) { + out.writeOptionalVInt(numberOfAllocations.orElse(null)); + out.writeOptionalWriteable(stateAndReason.orElse(null)); + } else { + assert stateAndReason.isPresent() : "updating routing info while nodes prior to 8.4.0 should only contain state and reason"; + stateAndReason.get().writeTo(out); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RoutingInfoUpdate that = (RoutingInfoUpdate) o; + return Objects.equals(numberOfAllocations, that.numberOfAllocations) && Objects.equals(stateAndReason, that.stateAndReason); + } + + @Override + public int hashCode() { + return Objects.hash(numberOfAllocations, stateAndReason); + } + + @Override + public String toString() { + return "RoutingInfoUpdate{" + "numberOfAllocations=" + numberOfAllocations + ", stateAndReason=" + stateAndReason + '}'; + } + + public Optional getNumberOfAllocations() { + return numberOfAllocations; + } + + public Optional getStateAndReason() { + return stateAndReason; + } + + public RoutingInfo apply(RoutingInfo routingInfo) { + int currentAllocations = numberOfAllocations.orElse(routingInfo.getCurrentAllocations()); + RoutingState state = routingInfo.getState(); + String reason = routingInfo.getReason(); + if (stateAndReason.isPresent()) { + state = stateAndReason.get().getState(); + reason = stateAndReason.get().getReason(); + } + return new RoutingInfo(currentAllocations, routingInfo.getTargetAllocations(), state, reason); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java index a600f34ccadfd..f66c88f2fe08d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java @@ -10,7 +10,7 @@ import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.cluster.SimpleDiffable; -import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.Randomness; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xcontent.ConstructingObjectParser; @@ -25,12 +25,14 @@ import java.io.IOException; import java.time.Instant; +import java.util.ArrayList; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; // TODO implement better diffable logic so that whole diff does not need to be serialized if only one part changes /** @@ -53,7 +55,7 @@ public class TrainedModelAssignment implements SimpleDiffable new TrainedModelAssignment( (StartTrainedModelDeploymentAction.TaskParams) a[0], - (Map) a[1], + (Map) a[1], a[2] == null ? null : AssignmentState.fromString((String) a[2]), a[3] == null ? null : AssignmentState.fromString((String) a[3]), (String) a[4], @@ -68,7 +70,7 @@ public class TrainedModelAssignment implements SimpleDiffable p.map(LinkedHashMap::new, RoutingStateAndReason::fromXContent), + (p, c) -> p.map(LinkedHashMap::new, RoutingInfo::fromXContent), ROUTING_TABLE ); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), ASSIGNMENT_STATE); @@ -83,7 +85,7 @@ public class TrainedModelAssignment implements SimpleDiffable nodeRoutingTable; + private final Map nodeRoutingTable; private final AssignmentState assignmentState; private final String reason; private final Instant startTime; @@ -94,7 +96,7 @@ public static TrainedModelAssignment fromXContent(XContentParser parser) throws private TrainedModelAssignment( StartTrainedModelDeploymentAction.TaskParams taskParams, - Map nodeRoutingTable, + Map nodeRoutingTable, AssignmentState assignmentState, AssignmentState legacyAssignmentState, String reason, @@ -105,7 +107,7 @@ private TrainedModelAssignment( TrainedModelAssignment( StartTrainedModelDeploymentAction.TaskParams taskParams, - Map nodeRoutingTable, + Map nodeRoutingTable, AssignmentState assignmentState, String reason, Instant startTime @@ -119,7 +121,7 @@ private TrainedModelAssignment( public TrainedModelAssignment(StreamInput in) throws IOException { this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(in); - this.nodeRoutingTable = in.readOrderedMap(StreamInput::readString, RoutingStateAndReason::new); + this.nodeRoutingTable = in.readOrderedMap(StreamInput::readString, RoutingInfo::new); this.assignmentState = in.readEnum(AssignmentState.class); this.reason = in.readOptionalString(); this.startTime = in.readInstant(); @@ -129,7 +131,7 @@ public boolean isRoutedToNode(String nodeId) { return nodeRoutingTable.containsKey(nodeId); } - public Map getNodeRoutingTable() { + public Map getNodeRoutingTable() { return Collections.unmodifiableMap(nodeRoutingTable); } @@ -153,6 +155,32 @@ public String[] getStartedNodes() { .toArray(String[]::new); } + public Optional selectRandomStartedNodeWeighedOnAllocations() { + List nodeIds = new ArrayList<>(nodeRoutingTable.size()); + List cumulativeAllocations = new ArrayList<>(nodeRoutingTable.size()); + int allocationSum = 0; + for (Map.Entry routingEntry : nodeRoutingTable.entrySet()) { + if (RoutingState.STARTED.equals(routingEntry.getValue().getState())) { + nodeIds.add(routingEntry.getKey()); + allocationSum += routingEntry.getValue().getCurrentAllocations(); + cumulativeAllocations.add(allocationSum); + } + } + + if (allocationSum == 0) { + // If we are in a mixed cluster where there are assignments prior to introducing allocation distribution + // we could have a zero-sum of allocations. We fall back to returning a random started node. + return nodeIds.isEmpty() ? Optional.empty() : Optional.of(nodeIds.get(Randomness.get().nextInt(nodeIds.size()))); + } + + int randomInt = Randomness.get().ints(1, 1, allocationSum + 1).iterator().nextInt(); + int nodeIndex = Collections.binarySearch(cumulativeAllocations, randomInt); + if (nodeIndex < 0) { + nodeIndex = -nodeIndex - 1; + } + return Optional.of(nodeIds.get(nodeIndex)); + } + public Optional getReason() { return Optional.ofNullable(reason); } @@ -161,6 +189,16 @@ public Instant getStartTime() { return startTime; } + public boolean isSatisfied(Set assignableNodeIds) { + int allocations = nodeRoutingTable.entrySet() + .stream() + .filter(e -> assignableNodeIds.contains(e.getKey())) + .filter(e -> e.getValue().getState().isAnyOf(RoutingState.STARTING, RoutingState.STARTED)) + .mapToInt(e -> e.getValue().getTargetAllocations()) + .sum(); + return allocations >= taskParams.getNumberOfAllocations(); + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -201,31 +239,22 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInstant(startTime); } - public Optional calculateAllocationStatus(List allocatableNodes) { + public Optional calculateAllocationStatus() { if (assignmentState.equals(AssignmentState.STOPPING)) { return Optional.empty(); } - int numAllocatableNodes = 0; - int numStarted = 0; - for (DiscoveryNode node : allocatableNodes) { - if (StartTrainedModelDeploymentAction.TaskParams.mayAssignToNode(node)) { - RoutingState nodeState = Optional.ofNullable(nodeRoutingTable.get(node.getId())) - .map(RoutingStateAndReason::getState) - .orElse(RoutingState.STOPPED); - numAllocatableNodes++; - if (nodeState.equals(RoutingState.STARTED)) { - numStarted++; - } - } - } - return Optional.of(new AllocationStatus(numStarted, numAllocatableNodes)); + int numStarted = nodeRoutingTable.values() + .stream() + .filter(RoutingInfo::isRoutable) + .mapToInt(RoutingInfo::getCurrentAllocations) + .sum(); + return Optional.of(new AllocationStatus(numStarted, taskParams.getNumberOfAllocations())); } public static class Builder { - private final Map nodeRoutingTable; + private final Map nodeRoutingTable; private final StartTrainedModelDeploymentAction.TaskParams taskParams; private AssignmentState assignmentState; - private boolean isChanged; private String reason; private Instant startTime; @@ -245,7 +274,7 @@ public static Builder empty(StartTrainedModelDeploymentAction.TaskParams taskPar private Builder( StartTrainedModelDeploymentAction.TaskParams taskParams, - Map nodeRoutingTable, + Map nodeRoutingTable, AssignmentState assignmentState, String reason, Instant startTime @@ -261,7 +290,7 @@ private Builder(StartTrainedModelDeploymentAction.TaskParams taskParams) { this(taskParams, new LinkedHashMap<>(), AssignmentState.STARTING, null, Instant.now()); } - public Builder addNewRoutingEntry(String nodeId) { + public Builder addRoutingEntry(String nodeId, RoutingInfo routingInfo) { if (nodeRoutingTable.containsKey(nodeId)) { throw new ResourceAlreadyExistsException( "routing entry for node [{}] for model [{}] already exists", @@ -269,51 +298,28 @@ public Builder addNewRoutingEntry(String nodeId) { taskParams.getModelId() ); } - isChanged = true; - nodeRoutingTable.put(nodeId, new RoutingStateAndReason(RoutingState.STARTING, "")); + nodeRoutingTable.put(nodeId, routingInfo); return this; } - // For testing purposes - Builder addRoutingEntry(String nodeId, RoutingState state) { - nodeRoutingTable.put(nodeId, new RoutingStateAndReason(state, "")); - return this; - } - - public Builder addNewFailedRoutingEntry(String nodeId, String failureReason) { - if (nodeRoutingTable.containsKey(nodeId)) { - throw new ResourceAlreadyExistsException( - "routing entry for node [{}] for model [{}] already exists", - nodeId, - taskParams.getModelId() - ); - } - isChanged = true; - nodeRoutingTable.put(nodeId, new RoutingStateAndReason(RoutingState.FAILED, failureReason)); - return this; - } - - public Builder updateExistingRoutingEntry(String nodeId, RoutingStateAndReason state) { - RoutingStateAndReason stateAndReason = nodeRoutingTable.get(nodeId); - if (stateAndReason == null) { + public Builder updateExistingRoutingEntry(String nodeId, RoutingInfo routingInfo) { + RoutingInfo existingRoutingInfo = nodeRoutingTable.get(nodeId); + if (existingRoutingInfo == null) { throw new ResourceNotFoundException( "routing entry for node [{}] for model [{}] does not exist", nodeId, taskParams.getModelId() ); } - if (stateAndReason.equals(state)) { + if (existingRoutingInfo.equals(routingInfo)) { return this; } - nodeRoutingTable.put(nodeId, state); - isChanged = true; + nodeRoutingTable.put(nodeId, routingInfo); return this; } public Builder removeRoutingEntry(String nodeId) { - if (nodeRoutingTable.remove(nodeId) != null) { - isChanged = true; - } + nodeRoutingTable.remove(nodeId); return this; } @@ -321,7 +327,6 @@ public Builder setReason(String reason) { if (Objects.equals(reason, this.reason)) { return this; } - isChanged = true; this.reason = reason; return this; } @@ -330,7 +335,6 @@ public Builder stopAssignment(String stopReason) { if (assignmentState.equals(AssignmentState.STOPPING)) { return this; } - isChanged = true; this.reason = stopReason; assignmentState = AssignmentState.STOPPING; return this; @@ -357,7 +361,6 @@ public Builder setAssignmentState(AssignmentState state) { if (assignmentState.equals(state)) { return this; } - isChanged = true; assignmentState = state; return this; } @@ -366,18 +369,12 @@ public Builder clearReason() { if (this.reason == null) { return this; } - isChanged = true; reason = null; return this; } - public boolean isChanged() { - return isChanged; - } - public TrainedModelAssignment build() { return new TrainedModelAssignment(taskParams, nodeRoutingTable, assignmentState, reason, startTime); } } - } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDeciderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDeciderTests.java index aa8c9eefec378..712053287c683 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDeciderTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDeciderTests.java @@ -10,23 +10,24 @@ import joptsimple.internal.Strings; import org.elasticsearch.Version; +import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ESAllocationTestCase; -import org.elasticsearch.cluster.EmptyClusterInfoService; +import org.elasticsearch.cluster.metadata.DesiredNode; +import org.elasticsearch.cluster.metadata.DesiredNodeWithStatus; +import org.elasticsearch.cluster.metadata.DesiredNodes; +import org.elasticsearch.cluster.metadata.DesiredNodesMetadata; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.routing.RecoverySource; -import org.elasticsearch.cluster.routing.RoutingNode; import org.elasticsearch.cluster.routing.RoutingNodesHelper; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.routing.UnassignedInfo; -import org.elasticsearch.cluster.routing.allocation.AllocationService; import org.elasticsearch.cluster.routing.allocation.DataTier; import org.elasticsearch.cluster.routing.allocation.RoutingAllocation; -import org.elasticsearch.cluster.routing.allocation.allocator.BalancedShardsAllocator; import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders; import org.elasticsearch.cluster.routing.allocation.decider.Decision; import org.elasticsearch.cluster.routing.allocation.decider.ReplicaAfterPrimaryActiveAllocationDecider; @@ -35,20 +36,24 @@ import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.index.IndexModule; import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.snapshots.EmptySnapshotsInfoService; import org.elasticsearch.snapshots.SearchableSnapshotsSettings; -import org.elasticsearch.test.gateway.TestGatewayAllocator; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; import static org.elasticsearch.cluster.routing.allocation.DataTier.DATA_COLD; import static org.elasticsearch.cluster.routing.allocation.DataTier.DATA_FROZEN; +import static org.elasticsearch.node.Node.NODE_EXTERNAL_ID_SETTING; +import static org.elasticsearch.node.NodeRoleSettings.NODE_ROLES_SETTING; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -63,6 +68,12 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase { ); private static final DiscoveryNode DATA_NODE = newNode("node-data", Collections.singleton(DiscoveryNodeRole.DATA_ROLE)); + private static final DesiredNode HOT_DESIRED_NODE = newDesiredNode("node-hot", DiscoveryNodeRole.DATA_HOT_NODE_ROLE); + private static final DesiredNode WARM_DESIRED_NODE = newDesiredNode("node-warm", DiscoveryNodeRole.DATA_WARM_NODE_ROLE); + private static final DesiredNode COLD_DESIRED_NODE = newDesiredNode("node-cold", DiscoveryNodeRole.DATA_COLD_NODE_ROLE); + private static final DesiredNode CONTENT_DESIRED_NODE = newDesiredNode("node-content", DiscoveryNodeRole.DATA_CONTENT_NODE_ROLE); + private static final DesiredNode DATA_DESIRED_NODE = newDesiredNode("node-data", DiscoveryNodeRole.DATA_ROLE); + private final ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); private final AllocationDeciders allocationDeciders = new AllocationDeciders( Arrays.asList( @@ -71,13 +82,6 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase { new ReplicaAfterPrimaryActiveAllocationDecider() ) ); - private final AllocationService service = new AllocationService( - allocationDeciders, - new TestGatewayAllocator(), - new BalancedShardsAllocator(Settings.EMPTY), - EmptyClusterInfoService.INSTANCE, - EmptySnapshotsInfoService.INSTANCE - ); private final ShardRouting shard = ShardRouting.newUnassigned( new ShardId("myindex", "myindex", 0), @@ -87,113 +91,111 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase { ); public void testIndexPrefer() { - ClusterState state = ClusterState.builder(service.reroute(ClusterState.EMPTY_STATE, "initial state")) - .nodes(DiscoveryNodes.builder().add(HOT_NODE).build()) - .metadata( - Metadata.builder() - .put( - IndexMetadata.builder("myindex") - .settings( - Settings.builder() - .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) - .put(IndexMetadata.SETTING_INDEX_UUID, "myindex") - .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) - .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) - .put(DataTier.TIER_PREFERENCE, "data_warm,data_cold") - .build() - ) - ) - .build() - ) - .build(); - RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state, null, null, 0); - allocation.debugDecision(true); - Decision d; - RoutingNode node; + { + final var desiredNodes = randomBoolean() ? null : createDesiredNodesWithActualizedNodes(HOT_DESIRED_NODE); + final var clusterState = clusterStateWithIndexAndNodes( + "data_warm,data_cold", + DiscoveryNodes.builder().add(HOT_NODE).build(), + desiredNodes + ); - for (DiscoveryNode n : Arrays.asList(HOT_NODE, WARM_NODE, COLD_NODE)) { - node = RoutingNodesHelper.routingNode(n.getId(), n, shard); - d = DataTierAllocationDecider.INSTANCE.canAllocate(shard, node, allocation); - assertThat(node.toString(), d.type(), equalTo(Decision.Type.NO)); - assertThat( - node.toString(), - d.getExplanation(), - containsString( + for (DiscoveryNode n : Arrays.asList(HOT_NODE, WARM_NODE, COLD_NODE)) { + assertAllocationDecision( + clusterState, + n, + Decision.Type.NO, "index has a preference for tiers [data_warm,data_cold], " + "but no nodes for any of those tiers are available in the cluster" - ) + ); + } + } + + { + final var desiredNodes = randomBoolean() ? null : createDesiredNodesWithActualizedNodes(HOT_DESIRED_NODE, COLD_DESIRED_NODE); + final var clusterState = clusterStateWithIndexAndNodes( + "data_warm,data_cold", + DiscoveryNodes.builder().add(HOT_NODE).add(COLD_NODE).build(), + desiredNodes ); - d = DataTierAllocationDecider.INSTANCE.canRemain(shard, node, allocation); - assertThat(node.toString(), d.type(), equalTo(Decision.Type.NO)); - assertThat( - node.toString(), - d.getExplanation(), - containsString( - "index has a preference for tiers [data_warm,data_cold], " - + "but no nodes for any of those tiers are available in the cluster" - ) + + for (DiscoveryNode n : Arrays.asList(HOT_NODE, WARM_NODE)) { + assertAllocationDecision( + clusterState, + n, + Decision.Type.NO, + "index has a preference for tiers [data_warm,data_cold] and node does not meet the required [data_cold] tier" + ); + } + + assertAllocationDecision( + clusterState, + COLD_NODE, + Decision.Type.YES, + "index has a preference for tiers [data_warm,data_cold] and node has tier [data_cold]" ); } - state = ClusterState.builder(service.reroute(ClusterState.EMPTY_STATE, "initial state")) - .nodes(DiscoveryNodes.builder().add(HOT_NODE).add(COLD_NODE).build()) - .metadata( - Metadata.builder() - .put( - IndexMetadata.builder("myindex") - .settings( - Settings.builder() - .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) - .put(IndexMetadata.SETTING_INDEX_UUID, "myindex") - .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) - .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) - .put(DataTier.TIER_PREFERENCE, "data_warm,data_cold") - .build() - ) - ) - .build() - ) - .build(); - allocation = new RoutingAllocation(allocationDeciders, state, null, null, 0); - allocation.debugDecision(true); - - for (DiscoveryNode n : Arrays.asList(HOT_NODE, WARM_NODE)) { - node = RoutingNodesHelper.routingNode(n.getId(), n, shard); - d = DataTierAllocationDecider.INSTANCE.canAllocate(shard, node, allocation); - assertThat(node.toString(), d.type(), equalTo(Decision.Type.NO)); - assertThat( - node.toString(), - d.getExplanation(), - containsString( - "index has a preference for tiers [data_warm,data_cold] " + "and node does not meet the required [data_cold] tier" - ) + { + // Remove the cold tier from desired nodes + final var desiredNodes = createDesiredNodesWithActualizedNodes(WARM_DESIRED_NODE); + final var state = clusterStateWithIndexAndNodes( + "data_cold,data_warm", + DiscoveryNodes.builder().add(WARM_NODE).add(COLD_NODE).build(), + desiredNodes ); - d = DataTierAllocationDecider.INSTANCE.canRemain(shard, node, allocation); - assertThat(node.toString(), d.type(), equalTo(Decision.Type.NO)); - assertThat( - node.toString(), - d.getExplanation(), - containsString( - "index has a preference for tiers [data_warm,data_cold] " + "and node does not meet the required [data_cold] tier" - ) + + for (DiscoveryNode node : List.of(HOT_NODE, COLD_NODE)) { + assertAllocationDecision( + state, + node, + Decision.Type.NO, + "index has a preference for tiers [data_cold,data_warm] and node does not meet the required [data_warm] tier" + ); + } + + assertAllocationDecision( + state, + WARM_NODE, + Decision.Type.YES, + "index has a preference for tiers [data_cold,data_warm] and node has tier [data_warm]" ); } - node = RoutingNodesHelper.routingNode(COLD_NODE.getId(), COLD_NODE, shard); - d = DataTierAllocationDecider.INSTANCE.canAllocate(shard, node, allocation); - assertThat(node.toString(), d.type(), equalTo(Decision.Type.YES)); - assertThat( - node.toString(), - d.getExplanation(), - containsString("index has a preference for tiers [data_warm,data_cold] and node has tier [data_cold]") - ); - d = DataTierAllocationDecider.INSTANCE.canRemain(shard, node, allocation); - assertThat(node.toString(), d.type(), equalTo(Decision.Type.YES)); - assertThat( - node.toString(), - d.getExplanation(), - containsString("index has a preference for tiers [data_warm,data_cold] and node has tier [data_cold]") - ); + { + // There's a warm node in the desired nodes, but it hasn't joined the cluster yet, + // in that case we consider that there aren't any nodes with the preferred tier in the cluster + final ClusterState clusterState; + final String tierPreference; + if (randomBoolean()) { + tierPreference = "data_warm,data_cold"; + clusterState = clusterStateWithIndexAndNodes( + tierPreference, + DiscoveryNodes.builder().add(HOT_NODE).build(), + DesiredNodes.create("history", 1, List.of(pendingDesiredNode(WARM_DESIRED_NODE))) + ); + } else { + tierPreference = "data_warm,data_hot"; + clusterState = clusterStateWithIndexAndNodes( + tierPreference, + DiscoveryNodes.builder().add(COLD_NODE).build(), + DesiredNodes.create("history", 1, List.of(pendingDesiredNode(WARM_DESIRED_NODE))) + ); + } + + for (DiscoveryNode node : List.of(HOT_NODE, WARM_NODE, COLD_NODE)) { + assertAllocationDecision( + clusterState, + node, + Decision.Type.NO, + String.format( + Locale.ROOT, + "index has a preference for tiers [%s], but no nodes for any of those tiers are available in the cluster", + tierPreference + ) + ); + } + + } } public void testTierNodesPresent() { @@ -222,43 +224,269 @@ public void testTierNodesPresent() { assertTrue(DataTierAllocationDecider.tierNodesPresent("data_content", nodes)); } + public void testTierNodesPresentDesiredNodes() { + Set nodes = Collections.emptySet(); + + assertFalse(DataTierAllocationDecider.tierNodesPresent("data", nodes)); + assertFalse(DataTierAllocationDecider.tierNodesPresent("data_hot", nodes)); + assertFalse(DataTierAllocationDecider.tierNodesPresent("data_warm", nodes)); + assertFalse(DataTierAllocationDecider.tierNodesPresent("data_cold", nodes)); + assertFalse(DataTierAllocationDecider.tierNodesPresent("data_content", nodes)); + + nodes = Set.of(WARM_DESIRED_NODE, CONTENT_DESIRED_NODE); + + assertFalse(DataTierAllocationDecider.tierNodesPresent("data", nodes)); + assertFalse(DataTierAllocationDecider.tierNodesPresent("data_hot", nodes)); + assertTrue(DataTierAllocationDecider.tierNodesPresent("data_warm", nodes)); + assertFalse(DataTierAllocationDecider.tierNodesPresent("data_cold", nodes)); + assertTrue(DataTierAllocationDecider.tierNodesPresent("data_content", nodes)); + + nodes = Set.of(DATA_DESIRED_NODE); + + assertTrue(DataTierAllocationDecider.tierNodesPresent("data", nodes)); + assertTrue(DataTierAllocationDecider.tierNodesPresent("data_hot", nodes)); + assertTrue(DataTierAllocationDecider.tierNodesPresent("data_warm", nodes)); + assertTrue(DataTierAllocationDecider.tierNodesPresent("data_cold", nodes)); + assertTrue(DataTierAllocationDecider.tierNodesPresent("data_content", nodes)); + } + public void testPreferredTierAvailable() { - DiscoveryNodes nodes = DiscoveryNodes.builder().build(); + { + final var nodes = DiscoveryNodes.builder().build(); + final DesiredNodes desiredNodes = randomBoolean() + ? null + : createDesiredNodesWithPendingNodes(HOT_DESIRED_NODE, WARM_DESIRED_NODE, COLD_DESIRED_NODE); - assertThat(DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data"), nodes), equalTo(Optional.empty())); - assertThat( - DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_hot,data_warm"), nodes), - equalTo(Optional.empty()) - ); - assertThat( - DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_warm,data_content"), nodes), - equalTo(Optional.empty()) - ); - assertThat(DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_cold"), nodes), equalTo(Optional.empty())); + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data"), nodes, desiredNodes), + equalTo(Optional.empty()) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_hot,data_warm"), nodes, desiredNodes), + equalTo(Optional.empty()) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_warm,data_content"), nodes, desiredNodes), + equalTo(Optional.empty()) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_cold"), nodes, desiredNodes), + equalTo(Optional.empty()) + ); + } - nodes = DiscoveryNodes.builder().add(WARM_NODE).add(CONTENT_NODE).build(); + { + final var nodes = DiscoveryNodes.builder().add(WARM_NODE).add(CONTENT_NODE).build(); + final var desiredNodes = randomBoolean() + ? null + : createDesiredNodesWithActualizedNodes(WARM_DESIRED_NODE, CONTENT_DESIRED_NODE); - assertThat(DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data"), nodes), equalTo(Optional.empty())); - assertThat( - DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_hot,data_warm"), nodes), - equalTo(Optional.of("data_warm")) - ); - assertThat( - DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_warm,data_content"), nodes), - equalTo(Optional.of("data_warm")) - ); - assertThat( - DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_content,data_warm"), nodes), - equalTo(Optional.of("data_content")) - ); - assertThat( - DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_hot,data_content,data_warm"), nodes), - equalTo(Optional.of("data_content")) - ); - assertThat( - DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_hot,data_cold,data_warm"), nodes), - equalTo(Optional.of("data_warm")) - ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data"), nodes, desiredNodes), + equalTo(Optional.empty()) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_hot,data_warm"), nodes, desiredNodes), + equalTo(Optional.of("data_warm")) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_warm,data_content"), nodes, desiredNodes), + equalTo(Optional.of("data_warm")) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_content,data_warm"), nodes, desiredNodes), + equalTo(Optional.of("data_content")) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier( + DataTier.parseTierList("data_hot,data_content,data_warm"), + nodes, + desiredNodes + ), + equalTo(Optional.of("data_content")) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier( + DataTier.parseTierList("data_hot,data_cold,data_warm"), + nodes, + desiredNodes + ), + equalTo(Optional.of("data_warm")) + ); + } + + { + final var nodes = DiscoveryNodes.builder().add(WARM_NODE).add(CONTENT_NODE).build(); + final var desiredNodes = createDesiredNodesWithActualizedNodes(HOT_DESIRED_NODE, WARM_DESIRED_NODE, CONTENT_DESIRED_NODE); + + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data"), nodes, desiredNodes), + equalTo(Optional.empty()) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_hot,data_warm"), nodes, desiredNodes), + equalTo(Optional.of("data_hot")) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_warm,data_content"), nodes, desiredNodes), + equalTo(Optional.of("data_warm")) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_content,data_warm"), nodes, desiredNodes), + equalTo(Optional.of("data_content")) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier( + DataTier.parseTierList("data_hot,data_content,data_warm"), + nodes, + desiredNodes + ), + equalTo(Optional.of("data_hot")) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier( + DataTier.parseTierList("data_hot,data_cold,data_warm"), + nodes, + desiredNodes + ), + equalTo(Optional.of("data_hot")) + ); + } + + { + // When there are desired nodes that haven't joined the cluster yet, those are not considered + final var nodes = DiscoveryNodes.builder().add(WARM_NODE).add(CONTENT_NODE).build(); + // i.e. HOT_DESIRED_NODE might be part of the DesiredNodes, but it is not part of the cluster yet + final var desiredNodes = DesiredNodes.create( + randomAlphaOfLength(10), + 1, + List.of( + pendingDesiredNode(HOT_DESIRED_NODE), + actualizedDesiredNode(WARM_DESIRED_NODE), + actualizedDesiredNode(CONTENT_DESIRED_NODE) + ) + ); + + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data"), nodes, desiredNodes), + equalTo(Optional.empty()) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_hot,data_warm"), nodes, desiredNodes), + equalTo(Optional.of("data_warm")) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_warm,data_content"), nodes, desiredNodes), + equalTo(Optional.of("data_warm")) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_content,data_warm"), nodes, desiredNodes), + equalTo(Optional.of("data_content")) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier( + DataTier.parseTierList("data_hot,data_content,data_warm"), + nodes, + desiredNodes + ), + equalTo(Optional.of("data_content")) + ); + assertThat( + DataTierAllocationDecider.preferredAvailableTier( + DataTier.parseTierList("data_hot,data_cold,data_warm"), + nodes, + desiredNodes + ), + equalTo(Optional.of("data_warm")) + ); + } + + { + // Cold tier is planned to be removed + final var nodes = DiscoveryNodes.builder().add(HOT_NODE).add(WARM_NODE).add(COLD_NODE).build(); + final var desiredNodes = createDesiredNodesWithActualizedNodes(HOT_DESIRED_NODE, WARM_DESIRED_NODE); + + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_cold,data_warm"), nodes, desiredNodes), + equalTo(Optional.of("data_warm")) + ); + } + + { + // During grow and shrink (i.e. a way to replace a node) we should avoid moving the shard from a preferred tier to a less + // preferred tier if there's a node that can hold that shard and we know that a new desired node would substitute the old one + final var nodes = DiscoveryNodes.builder().add(HOT_NODE).add(WARM_NODE).add(COLD_NODE).build(); + final var desiredNodes = DesiredNodes.create( + "history", + 1, + List.of( + actualizedDesiredNode(HOT_DESIRED_NODE), + actualizedDesiredNode(WARM_DESIRED_NODE), + pendingDesiredNode(COLD_DESIRED_NODE) + ) + ); + + assertThat( + DataTierAllocationDecider.preferredAvailableTier(DataTier.parseTierList("data_cold,data_warm"), nodes, desiredNodes), + equalTo(Optional.of("data_cold")) + ); + } + + { + // Ensure that when we are removing a tier and growing the next preferred tier we wait until all the new + // nodes have joined the cluster avoiding filling the new nodes with shards from the removed tier + final var nodes = DiscoveryNodes.builder().add(HOT_NODE).add(WARM_NODE).add(COLD_NODE).build(); + final DesiredNodes desiredNodes; + // Grow any of the next preferred tiers + if (randomBoolean()) { + final var newWarmNode = newDesiredNode("node-warm-2", DiscoveryNodeRole.DATA_WARM_NODE_ROLE); + desiredNodes = DesiredNodes.create( + "history", + 1, + List.of( + actualizedDesiredNode(HOT_DESIRED_NODE), + actualizedDesiredNode(WARM_DESIRED_NODE), + pendingDesiredNode(newWarmNode) + ) + ); + } else { + final var newHotNode = newDesiredNode("node-hot-2", DiscoveryNodeRole.DATA_HOT_NODE_ROLE); + desiredNodes = DesiredNodes.create( + "history", + 1, + List.of( + actualizedDesiredNode(HOT_DESIRED_NODE), + pendingDesiredNode(newHotNode), + actualizedDesiredNode(WARM_DESIRED_NODE) + ) + ); + } + + assertThat( + DataTierAllocationDecider.preferredAvailableTier( + DataTier.parseTierList("data_cold,data_warm,data_hot"), + nodes, + desiredNodes + ), + equalTo(Optional.of("data_cold")) + ); + + // Once all the nodes have joined, we can move the shard to the next tier + final var updatedDesiredNodes = DesiredNodes.create( + "history", + 2, + desiredNodes.nodes().stream().map(DesiredNodeWithStatus::desiredNode).map(this::actualizedDesiredNode).toList() + ); + + assertThat( + DataTierAllocationDecider.preferredAvailableTier( + DataTier.parseTierList("data_cold,data_warm,data_hot"), + nodes, + updatedDesiredNodes + ), + equalTo(Optional.of("data_warm")) + ); + } } public void testFrozenIllegalForRegularIndices() { @@ -354,4 +582,81 @@ public void testDefaultValueForPreference() { Settings settings = builder.build(); assertThat(DataTier.TIER_PREFERENCE_SETTING.get(settings), equalTo(DATA_FROZEN)); } + + private ClusterState clusterStateWithIndexAndNodes(String tierPreference, DiscoveryNodes discoveryNodes, DesiredNodes desiredNodes) { + final Metadata.Builder metadata = Metadata.builder() + .put( + IndexMetadata.builder(shard.getIndexName()) + .settings( + Settings.builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .put(IndexMetadata.SETTING_INDEX_UUID, shard.getIndexName()) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put(DataTier.TIER_PREFERENCE, tierPreference) + .build() + ) + ); + if (desiredNodes != null) { + metadata.putCustom(DesiredNodesMetadata.TYPE, new DesiredNodesMetadata(desiredNodes)); + } + return ClusterState.builder(new ClusterName("test")).nodes(discoveryNodes).metadata(metadata).build(); + } + + private static DesiredNode newDesiredNode(String externalId, DiscoveryNodeRole... roles) { + assert roles.length > 0; + + return new DesiredNode( + Settings.builder() + .put(NODE_EXTERNAL_ID_SETTING.getKey(), externalId) + .put(NODE_ROLES_SETTING.getKey(), Arrays.stream(roles).map(DiscoveryNodeRole::roleName).collect(Collectors.joining(","))) + .build(), + 1, + ByteSizeValue.ONE, + ByteSizeValue.ONE, + Version.CURRENT + ); + } + + private DesiredNodes createDesiredNodesWithActualizedNodes(DesiredNode... nodes) { + return createDesiredNodesWithStatus(DesiredNodeWithStatus.Status.ACTUALIZED, nodes); + } + + private DesiredNodes createDesiredNodesWithPendingNodes(DesiredNode... nodes) { + return createDesiredNodesWithStatus(DesiredNodeWithStatus.Status.PENDING, nodes); + } + + private DesiredNodes createDesiredNodesWithStatus(DesiredNodeWithStatus.Status status, DesiredNode... nodes) { + return DesiredNodes.create( + randomAlphaOfLength(10), + 1, + Arrays.stream(nodes).map(desiredNode -> new DesiredNodeWithStatus(desiredNode, status)).toList() + ); + } + + private void assertAllocationDecision(ClusterState state, DiscoveryNode node, Decision.Type decisionType, String explanationMessage) { + final var allocation = new RoutingAllocation(allocationDeciders, null, state, null, null, 0); + allocation.debugDecision(true); + + final var routingNode = RoutingNodesHelper.routingNode(node.getId(), node, shard); + { + final var decision = DataTierAllocationDecider.INSTANCE.canAllocate(shard, routingNode, allocation); + assertThat(routingNode.toString(), decision.type(), equalTo(decisionType)); + assertThat(routingNode.toString(), decision.getExplanation(), containsString(explanationMessage)); + } + + { + final var decision = DataTierAllocationDecider.INSTANCE.canRemain(shard, routingNode, allocation); + assertThat(routingNode.toString(), decision.type(), equalTo(decisionType)); + assertThat(routingNode.toString(), decision.getExplanation(), containsString(explanationMessage)); + } + } + + private DesiredNodeWithStatus actualizedDesiredNode(DesiredNode desiredNode) { + return new DesiredNodeWithStatus(desiredNode, DesiredNodeWithStatus.Status.ACTUALIZED); + } + + private DesiredNodeWithStatus pendingDesiredNode(DesiredNode desiredNode) { + return new DesiredNodeWithStatus(desiredNode, DesiredNodeWithStatus.Status.PENDING); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAssignmentStateActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAssignmentRoutingInfoActionRequestTests.java similarity index 77% rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAssignmentStateActionRequestTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAssignmentRoutingInfoActionRequestTests.java index d1cb84bfe9dc7..866e6526034dd 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAssignmentStateActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAssignmentRoutingInfoActionRequestTests.java @@ -8,14 +8,14 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentStateAction.Request; -import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReasonTests; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction.Request; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfoUpdateTests; -public class UpdateTrainedModelAssignmentStateActionRequestTests extends AbstractWireSerializingTestCase { +public class UpdateTrainedModelAssignmentRoutingInfoActionRequestTests extends AbstractWireSerializingTestCase { @Override protected Request createTestInstance() { - return new Request(randomAlphaOfLength(10), randomAlphaOfLength(10), RoutingStateAndReasonTests.randomInstance()); + return new Request(randomAlphaOfLength(10), randomAlphaOfLength(10), RoutingInfoUpdateTests.randomInstance()); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfoTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfoTests.java new file mode 100644 index 0000000000000..c6ccc825954f1 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfoTests.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.assignment; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +import static org.hamcrest.Matchers.is; + +public class RoutingInfoTests extends AbstractSerializingTestCase { + + @Override + protected RoutingInfo doParseInstance(XContentParser parser) throws IOException { + return RoutingInfo.fromXContent(parser); + } + + @Override + protected Writeable.Reader instanceReader() { + return RoutingInfo::new; + } + + @Override + protected RoutingInfo createTestInstance() { + return randomInstance(); + } + + public static RoutingInfo randomInstance() { + return new RoutingInfo( + randomIntBetween(1, 10), + randomIntBetween(1, 10), + randomFrom(RoutingState.values()), + randomBoolean() ? null : randomAlphaOfLength(10) + ); + } + + public static RoutingInfo randomInstance(RoutingState state) { + return new RoutingInfo(randomIntBetween(1, 10), randomIntBetween(1, 10), state, randomBoolean() ? null : randomAlphaOfLength(10)); + } + + public void testIsRoutable_GivenNonStarted() { + RoutingInfo routingInfo = new RoutingInfo( + 1, + 1, + randomFrom(RoutingState.STARTING, RoutingState.FAILED, RoutingState.STOPPING, RoutingState.STOPPED), + "" + ); + assertThat(routingInfo.isRoutable(), is(false)); + } + + public void testIsRoutable_GivenStartedWithZeroAllocations() { + RoutingInfo routingInfo = new RoutingInfo(0, 1, RoutingState.STARTED, ""); + assertThat(routingInfo.isRoutable(), is(false)); + } + + public void testIsRoutable_GivenStartedWithNonZeroAllocations() { + RoutingInfo routingInfo = new RoutingInfo(randomIntBetween(1, 10), 1, RoutingState.STARTED, ""); + assertThat(routingInfo.isRoutable(), is(true)); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfoUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfoUpdateTests.java new file mode 100644 index 0000000000000..3adb679431250 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfoUpdateTests.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.assignment; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import static org.hamcrest.Matchers.equalTo; + +public class RoutingInfoUpdateTests extends AbstractWireSerializingTestCase { + + @Override + protected Writeable.Reader instanceReader() { + return RoutingInfoUpdate::new; + } + + @Override + protected RoutingInfoUpdate createTestInstance() { + return randomInstance(); + } + + public static RoutingInfoUpdate randomInstance() { + if (randomBoolean()) { + return RoutingInfoUpdate.updateNumberOfAllocations(randomIntBetween(1, Integer.MAX_VALUE)); + } else { + return RoutingInfoUpdate.updateStateAndReason(RoutingStateAndReasonTests.randomInstance()); + } + } + + public void testApply_GivenUpdatingNumberOfAllocations() { + RoutingInfo updatedRoutingInfo = RoutingInfoUpdate.updateNumberOfAllocations(4) + .apply(new RoutingInfo(3, 5, RoutingState.STARTED, "some text")); + assertThat(updatedRoutingInfo, equalTo(new RoutingInfo(4, 5, RoutingState.STARTED, "some text"))); + } + + public void testApply_GivenUpdatingStateAndReason() { + RoutingInfo updatedRoutingInfo = RoutingInfoUpdate.updateStateAndReason(new RoutingStateAndReason(RoutingState.STOPPING, "test")) + .apply(new RoutingInfo(3, 5, RoutingState.STARTED, "")); + assertThat(updatedRoutingInfo, equalTo(new RoutingInfo(3, 5, RoutingState.STOPPING, "test"))); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java index a5a8b2584c8c7..14b2d65a5c5be 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java @@ -9,25 +9,29 @@ import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.Version; -import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentTaskParamsTests; +import org.elasticsearch.xpack.core.ml.stats.CountAccumulator; import java.io.IOException; +import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.function.Function; +import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; import static org.hamcrest.Matchers.arrayContainingInAnyOrder; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThanOrEqualTo; public class TrainedModelAssignmentTests extends AbstractSerializingTestCase { @@ -35,11 +39,7 @@ public static TrainedModelAssignment randomInstance() { TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomParams()); List nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList()); for (String node : nodes) { - if (randomBoolean()) { - builder.addNewFailedRoutingEntry(node, randomAlphaOfLength(10)); - } else { - builder.addNewRoutingEntry(node); - } + builder.addRoutingEntry(node, RoutingInfoTests.randomInstance()); } builder.setAssignmentState(randomFrom(AssignmentState.values())); if (randomBoolean()) { @@ -63,48 +63,20 @@ protected TrainedModelAssignment createTestInstance() { return randomInstance(); } - public void testBuilderChanged() { - TrainedModelAssignment original = randomInstance(); - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.fromAssignment(original); - assertThat(builder.isChanged(), is(false)); - String addingNode = "foo"; - - assertUnchanged(builder, b -> b.removeRoutingEntry(addingNode)); - - if (randomBoolean()) { - builder.addNewRoutingEntry(addingNode); - } else { - builder.addNewFailedRoutingEntry(addingNode, "test failed"); - } - assertThat(builder.isChanged(), is(true)); - - TrainedModelAssignment.Builder builderWithNode = TrainedModelAssignment.Builder.fromAssignment(builder.build()); - assertThat(builderWithNode.isChanged(), is(false)); - - builderWithNode.removeRoutingEntry(addingNode); - assertThat(builderWithNode.isChanged(), is(true)); - } - public void testBuilderAddingExistingRoute() { - TrainedModelAssignment original = randomInstance(); - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.fromAssignment(original); + TrainedModelAssignment.Builder assignment = TrainedModelAssignment.Builder.empty(randomParams()); String addingNode = "new-node"; - if (randomBoolean()) { - builder.addNewRoutingEntry(addingNode); - } else { - builder.addNewFailedRoutingEntry(addingNode, "test failed"); - } - expectThrows(ResourceAlreadyExistsException.class, () -> builder.addNewFailedRoutingEntry("new-node", "anything")); - expectThrows(ResourceAlreadyExistsException.class, () -> builder.addNewRoutingEntry("new-node")); + assignment.addRoutingEntry(addingNode, RoutingInfoTests.randomInstance()); + + expectThrows(ResourceAlreadyExistsException.class, () -> assignment.addRoutingEntry("new-node", RoutingInfoTests.randomInstance())); } public void testBuilderUpdatingMissingRoute() { - TrainedModelAssignment original = randomInstance(); - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.fromAssignment(original); + TrainedModelAssignment.Builder assignment = TrainedModelAssignment.Builder.empty(randomParams()); String addingNode = "new-node"; expectThrows( ResourceNotFoundException.class, - () -> builder.updateExistingRoutingEntry(addingNode, RoutingStateAndReasonTests.randomInstance()) + () -> assignment.updateExistingRoutingEntry(addingNode, RoutingInfoTests.randomInstance()) ); } @@ -114,143 +86,197 @@ public void testGetStartedNodes() { String nodeInAnotherState1 = "another-state-node-1"; String nodeInAnotherState2 = "another-state-node-2"; TrainedModelAssignment allocation = TrainedModelAssignment.Builder.empty(randomParams()) - .addNewRoutingEntry(startedNode1) - .addNewRoutingEntry(startedNode2) - .addNewRoutingEntry(nodeInAnotherState1) - .addNewRoutingEntry(nodeInAnotherState2) - .updateExistingRoutingEntry(startedNode1, new RoutingStateAndReason(RoutingState.STARTED, "")) - .updateExistingRoutingEntry(startedNode2, new RoutingStateAndReason(RoutingState.STARTED, "")) - .updateExistingRoutingEntry( + .addRoutingEntry(startedNode1, RoutingInfoTests.randomInstance(RoutingState.STARTED)) + .addRoutingEntry(startedNode2, RoutingInfoTests.randomInstance(RoutingState.STARTED)) + .addRoutingEntry( nodeInAnotherState1, - new RoutingStateAndReason( - randomFrom(RoutingState.STARTING, RoutingState.FAILED, RoutingState.STOPPED, RoutingState.STOPPING), - randomAlphaOfLength(10) + RoutingInfoTests.randomInstance( + randomFrom(RoutingState.STARTING, RoutingState.STOPPING, RoutingState.STOPPED, RoutingState.FAILED) ) ) - .updateExistingRoutingEntry( + .addRoutingEntry( nodeInAnotherState2, - new RoutingStateAndReason( - randomFrom(RoutingState.STARTING, RoutingState.FAILED, RoutingState.STOPPED, RoutingState.STOPPING), - randomAlphaOfLength(10) + RoutingInfoTests.randomInstance( + randomFrom(RoutingState.STARTING, RoutingState.STOPPING, RoutingState.STOPPED, RoutingState.FAILED) ) ) .build(); assertThat(allocation.getStartedNodes(), arrayContainingInAnyOrder(startedNode1, startedNode2)); } - public void testCalculateAllocationStatus() { - List nodes = Stream.generate(TrainedModelAssignmentTests::buildNode).limit(5).collect(Collectors.toList()); - final boolean includeNodes = randomBoolean(); + public void testCalculateAllocationStatus_GivenNoAllocations() { assertThat( - TrainedModelAssignment.Builder.empty(randomParams()) - .build() - .calculateAllocationStatus(includeNodes ? nodes : List.of()) - .orElseThrow(), - equalTo(new AllocationStatus(0, includeNodes ? 5 : 0)) - ); - assertThat( - TrainedModelAssignment.Builder.empty(randomParams()) - .stopAssignment("test") - .build() - .calculateAllocationStatus(includeNodes ? nodes : List.of()) - .isPresent(), - is(false) + TrainedModelAssignment.Builder.empty(randomTaskParams(5)).build().calculateAllocationStatus().get(), + equalTo(new AllocationStatus(0, 5)) ); + } - { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomParams()); - int count = randomInt(4); - for (int i = 0; i < count; i++) { - builder.addRoutingEntry(nodes.get(i).getId(), RoutingState.STARTED); - } - assertThat(builder.build().calculateAllocationStatus(nodes).orElseThrow(), equalTo(new AllocationStatus(count, 5))); - } - { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomParams()); - for (DiscoveryNode node : nodes) { - builder.addRoutingEntry( - node.getId(), - randomFrom(RoutingState.FAILED, RoutingState.STOPPED, RoutingState.STARTING, RoutingState.STOPPING) - ); - } - int count = randomIntBetween(1, 4); - for (int i = 0; i < count; i++) { - builder.addRoutingEntry(nodes.get(i).getId(), RoutingState.STARTED); - } - assertThat(builder.build().calculateAllocationStatus(nodes).orElseThrow(), equalTo(new AllocationStatus(count, 5))); - } - { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomParams()); - for (DiscoveryNode node : nodes) { - builder.addRoutingEntry(node.getId(), RoutingState.STARTED); - } - assertThat(builder.build().calculateAllocationStatus(nodes).orElseThrow(), equalTo(new AllocationStatus(5, 5))); - } + public void testCalculateAllocationStatus_GivenStoppingAssignment() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + builder.addRoutingEntry("node-1", new RoutingInfo(1, 2, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-2", new RoutingInfo(2, 1, RoutingState.STARTED, "")); + assertThat(builder.stopAssignment("test").build().calculateAllocationStatus().isEmpty(), is(true)); + } + + public void testCalculateAllocationStatus_GivenPartiallyAllocated() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + builder.addRoutingEntry("node-1", new RoutingInfo(1, 2, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-2", new RoutingInfo(2, 1, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTING, "")); + assertThat(builder.build().calculateAllocationStatus().get(), equalTo(new AllocationStatus(3, 5))); + } + + public void testCalculateAllocationStatus_GivenFullyAllocated() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")); + assertThat(builder.build().calculateAllocationStatus().get(), equalTo(new AllocationStatus(5, 5))); } - public void testCalculateAllocationState() { - List nodes = Stream.generate(TrainedModelAssignmentTests::buildNode).limit(5).collect(Collectors.toList()); - assertThat(TrainedModelAssignment.Builder.empty(randomParams()).calculateAssignmentState(), equalTo(AssignmentState.STARTING)); + public void testCalculateAssignmentState_GivenNoStartedAssignments() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTING, "")); + builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTING, "")); + assertThat(builder.calculateAssignmentState(), equalTo(AssignmentState.STARTING)); + } + + public void testCalculateAssignmentState_GivenOneStartedAssignment() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTING, "")); + builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")); + assertThat(builder.calculateAssignmentState(), equalTo(AssignmentState.STARTED)); + } + + public void testCalculateAndSetAssignmentState_GivenStoppingAssignment() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")); assertThat( - TrainedModelAssignment.Builder.empty(randomParams()).stopAssignment("test").calculateAssignmentState(), + builder.stopAssignment("test").calculateAndSetAssignmentState().build().getAssignmentState(), equalTo(AssignmentState.STOPPING) ); + } - { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomParams()); - int count = randomInt(4); - for (int i = 0; i < count; i++) { - builder.addRoutingEntry( - nodes.get(i).getId(), - randomFrom(RoutingState.FAILED, RoutingState.STOPPED, RoutingState.STARTING, RoutingState.STOPPING) - ); - } - assertThat(builder.calculateAssignmentState(), equalTo(AssignmentState.STARTING)); - } - { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomParams()); - for (DiscoveryNode node : nodes) { - builder.addRoutingEntry( - node.getId(), - randomFrom(RoutingState.FAILED, RoutingState.STOPPED, RoutingState.STARTING, RoutingState.STOPPING) - ); - } - int count = randomIntBetween(1, 4); - for (int i = 0; i < count; i++) { - builder.addRoutingEntry(nodes.get(i).getId(), RoutingState.STARTED); - } - assertThat(builder.calculateAssignmentState(), equalTo(AssignmentState.STARTED)); + public void testSelectRandomStartedNodeWeighedOnAllocations_GivenNoStartedAllocations() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTING, "")); + builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STOPPED, "")); + TrainedModelAssignment assignment = builder.build(); + + assertThat(assignment.selectRandomStartedNodeWeighedOnAllocations().isEmpty(), is(true)); + } + + public void testSelectRandomStartedNodeWeighedOnAllocations_GivenSingleStartedNode() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, "")); + TrainedModelAssignment assignment = builder.build(); + + Optional node = assignment.selectRandomStartedNodeWeighedOnAllocations(); + + assertThat(node.isPresent(), is(true)); + assertThat(node.get(), equalTo("node-1")); + } + + public void testSelectRandomStartedNodeWeighedOnAllocations_GivenMultipleStartedNodes() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); + builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-2", new RoutingInfo(2, 2, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTED, "")); + TrainedModelAssignment assignment = builder.build(); + + final long selectionCount = 10000; + final CountAccumulator countsPerNodeAccumulator = new CountAccumulator(); + for (int i = 0; i < selectionCount; i++) { + Optional node = assignment.selectRandomStartedNodeWeighedOnAllocations(); + assertThat(node.isPresent(), is(true)); + countsPerNodeAccumulator.add(node.get(), 1L); } - { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomParams()); - for (DiscoveryNode node : nodes) { - builder.addRoutingEntry(node.getId(), RoutingState.STARTED); - } - assertThat(builder.calculateAssignmentState(), equalTo(AssignmentState.STARTED)); + + Map countsPerNode = countsPerNodeAccumulator.asMap(); + assertThat(countsPerNode.keySet(), contains("node-1", "node-2", "node-3")); + assertThat(countsPerNode.get("node-1") + countsPerNode.get("node-2") + countsPerNode.get("node-3"), equalTo(selectionCount)); + + assertValueWithinPercentageOfExpectedRatio(countsPerNode.get("node-1"), selectionCount, 1.0 / 6.0, 0.2); + assertValueWithinPercentageOfExpectedRatio(countsPerNode.get("node-2"), selectionCount, 2.0 / 6.0, 0.2); + assertValueWithinPercentageOfExpectedRatio(countsPerNode.get("node-3"), selectionCount, 3.0 / 6.0, 0.2); + } + + public void testSelectRandomStartedNodeWeighedOnAllocations_GivenMultipleStartedNodesWithZeroAllocations() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); + builder.addRoutingEntry("node-1", new RoutingInfo(0, 0, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-2", new RoutingInfo(0, 0, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-3", new RoutingInfo(0, 0, RoutingState.STARTED, "")); + TrainedModelAssignment assignment = builder.build(); + final long selectionCount = 1000; + Set selectedNodes = new HashSet<>(); + for (int i = 0; i < selectionCount; i++) { + Optional selectedNode = assignment.selectRandomStartedNodeWeighedOnAllocations(); + assertThat(selectedNode.isPresent(), is(true)); + selectedNodes.add(selectedNode.get()); } + + assertThat(selectedNodes, contains("node-1", "node-2", "node-3")); } - private static DiscoveryNode buildNode() { - return new DiscoveryNode( - randomAlphaOfLength(10), - randomAlphaOfLength(10), - buildNewFakeTransportAddress(), - Map.of(), - DiscoveryNodeRole.roles(), - Version.CURRENT + public void testIsSatisfied_GivenEnoughAllocations() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); + builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-2", new RoutingInfo(2, 2, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTED, "")); + TrainedModelAssignment assignment = builder.build(); + assertThat(assignment.isSatisfied(Sets.newHashSet("node-1", "node-2", "node-3")), is(true)); + } + + public void testIsSatisfied_GivenEnoughAllocations_ButOneNodeIsNotAssignable() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); + builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-2", new RoutingInfo(2, 2, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTED, "")); + TrainedModelAssignment assignment = builder.build(); + assertThat(assignment.isSatisfied(Sets.newHashSet("node-2", "node-3")), is(false)); + } + + public void testIsSatisfied_GivenEnoughAllocations_ButOneNodeIsNeitherStartingNorStarted() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); + builder.addRoutingEntry( + "node-1", + new RoutingInfo(1, 1, randomFrom(RoutingState.FAILED, RoutingState.STOPPING, RoutingState.STOPPED), "") ); + builder.addRoutingEntry("node-2", new RoutingInfo(2, 2, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTED, "")); + TrainedModelAssignment assignment = builder.build(); + assertThat(assignment.isSatisfied(Sets.newHashSet("node-1", "node-2", "node-3")), is(false)); } - private static StartTrainedModelDeploymentAction.TaskParams randomParams() { - return StartTrainedModelDeploymentTaskParamsTests.createRandom(); + public void testIsSatisfied_GivenNotEnoughAllocations() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(7)); + builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-2", new RoutingInfo(2, 2, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTED, "")); + TrainedModelAssignment assignment = builder.build(); + assertThat(assignment.isSatisfied(Sets.newHashSet("node-1", "node-2", "node-3")), is(false)); } - private static void assertUnchanged( - TrainedModelAssignment.Builder builder, - Function function - ) { - function.apply(builder); - assertThat(builder.isChanged(), is(false)); + private void assertValueWithinPercentageOfExpectedRatio(long value, long totalCount, double ratio, double tolerance) { + double expected = totalCount * ratio; + double lowerBound = (1.0 - tolerance) * expected; + double upperBound = (1.0 + tolerance) * expected; + logger.info("Checked that: {} <= {} <= {}", lowerBound, value, upperBound); + assertThat((double) value, greaterThanOrEqualTo(lowerBound)); + assertThat((double) value, lessThanOrEqualTo(upperBound)); } + private static StartTrainedModelDeploymentAction.TaskParams randomTaskParams(int numberOfAllocations) { + return new StartTrainedModelDeploymentAction.TaskParams( + randomAlphaOfLength(10), + randomNonNegativeLong(), + randomIntBetween(1, 8), + numberOfAllocations, + randomIntBetween(1, 10000) + ); + } + + private static StartTrainedModelDeploymentAction.TaskParams randomParams() { + return StartTrainedModelDeploymentTaskParamsTests.createRandom(); + } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlMemoryIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlMemoryIT.java index 2084ea491e804..e75a94ec78925 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlMemoryIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlMemoryIT.java @@ -117,7 +117,7 @@ public void testMemoryStats() throws Exception { assertThat(stats.getJvmInference().getBytes(), greaterThanOrEqualTo(0L)); } assertThat(mlNodes, is(2)); - assertThat(nodesWithPytorchModel, equalTo(mlNodes)); + assertThat(nodesWithPytorchModel, equalTo(1)); assertThat(nodesWithAnomalyJob, is(1)); // It's possible that the DFA job could have finished before the stats call was made assumeFalse( diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java index 3db34ac4c3a31..8ef942426037f 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java @@ -330,8 +330,8 @@ public void testLiveDeploymentStats() throws IOException { "deployment_stats.nodes", stats.get(0) ); - // 2 of the 3 nodes in the cluster are ML nodes - assertThat(nodes, hasSize(2)); + // 2 of the 3 nodes in the cluster are ML nodes but we have asked for a single allocation + assertThat(nodes, hasSize(1)); for (var node : nodes) { assertThat(node.get("number_of_pending_requests"), notNullValue()); } @@ -408,12 +408,10 @@ public void testGetDeploymentStats_WithStartedStoppedDeployments() throws IOExce "deployment_stats.nodes", stats.get(i) ); - // 2 ml nodes - assertThat(nodes, hasSize(2)); - for (int j : new int[] { 0, 1 }) { - Object state = XContentMapValues.extractValue("routing_state.routing_state", nodes.get(j)); - assertEquals("started", state); - } + // 2 ml nodes but we've asked a single allocation for each model + assertThat(nodes, hasSize(1)); + Object state = XContentMapValues.extractValue("routing_state.routing_state", nodes.get(0)); + assertEquals("started", state); } stopDeployment(modelFoo); @@ -425,17 +423,15 @@ public void testGetDeploymentStats_WithStartedStoppedDeployments() throws IOExce assertThat(stats, hasSize(2)); assertThat(stats.get(0), not(hasKey("deployment_stats"))); - // check all nodes are started for the non-stopped deployment + // check a node is started for the non-stopped deployment List> nodes = (List>) XContentMapValues.extractValue( "deployment_stats.nodes", stats.get(1) ); - // 2 ml nodes - assertThat(nodes, hasSize(2)); - for (int j : new int[] { 0, 1 }) { - Object state = XContentMapValues.extractValue("routing_state.routing_state", nodes.get(j)); - assertEquals("started", state); - } + // 2 ml nodes but we've asked a single allocation + assertThat(nodes, hasSize(1)); + Object state = XContentMapValues.extractValue("routing_state.routing_state", nodes.get(0)); + assertEquals("started", state); stopDeployment(modelBar); } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/JobsAndModelsIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/JobsAndModelsIT.java index fda007a289da2..1c7daea9cc761 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/JobsAndModelsIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/JobsAndModelsIT.java @@ -45,6 +45,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; /** @@ -175,15 +176,39 @@ public void testCluster_GivenAnomalyDetectionJobAndTrainedModelDeployment_Should String lastMlNodeName = internalCluster().startNode(onlyRoles(Set.of(DiscoveryNodeRole.ML_ROLE))); ensureStableCluster(); - // Here we make the assumption that models are assigned before persistent tasks. - // The reason this holds follows. Allocation service is a plugin component listening to - // cluster states updates. Persistent tasks have executors that listen to cluster - // states. Plugin components get created before persistent task executors. Thus, - // the allocation service will be producing each cluster state updates first. - // As this assumption might be critical, the test should break if the assumption - // breaks to give us a warning about potential impact. + // Wait until either the job or the model is assigned + assertBusy(() -> { + GetTrainedModelsStatsAction.Response modelStatsResponse = client().execute( + GetTrainedModelsStatsAction.INSTANCE, + new GetTrainedModelsStatsAction.Request(model.getModelId()) + ).actionGet(); + GetTrainedModelsStatsAction.Response.TrainedModelStats modelStats = modelStatsResponse.getResources().results().get(0); + GetJobsStatsAction.Response jobStatsResponse = client().execute( + GetJobsStatsAction.INSTANCE, + new GetJobsStatsAction.Request(job.getId()) + ).actionGet(); + GetJobsStatsAction.Response.JobStats jobStats = jobStatsResponse.getResponse().results().get(0); + + boolean isModelAssigned = modelStats.getDeploymentStats().getNodeStats().isEmpty() == false; + boolean isJobAssigned = jobStats.getNode() != null; + assertThat(isJobAssigned ^ isModelAssigned, is(true)); + + if (isJobAssigned) { + assertThat(jobStats.getNode().getName(), equalTo(lastMlNodeName)); + assertThat(modelStats.getDeploymentStats().getReason(), containsString("insufficient available memory")); + } else { + assertThat(modelStats.getDeploymentStats().getNodeStats().get(0).getNode().getName(), equalTo(lastMlNodeName)); + assertThat(jobStats.getAssignmentExplanation(), containsString("insufficient available memory")); + } + }); - // Wait until the model is assigned + // Start another new ML node + logger.info("Starting dedicated ml node..."); + internalCluster().startNode(onlyRoles(Set.of(DiscoveryNodeRole.ML_ROLE))); + ensureStableCluster(); + + // Wait until both the job and the model are assigned + // and check they are not on the same node assertBusy(() -> { GetTrainedModelsStatsAction.Response modelStatsResponse = client().execute( GetTrainedModelsStatsAction.INSTANCE, @@ -191,17 +216,15 @@ public void testCluster_GivenAnomalyDetectionJobAndTrainedModelDeployment_Should ).actionGet(); GetTrainedModelsStatsAction.Response.TrainedModelStats modelStats = modelStatsResponse.getResources().results().get(0); assertThat(modelStats.getDeploymentStats().getNodeStats().isEmpty(), is(false)); - assertThat(modelStats.getDeploymentStats().getNodeStats().get(0).getNode().getName(), equalTo(lastMlNodeName)); - }); + GetJobsStatsAction.Response jobStatsResponse = client().execute( + GetJobsStatsAction.INSTANCE, + new GetJobsStatsAction.Request(job.getId()) + ).actionGet(); + GetJobsStatsAction.Response.JobStats jobStats = jobStatsResponse.getResponse().results().get(0); + assertThat(jobStats.getNode(), is(notNullValue())); - // Check the job is unassigned due to insufficient memory - GetJobsStatsAction.Response jobStatsResponse = client().execute( - GetJobsStatsAction.INSTANCE, - new GetJobsStatsAction.Request(job.getId()) - ).actionGet(); - GetJobsStatsAction.Response.JobStats jobStats = jobStatsResponse.getResponse().results().get(0); - assertThat(jobStats.getNode(), is(nullValue())); - assertThat(jobStats.getAssignmentExplanation(), containsString("insufficient available memory")); + assertThat(jobStats.getNode(), is(not(equalTo(modelStats.getDeploymentStats().getNodeStats().get(0).getNode())))); + }); // Clean up client().execute(CloseJobAction.INSTANCE, new CloseJobAction.Request(jobId).setForce(true)).actionGet(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index ab28ae4a432d1..4a6b17d1de71f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -170,7 +170,7 @@ import org.elasticsearch.xpack.core.ml.action.UpdateJobAction; import org.elasticsearch.xpack.core.ml.action.UpdateModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.UpdateProcessAction; -import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentStateAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction; import org.elasticsearch.xpack.core.ml.action.UpgradeJobModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction; import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction; @@ -462,6 +462,7 @@ public class MachineLearning extends Plugin public static final String PRE_V7_BASE_PATH = "/_xpack/ml/"; public static final String DATAFEED_THREAD_POOL_NAME = NAME + "_datafeed"; public static final String JOB_COMMS_THREAD_POOL_NAME = NAME + "_job_comms"; + public static final String NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME = NAME + "_native_inference_comms"; public static final String UTILITY_THREAD_POOL_NAME = NAME + "_utility"; public static final String TRAINED_MODEL_CIRCUIT_BREAKER_NAME = "model_inference"; @@ -512,6 +513,7 @@ public Map getProcessors(Processor.Parameters paramet private static final String PRE_V8_MAX_OPEN_JOBS_NODE_ATTR = "ml.max_open_jobs"; public static final String MACHINE_MEMORY_NODE_ATTR = "ml.machine_memory"; public static final String MAX_JVM_SIZE_NODE_ATTR = "ml.max_jvm_size"; + public static final String ALLOCATED_PROCESSORS_NODE_ATTR = "ml.allocated_processors"; public static final Setting CONCURRENT_JOB_ALLOCATIONS = Setting.intSetting( "xpack.ml.node_concurrent_job_allocations", 2, @@ -701,10 +703,12 @@ public List> getSettings() { ); } + @Override public Settings additionalSettings() { String maxOpenJobsPerNodeNodeAttrName = "node.attr." + PRE_V8_MAX_OPEN_JOBS_NODE_ATTR; String machineMemoryAttrName = "node.attr." + MACHINE_MEMORY_NODE_ATTR; String jvmSizeAttrName = "node.attr." + MAX_JVM_SIZE_NODE_ATTR; + String allocatedProcessorsAttrName = "node.attr." + ALLOCATED_PROCESSORS_NODE_ATTR; if (enabled == false) { disallowMlNodeAttributes(maxOpenJobsPerNodeNodeAttrName, machineMemoryAttrName, jvmSizeAttrName); @@ -719,10 +723,11 @@ public Settings additionalSettings() { Long.toString(OsProbe.getInstance().osStats().getMem().getAdjustedTotal().getBytes()) ); addMlNodeAttribute(additionalSettings, jvmSizeAttrName, Long.toString(Runtime.getRuntime().maxMemory())); + addMlNodeAttribute(additionalSettings, allocatedProcessorsAttrName, Integer.toString(getAllocatedProcessors())); // This is not used in v8 and higher, but users are still prevented from setting it directly to avoid confusion disallowMlNodeAttributes(maxOpenJobsPerNodeNodeAttrName); } else { - disallowMlNodeAttributes(maxOpenJobsPerNodeNodeAttrName, machineMemoryAttrName, jvmSizeAttrName); + disallowMlNodeAttributes(maxOpenJobsPerNodeNodeAttrName, machineMemoryAttrName, jvmSizeAttrName, allocatedProcessorsAttrName); } return additionalSettings.build(); } @@ -736,6 +741,10 @@ private void addMlNodeAttribute(Settings.Builder additionalSettings, String attr } } + private int getAllocatedProcessors() { + return EsExecutors.allocatedProcessors(settings); + } + private void disallowMlNodeAttributes(String... mlNodeAttributes) { for (String attrName : mlNodeAttributes) { if (settings.get(attrName) != null) { @@ -1061,7 +1070,7 @@ public Collection createComponents( threadPool ); trainedModelAllocationClusterServiceSetOnce.set( - new TrainedModelAssignmentClusterService(settings, clusterService, new NodeLoadDetector(memoryTracker)) + new TrainedModelAssignmentClusterService(settings, clusterService, threadPool, new NodeLoadDetector(memoryTracker)) ); mlAutoscalingDeciderService.set( @@ -1333,7 +1342,10 @@ public List getRestHandlers( new ActionHandler<>(DeleteTrainedModelAssignmentAction.INSTANCE, TransportDeleteTrainedModelAssignmentAction.class), new ActionHandler<>(PutTrainedModelDefinitionPartAction.INSTANCE, TransportPutTrainedModelDefinitionPartAction.class), new ActionHandler<>(PutTrainedModelVocabularyAction.INSTANCE, TransportPutTrainedModelVocabularyAction.class), - new ActionHandler<>(UpdateTrainedModelAssignmentStateAction.INSTANCE, TransportUpdateTrainedModelAssignmentStateAction.class), + new ActionHandler<>( + UpdateTrainedModelAssignmentRoutingInfoAction.INSTANCE, + TransportUpdateTrainedModelAssignmentStateAction.class + ), usageAction, infoAction ); @@ -1372,6 +1384,21 @@ public List> getExecutorBuilders(Settings unused) { "xpack.ml.job_comms_thread_pool" ); + // 3 threads per native inference process: for input, c++ logger output, and result processing. + // As we cannot assign more models than the number of allocated processors, this thread pool's + // size is limited by the number of allocated processors on this node. + // Only use this thread pool for the main long-running process associated with a native inference model deployment. + // (Using it for some other purpose could mean that an unrelated pytorch model assignment fails to start + // or that whatever needed the thread for another purpose has to queue for a very long time.) + ScalingExecutorBuilder pytorchComms = new ScalingExecutorBuilder( + NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME, + 3, + getAllocatedProcessors() * 3, + TimeValue.timeValueMinutes(1), + false, + "xpack.ml.native_inference_comms_thread_pool" + ); + // This pool is used by renormalization, data frame analytics memory estimation, plus some other parts // of ML that need to kick off non-trivial activities that mustn't block other threads. ScalingExecutorBuilder utility = new ScalingExecutorBuilder( @@ -1392,7 +1419,7 @@ public List> getExecutorBuilders(Settings unused) { "xpack.ml.datafeed_thread_pool" ); - return List.of(jobComms, utility, datafeed); + return List.of(jobComms, pytorchComms, utility, datafeed); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java index c7b2b90f1e869..c7eb0635f0f28 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java @@ -13,7 +13,6 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.tasks.TransportTasksAction; import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; @@ -23,11 +22,10 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher; import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; -import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; -import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata; import org.elasticsearch.xpack.ml.inference.deployment.ModelStats; @@ -109,7 +107,7 @@ protected void doExecute( List matchedDeploymentIds = new ArrayList<>(); Set taskNodes = new HashSet<>(); - Map> assignmentNonStartedRoutes = new HashMap<>(); + Map> assignmentNonStartedRoutes = new HashMap<>(); for (var assignmentEntry : assignment.modelAssignments().entrySet()) { String modelId = assignmentEntry.getKey(); if (idsMatcher.idMatches(modelId)) { @@ -117,7 +115,7 @@ protected void doExecute( taskNodes.addAll(Arrays.asList(assignmentEntry.getValue().getStartedNodes())); - Map routings = assignmentEntry.getValue() + Map routings = assignmentEntry.getValue() .getNodeRoutingTable() .entrySet() .stream() @@ -140,20 +138,13 @@ protected void doExecute( ActionListener addFailedListener = listener.delegateFailure((l, response) -> { var updatedResponse = addFailedRoutes(response, assignmentNonStartedRoutes, clusterState.nodes()); - ClusterState latestState = clusterService.state(); - Set nodesShuttingDown = TransportStartTrainedModelDeploymentAction.nodesShuttingDown(latestState); - List nodes = latestState.getNodes() - .stream() - .filter(d -> nodesShuttingDown.contains(d.getId()) == false) - .filter(StartTrainedModelDeploymentAction.TaskParams::mayAssignToNode) - .collect(Collectors.toList()); // Set the allocation state and reason if we have it for (AssignmentStats stats : updatedResponse.getStats().results()) { TrainedModelAssignment trainedModelAssignment = assignment.getModelAssignment(stats.getModelId()); if (trainedModelAssignment != null) { stats.setState(trainedModelAssignment.getAssignmentState()).setReason(trainedModelAssignment.getReason().orElse(null)); if (trainedModelAssignment.getAssignmentState().isAnyOf(AssignmentState.STARTED, AssignmentState.STARTING)) { - stats.setAllocationStatus(trainedModelAssignment.calculateAllocationStatus(nodes).orElse(null)); + stats.setAllocationStatus(trainedModelAssignment.calculateAllocationStatus().orElse(null)); } } } @@ -178,7 +169,7 @@ protected void doExecute( */ static GetDeploymentStatsAction.Response addFailedRoutes( GetDeploymentStatsAction.Response tasksResponse, - Map> assignmentNonStartedRoutes, + Map> assignmentNonStartedRoutes, DiscoveryNodes nodes ) { final Map modelToAssignmentWithNonStartedRoutes = assignmentNonStartedRoutes.keySet() @@ -190,7 +181,7 @@ static GetDeploymentStatsAction.Response addFailedRoutes( for (AssignmentStats stat : tasksResponse.getStats().results()) { if (modelToAssignmentWithNonStartedRoutes.containsKey(stat.getModelId())) { // there is merging to be done - Map nodeToRoutingStates = assignmentNonStartedRoutes.get( + Map nodeToRoutingStates = assignmentNonStartedRoutes.get( modelToAssignmentWithNonStartedRoutes.get(stat.getModelId()) ); List updatedNodeStats = new ArrayList<>(); @@ -202,12 +193,12 @@ static GetDeploymentStatsAction.Response addFailedRoutes( // and we have a non-started routing entry. // Prefer the entry from assignmentNonStartedRoutes as we cannot be sure // of the state of the task - it may be starting, started, stopping, or stopped. - RoutingStateAndReason stateAndReason = nodeToRoutingStates.get(nodeStat.getNode().getId()); + RoutingInfo routingInfo = nodeToRoutingStates.get(nodeStat.getNode().getId()); updatedNodeStats.add( AssignmentStats.NodeStats.forNotStartedState( nodeStat.getNode(), - stateAndReason.getState(), - stateAndReason.getReason() + routingInfo.getState(), + routingInfo.getReason() ) ); } else { @@ -317,11 +308,14 @@ protected void taskOperation( nodeStats.add(AssignmentStats.NodeStats.forNotStartedState(clusterService.localNode(), RoutingState.STOPPED, "")); } + TrainedModelAssignment assignment = TrainedModelAssignmentMetadata.fromState(clusterService.state()) + .getModelAssignment(task.getModelId()); + listener.onResponse( new AssignmentStats( task.getModelId(), task.getParams().getThreadsPerAllocation(), - task.getParams().getNumberOfAllocations(), + assignment == null ? task.getParams().getNumberOfAllocations() : assignment.getTaskParams().getNumberOfAllocations(), task.getParams().getQueueCapacity(), TrainedModelAssignmentMetadata.fromState(clusterService.state()).getModelAssignment(task.getModelId()).getStartTime(), nodeStats diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java index c7c8f25e15d60..bfd2823a970be 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.ml.action; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.FailedNodeException; @@ -14,7 +16,6 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.tasks.TransportTasksAction; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.Randomness; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -32,12 +33,16 @@ import java.util.List; +import static org.elasticsearch.core.Strings.format; + public class TransportInferTrainedModelDeploymentAction extends TransportTasksAction< TrainedModelDeploymentTask, InferTrainedModelDeploymentAction.Request, InferTrainedModelDeploymentAction.Response, InferTrainedModelDeploymentAction.Response> { + private static final Logger logger = LogManager.getLogger(TransportInferTrainedModelDeploymentAction.class); + private final TrainedModelProvider provider; @Inject @@ -94,16 +99,17 @@ protected void doExecute( listener.onFailure(ExceptionsHelper.conflictStatusException(message)); return; } - String[] randomRunningNode = assignment.getStartedNodes(); - if (randomRunningNode.length == 0) { - String message = "Trained model [" + deploymentId + "] is not allocated to any nodes"; - listener.onFailure(ExceptionsHelper.conflictStatusException(message)); - return; - } - // TODO Do better routing for inference calls - int nodeIndex = Randomness.get().nextInt(randomRunningNode.length); - request.setNodes(randomRunningNode[nodeIndex]); - super.doExecute(task, request, listener); + logger.trace(() -> format("[%s] selecting node from routing table: %s", assignment.getModelId(), assignment.getNodeRoutingTable())); + assignment.selectRandomStartedNodeWeighedOnAllocations().ifPresentOrElse(node -> { + logger.trace(() -> format("[%s] selected node [%s]", assignment.getModelId(), node)); + request.setNodes(node); + super.doExecute(task, request, listener); + }, () -> { + logger.trace(() -> format("[%s] model not allocated to any node [%s]", assignment.getModelId())); + listener.onFailure( + ExceptionsHelper.conflictStatusException("Trained model [" + deploymentId + "] is not allocated to any nodes") + ); + }); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java index a7cd4ea557b21..438b4baacfd2d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java @@ -50,8 +50,8 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; -import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation; @@ -301,6 +301,7 @@ private void deleteFailedDeployment( Exception exception, ActionListener listener ) { + logger.trace(() -> format("[{}] Deleting failed deployment", modelId), exception); trainedModelAssignmentService.deleteModelAssignment(modelId, ActionListener.wrap(pTask -> listener.onFailure(exception), e -> { logger.error( () -> format( @@ -452,16 +453,16 @@ public boolean test(ClusterState clusterState) { return true; } - final Set> nodesAndState = trainedModelAssignment.getNodeRoutingTable().entrySet(); + final Set> nodeIdsAndRouting = trainedModelAssignment.getNodeRoutingTable().entrySet(); Map nodeFailuresAndReasons = new HashMap<>(); Set nodesStillInitializing = new LinkedHashSet<>(); - for (Map.Entry nodeIdAndState : nodesAndState) { - if (RoutingState.FAILED.equals(nodeIdAndState.getValue().getState())) { - nodeFailuresAndReasons.put(nodeIdAndState.getKey(), nodeIdAndState.getValue().getReason()); + for (Map.Entry nodeIdAndRouting : nodeIdsAndRouting) { + if (RoutingState.FAILED.equals(nodeIdAndRouting.getValue().getState())) { + nodeFailuresAndReasons.put(nodeIdAndRouting.getKey(), nodeIdAndRouting.getValue().getReason()); } - if (RoutingState.STARTING.equals(nodeIdAndState.getValue().getState())) { - nodesStillInitializing.add(nodeIdAndState.getKey()); + if (RoutingState.STARTING.equals(nodeIdAndRouting.getValue().getState())) { + nodesStillInitializing.add(nodeIdAndRouting.getKey()); } } @@ -482,7 +483,7 @@ public boolean test(ClusterState clusterState) { OptionalLong smallestMLNode = nodes.stream().map(NodeLoadDetector::getNodeSize).flatMapToLong(OptionalLong::stream).min(); // No nodes allocated at all! - if (nodesAndState.isEmpty() + if (nodeIdsAndRouting.isEmpty() // We cannot scale horizontally && maxLazyMLNodes <= nodes.size() // We cannot scale vertically @@ -500,7 +501,7 @@ public boolean test(ClusterState clusterState) { return true; } - AllocationStatus allocationStatus = trainedModelAssignment.calculateAllocationStatus(nodes).orElse(null); + AllocationStatus allocationStatus = trainedModelAssignment.calculateAllocationStatus().orElse(null); if (allocationStatus == null || allocationStatus.calculateState().compareTo(waitForState) >= 0) { return true; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelAssignmentStateAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelAssignmentStateAction.java index c1841653a3fd4..31558ad127bd8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelAssignmentStateAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelAssignmentStateAction.java @@ -20,8 +20,8 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentStateAction; -import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentStateAction.Request; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction.Request; import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterService; public class TransportUpdateTrainedModelAssignmentStateAction extends AcknowledgedTransportMasterNodeAction { @@ -38,7 +38,7 @@ public TransportUpdateTrainedModelAssignmentStateAction( IndexNameExpressionResolver indexNameExpressionResolver ) { super( - UpdateTrainedModelAssignmentStateAction.NAME, + UpdateTrainedModelAssignmentRoutingInfoAction.NAME, false, transportService, clusterService, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderService.java index ae6f5f7f11292..e7fbdd959f709 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderService.java @@ -662,7 +662,7 @@ public AutoscalingDeciderResult scale(Settings configuration, AutoscalingDecider // Given maxOpenJobs, could we scale down to just one node? // We have no way of saying "we need X nodes" if (nodeLoads.size() > 1) { - long totalAssignedJobs = nodeLoads.stream().mapToLong(NodeLoad::getNumAssignedJobs).sum(); + long totalAssignedJobs = nodeLoads.stream().mapToLong(NodeLoad::getNumAssignedJobsAndModels).sum(); // one volatile read long maxOpenJobsCopy = this.maxOpenJobs; if (totalAssignedJobs > maxOpenJobsCopy) { @@ -876,7 +876,7 @@ Optional checkForScaleUp( Tuple> modelCapacityAndNewLoad = determineUnassignableJobs( waitingAllocatedModels, this::getAllocatedModelRequirement, - NodeLoad.Builder::incNumAssignedNativeInferenceJobs, + NodeLoad.Builder::incNumAssignedNativeInferenceModels, 0, analyticsCapacityAndNewLoad.v2() ).orElse(Tuple.tuple(NativeMemoryCapacity.ZERO, analyticsCapacityAndNewLoad.v2())); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java index 3ae6d7afc53ff..b78029e088f6e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; @@ -26,15 +25,16 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.gateway.GatewayService; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; -import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentStateAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.ml.MachineLearning; @@ -42,32 +42,38 @@ import org.elasticsearch.xpack.ml.job.NodeLoadDetector; import java.util.Collections; -import java.util.Comparator; -import java.util.Locale; +import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.TreeMap; import java.util.function.Function; import java.util.stream.Collectors; import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata.fromState; public class TrainedModelAssignmentClusterService implements ClusterStateListener { private static final Logger logger = LogManager.getLogger(TrainedModelAssignmentClusterService.class); private static final Version RENAME_ALLOCATION_TO_ASSIGNMENT_VERSION = Version.V_8_3_0; + public static final Version DISTRIBUTED_MODEL_ALLOCATION_VERSION = Version.V_8_4_0; private final ClusterService clusterService; + private final ThreadPool threadPool; private final NodeLoadDetector nodeLoadDetector; private volatile int maxMemoryPercentage; private volatile boolean useAuto; private volatile int maxOpenJobs; - public TrainedModelAssignmentClusterService(Settings settings, ClusterService clusterService, NodeLoadDetector nodeLoadDetector) { + public TrainedModelAssignmentClusterService( + Settings settings, + ClusterService clusterService, + ThreadPool threadPool, + NodeLoadDetector nodeLoadDetector + ) { this.clusterService = clusterService; + this.threadPool = threadPool; this.nodeLoadDetector = nodeLoadDetector; this.maxMemoryPercentage = MachineLearning.MAX_MACHINE_MEMORY_PERCENT.get(settings); this.useAuto = MachineLearning.USE_AUTO_MACHINE_MEMORY_PERCENT.get(settings); @@ -105,38 +111,63 @@ public void clusterChanged(ClusterChangedEvent event) { if (event.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) { return; } - if (event.localNodeMaster() && shouldAllocateModels(event)) { - submitUnbatchedTask("allocating models to nodes", new ClusterStateUpdateTask() { + if (event.localNodeMaster() == false) { + return; + } + + if (event.state().nodes().getMinNodeVersion().before(DISTRIBUTED_MODEL_ALLOCATION_VERSION)) { + // we should not try to rebalance assignments while there may be nodes running on a version + // prior to introducing distributed model allocation. + // But we should remove routing to removed or shutting down nodes. + removeRoutingToRemovedOrShuttingDownNodes(event); + return; + } + + if (shouldRebalanceModels(event)) { + // TODO this has a weird side-effect for allocating to nodes + // If the event indicates there were nodes added/removed, this method only looks at the current state and has + // no previous knowledge of existing nodes. Consequently, if a model was manually removed (task-kill) from a node + // it may get re-allocated to that node when another node is added/removed... + // + // As this produces a cluster state update task, we are certain that if the persistent + // task framework results in assigning some ML tasks on that same cluster state change + // we do not end up over-allocating a node. Both this service and the persistent task service + // will produce a cluster state update but the one that gets applied first wins. The other + // update will be rejected and we will retry to assign getting a correct update on available memory + // on each node. + rebalanceAssignments( + event.state(), + Optional.empty(), + "nodes changed", + ActionListener.wrap( + newMetadata -> logger.debug( + () -> format("rebalanced model assignments [%s]", Strings.toString(newMetadata, false, true)) + ), + e -> logger.warn("failed to rebalance models", e) + ) + ); + } + } + + private void removeRoutingToRemovedOrShuttingDownNodes(ClusterChangedEvent event) { + if (areAssignedNodesRemoved(event)) { + submitUnbatchedTask("removing routing entries for removed or shutting down nodes", new ClusterStateUpdateTask() { @Override public ClusterState execute(ClusterState currentState) { - // TODO this has a weird side-effect for allocating to nodes - // If the event indicates there were nodes added/removed, this method only looks at the current state and has - // no previous knowledge of existing nodes. Consequently, if a model was manually removed (task-kill) from a node - // it may get re-allocated to that node when another node is added/removed... - - // As this produces a cluster state update task, we are certain that if the persistent - // task framework results in assigning some ML tasks on that same cluster state change - // we do not end up over-allocating a node. Both this service and the persistant task service - // will produce a cluster state update but the one that gets applied first wins. The other - // update will be rejected and we will retry to assign getting a correct update on available memory - // on each node. - // Also, note that as this service is a returned as a component of the ML plugin, - // and components are created before persistent task executors, we will try to allocate - // trained models before we try to assign ML persistent tasks. - return addRemoveAssignmentNodes(currentState); + return removeRoutingToUnassignableNodes(currentState); } @Override public void onFailure(Exception e) { - logger.warn("failed to allocate models", e); + logger.error("could not remove routing entries for removed or shutting down nodes", e); } @Override public void clusterStateProcessed(ClusterState oldState, ClusterState newState) { - logger.trace( + logger.debug( () -> format( "updated model assignments based on node changes in the cluster; new metadata [%s]", - Strings.toString(fromState(newState), false, true) + Strings.toString(TrainedModelAssignmentMetadata.fromState(newState), false, true) ) ); } @@ -144,10 +175,58 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState) } } + // Visible for testing + static boolean areAssignedNodesRemoved(ClusterChangedEvent event) { + boolean nodesShutdownChanged = event.changedCustomMetadataSet().contains(NodesShutdownMetadata.TYPE); + if (event.nodesRemoved() || nodesShutdownChanged) { + Set removedOrShuttingDownNodeIds = new HashSet<>(nodesShuttingDown(event.state())); + event.nodesDelta().removedNodes().stream().map(DiscoveryNode::getId).forEach(removedOrShuttingDownNodeIds::add); + + TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.fromState(event.state()); + for (TrainedModelAssignment assignment : metadata.modelAssignments().values()) { + if (Sets.intersection(removedOrShuttingDownNodeIds, assignment.getNodeRoutingTable().keySet()).isEmpty() == false) { + return true; + } + } + } + return false; + } + + // Visible for testing + static ClusterState removeRoutingToUnassignableNodes(ClusterState currentState) { + Set assignableNodes = getAssignableNodes(currentState).stream().map(DiscoveryNode::getId).collect(Collectors.toSet()); + TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.fromState(currentState); + TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(currentState); + for (TrainedModelAssignment assignment : metadata.modelAssignments().values()) { + Set routedNodeIdsToRemove = Sets.difference(assignment.getNodeRoutingTable().keySet(), assignableNodes); + if (routedNodeIdsToRemove.isEmpty() == false) { + logger.debug( + () -> format( + "[%s] removing routing entries to nodes {} because they have been removed or are shutting down", + assignment.getModelId(), + routedNodeIdsToRemove + ) + ); + TrainedModelAssignment.Builder assignmentBuilder = TrainedModelAssignment.Builder.fromAssignment(assignment); + routedNodeIdsToRemove.forEach(assignmentBuilder::removeRoutingEntry); + builder.updateAssignment(assignment.getModelId(), assignmentBuilder.calculateAndSetAssignmentState()); + } + } + return update(currentState, builder); + } + public void updateModelRoutingTable( - UpdateTrainedModelAssignmentStateAction.Request request, + UpdateTrainedModelAssignmentRoutingInfoAction.Request request, ActionListener listener ) { + logger.debug( + () -> format( + "[%s] updating routing table entry for node [%s], update [%s]", + request.getModelId(), + request.getNodeId(), + request.getUpdate() + ) + ); submitUnbatchedTask("updating model routing for node assignment", new ClusterStateUpdateTask() { @Override public ClusterState execute(ClusterState currentState) { @@ -170,22 +249,44 @@ public void createNewModelAssignment( StartTrainedModelDeploymentAction.TaskParams params, ActionListener listener ) { - submitUnbatchedTask("create model assignment", new ClusterStateUpdateTask() { - @Override - public ClusterState execute(ClusterState currentState) { - return createModelAssignment(currentState, params); - } + if (clusterService.state().nodes().getMinNodeVersion().before(DISTRIBUTED_MODEL_ALLOCATION_VERSION)) { + listener.onFailure( + new ElasticsearchStatusException( + "cannot create new assignment for model [{}] while there are nodes older than version [{}]", + RestStatus.CONFLICT, + params.getModelId(), + DISTRIBUTED_MODEL_ALLOCATION_VERSION + ) + ); + return; + } - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } + if (MlMetadata.getMlMetadata(clusterService.state()).isResetMode()) { + listener.onFailure( + new ElasticsearchStatusException( + "cannot create new assignment for model [{}] while feature reset is in progress.", + RestStatus.CONFLICT, + params.getModelId() + ) + ); + return; + } - @Override - public void clusterStateProcessed(ClusterState oldState, ClusterState newState) { - listener.onResponse(TrainedModelAssignmentMetadata.fromState(newState).getModelAssignment(params.getModelId())); - } - }); + rebalanceAssignments( + clusterService.state(), + Optional.of(params), + "model [" + params.getModelId() + "] started", + ActionListener.wrap(newMetadata -> { + TrainedModelAssignment assignment = newMetadata.getModelAssignment(params.getModelId()); + if (assignment == null) { + // If we could not allocate the model anywhere then it is possible the assignment + // here is null. We should notify the listener of an empty assignment as the + // handling of this is done elsewhere with the wait-to-start predicate. + assignment = TrainedModelAssignment.Builder.empty(params).build(); + } + listener.onResponse(assignment); + }, listener::onFailure) + ); } public void setModelAssignmentToStopping(String modelId, ActionListener listener) { @@ -247,10 +348,12 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState) } private static ClusterState update(ClusterState currentState, TrainedModelAssignmentMetadata.Builder modelAssignments) { - if (modelAssignments.isChanged()) { - return forceUpdate(currentState, modelAssignments); - } else { + TrainedModelAssignmentMetadata previousMetadata = TrainedModelAssignmentMetadata.fromState(currentState); + TrainedModelAssignmentMetadata updatedMetadata = modelAssignments.build(); + if (updatedMetadata.equals(previousMetadata)) { return currentState; + } else { + return forceUpdate(currentState, modelAssignments); } } @@ -265,42 +368,94 @@ private static ClusterState forceUpdate(ClusterState currentState, TrainedModelA return ClusterState.builder(currentState).metadata(metadata).build(); } - ClusterState createModelAssignment(ClusterState currentState, StartTrainedModelDeploymentAction.TaskParams params) { - if (MlMetadata.getMlMetadata(currentState).isResetMode()) { - throw new ElasticsearchStatusException( - "cannot create new assignment for model [{}] while feature reset is in progress.", - RestStatus.CONFLICT, - params.getModelId() - ); - } - TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(currentState); - if (builder.hasModel(params.getModelId())) { - throw new ResourceAlreadyExistsException("assignment for model with id [{}] already exist", params.getModelId()); - } - TrainedModelAssignment.Builder assignmentBuilder = TrainedModelAssignment.Builder.empty(params); - - Set shuttingDownNodes = nodesShuttingDown(currentState); - Map nodeToReason = new TreeMap<>(); - for (DiscoveryNode node : currentState.getNodes()) { - if (StartTrainedModelDeploymentAction.TaskParams.mayAssignToNode(node) && shuttingDownNodes.contains(node.getId()) == false) { - Optional maybeError = nodeHasCapacity(currentState, params, node); - if (maybeError.isPresent()) { - nodeToReason.put(node.getName(), maybeError.get()); - } else { - assignmentBuilder.addNewRoutingEntry(node.getId()); - } + ClusterState createModelAssignment(ClusterState currentState, StartTrainedModelDeploymentAction.TaskParams params) throws Exception { + return update(currentState, rebalanceAssignments(currentState, Optional.of(params))); + } + + private void rebalanceAssignments( + ClusterState clusterState, + Optional modelToAdd, + String reason, + ActionListener listener + ) { + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { + logger.debug(() -> format("Rebalancing model allocations because [%s]", reason)); + TrainedModelAssignmentMetadata.Builder rebalancedMetadata; + try { + rebalancedMetadata = rebalanceAssignments(clusterState, modelToAdd); + } catch (Exception e) { + listener.onFailure(e); + return; } - } - if (nodeToReason.isEmpty() == false) { - assignmentBuilder.setReason( - nodeToReason.entrySet() - .stream() - .map(entry -> String.format(Locale.ROOT, "Not allocating on node [%s]. Reason: %s", entry.getKey(), entry.getValue())) - .collect(Collectors.joining("|")) + + submitUnbatchedTask(reason, new ClusterStateUpdateTask() { + @Override + public ClusterState execute(ClusterState currentState) { + + if (areClusterStatesCompatibleForRebalance(clusterState, currentState)) { + return update(currentState, rebalancedMetadata); + } + rebalanceAssignments(currentState, modelToAdd, reason, listener); + return currentState; + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + + @Override + public void clusterStateProcessed(ClusterState oldState, ClusterState newState) { + listener.onResponse(TrainedModelAssignmentMetadata.fromState(newState)); + } + }); + }); + } + + private boolean areClusterStatesCompatibleForRebalance(ClusterState source, ClusterState target) { + List sourceNodes = getAssignableNodes(source); + List targetNodes = getAssignableNodes(target); + // We also compare node loads as it could be that another ML job has been started meanwhile + return sourceNodes.equals(targetNodes) + && detectNodeLoads(sourceNodes, source).equals(detectNodeLoads(targetNodes, target)) + && MlMetadata.getMlMetadata(source).equals(MlMetadata.getMlMetadata(target)) + && TrainedModelAssignmentMetadata.fromState(source).equals(TrainedModelAssignmentMetadata.fromState(target)); + } + + private TrainedModelAssignmentMetadata.Builder rebalanceAssignments( + ClusterState currentState, + Optional modelToAdd + ) throws Exception { + List nodes = getAssignableNodes(currentState); + logger.debug(() -> format("assignable nodes are %s", nodes.stream().map(DiscoveryNode::getId).toList())); + Map nodeLoads = detectNodeLoads(nodes, currentState); + TrainedModelAssignmentRebalancer rebalancer = new TrainedModelAssignmentRebalancer( + TrainedModelAssignmentMetadata.fromState(currentState), + nodeLoads, + modelToAdd + ); + return rebalancer.rebalance(); + } + + private static List getAssignableNodes(ClusterState clusterState) { + final Set shuttingDownNodes = nodesShuttingDown(clusterState); + return clusterState.getNodes() + .getNodes() + .values() + .stream() + .filter(StartTrainedModelDeploymentAction.TaskParams::mayAssignToNode) + .filter(n -> shuttingDownNodes.contains(n.getId()) == false) + .toList(); + } + + private Map detectNodeLoads(List nodes, ClusterState clusterState) { + return nodes.stream() + .collect( + Collectors.toMap( + Function.identity(), + n -> nodeLoadDetector.detectNodeLoad(clusterState, null, n, maxOpenJobs, maxMemoryPercentage, useAuto) + ) ); - } - builder.addNewAssignment(params.getModelId(), assignmentBuilder); - return update(currentState, builder); } static ClusterState setToStopping(ClusterState clusterState, String modelId, String reason) { @@ -318,7 +473,7 @@ static ClusterState setToStopping(ClusterState clusterState, String modelId, Str return update(clusterState, builder); } - static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTrainedModelAssignmentStateAction.Request request) { + static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTrainedModelAssignmentRoutingInfoAction.Request request) { final String modelId = request.getModelId(); final String nodeId = request.getNodeId(); TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.fromState(currentState); @@ -326,7 +481,8 @@ static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTra final TrainedModelAssignment existingAssignment = metadata.getModelAssignment(modelId); final TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(currentState); // If state is stopped, this indicates the node process is closed, remove the node from the assignment - if (request.getRoutingState().getState().equals(RoutingState.STOPPED)) { + if (request.getUpdate().getStateAndReason().isPresent() + && request.getUpdate().getStateAndReason().get().getState().equals(RoutingState.STOPPED)) { if (existingAssignment == null || existingAssignment.isRoutedToNode(nodeId) == false) { return currentState; } @@ -340,19 +496,17 @@ static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTra // If we are stopping, don't update anything if (existingAssignment.getAssignmentState().equals(AssignmentState.STOPPING)) { logger.debug( - () -> format( - "[%s] requested update from node [%s] to update route state to [%s]", - modelId, - nodeId, - request.getRoutingState() - ) + () -> format("[%s] requested update from node [%s] while stopping; update was [%s]", modelId, nodeId, request.getUpdate()) ); return currentState; } if (existingAssignment.isRoutedToNode(nodeId) == false) { throw new ResourceNotFoundException("assignment for model with id [{}]] is not routed to node [{}]", modelId, nodeId); } - builder.getAssignment(modelId).updateExistingRoutingEntry(nodeId, request.getRoutingState()).calculateAndSetAssignmentState(); + RoutingInfo routingInfo = existingAssignment.getNodeRoutingTable().get(nodeId); + builder.getAssignment(modelId) + .updateExistingRoutingEntry(nodeId, request.getUpdate().apply(routingInfo)) + .calculateAndSetAssignmentState(); return update(currentState, builder); } @@ -362,6 +516,7 @@ static ClusterState removeAssignment(ClusterState currentState, String modelId) if (builder.hasModel(modelId) == false) { throw new ResourceNotFoundException("assignment for model with id [{}] not found", modelId); } + logger.debug(() -> format("[%s] removing assignment", modelId)); return update(currentState, builder.removeAssignment(modelId)); } @@ -372,71 +527,7 @@ static ClusterState removeAllAssignments(ClusterState currentState) { return forceUpdate(currentState, TrainedModelAssignmentMetadata.Builder.empty()); } - ClusterState addRemoveAssignmentNodes(ClusterState currentState) { - final TrainedModelAssignmentMetadata previousState = TrainedModelAssignmentMetadata.fromState(currentState); - final TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(currentState); - Set shuttingDownNodes = nodesShuttingDown(currentState); - Map currentEligibleNodes = currentState.getNodes() - .stream() - // TODO: Change when we update `mayAllocateToNode` - .filter( - node -> shuttingDownNodes.contains(node.getId()) == false - && StartTrainedModelDeploymentAction.TaskParams.mayAssignToNode(node) - ) - .collect(Collectors.toMap(DiscoveryNode::getId, Function.identity())); - // TODO: make more efficient, we iterate every entry, sorting by nodes routed (fewest to most) - previousState.modelAssignments() - .entrySet() - .stream() - .filter(entry -> entry.getValue().getAssignmentState().equals(AssignmentState.STOPPING) == false) - .sorted(Comparator.comparing(e -> e.getValue().getNodeRoutingTable().size())) - .forEach(modelAssignmentEntry -> { - final String modelId = modelAssignmentEntry.getKey(); - Map nodeToReason = new TreeMap<>(); - for (DiscoveryNode node : currentEligibleNodes.values()) { - if (modelAssignmentEntry.getValue().isRoutedToNode(node.getId()) == false) { - Optional failure = builder.isChanged() ? - // We use the builder only if we have changed, there is no point in creating a new object if we haven't changed - nodeHasCapacity(currentState, builder, modelAssignmentEntry.getValue().getTaskParams(), node) - : nodeHasCapacity(currentState, modelAssignmentEntry.getValue().getTaskParams(), node); - if (failure.isPresent()) { - nodeToReason.put(node.getName(), failure.get()); - } else { - builder.getAssignment(modelId).addNewRoutingEntry(node.getId()); - } - } - } - if (nodeToReason.isEmpty() == false) { - builder.getAssignment(modelId) - .setReason( - nodeToReason.entrySet() - .stream() - .map( - entry -> String.format( - Locale.ROOT, - "Not allocating on node [%s]. Reason: %s", - entry.getKey(), - entry.getValue() - ) - ) - .collect(Collectors.joining("|")) - ); - } else { - builder.getAssignment(modelId).clearReason(); - } - for (String nodeId : modelAssignmentEntry.getValue().getNodeRoutingTable().keySet()) { - if (currentEligibleNodes.containsKey(nodeId) == false) { - builder.getAssignment(modelId).removeRoutingEntry(nodeId); - } - } - // It may be we moved from STARTED to PARTIALLY_STARTED with the addition of new nodes - // Or moved from PARTIALLY_STARTED to STARTED if a node was removed - builder.getAssignment(modelId).calculateAndSetAssignmentState(); - }); - return update(currentState, builder); - } - - static boolean shouldAllocateModels(final ClusterChangedEvent event) { + static boolean shouldRebalanceModels(final ClusterChangedEvent event) { // If there are no assignments created at all, there is nothing to update final TrainedModelAssignmentMetadata newMetadata = TrainedModelAssignmentMetadata.fromState(event.state()); if (newMetadata == null || newMetadata.modelAssignments().isEmpty()) { @@ -481,24 +572,48 @@ static boolean shouldAllocateModels(final ClusterChangedEvent event) { exitingShutDownNodes = Collections.emptySet(); } + logger.debug( + () -> format( + "added nodes %s; removed nodes %s; shutting down nodes %s; exiting shutdown nodes %s", + addedNodes, + removedNodes, + shuttingDownNodes, + exitingShutDownNodes + ) + ); for (TrainedModelAssignment trainedModelAssignment : newMetadata.modelAssignments().values()) { if (trainedModelAssignment.getAssignmentState().equals(AssignmentState.STOPPING)) { continue; } for (var nodeId : exitingShutDownNodes) { if (trainedModelAssignment.isRoutedToNode(nodeId)) { + logger.debug( + () -> format( + "should rebalance because model [%s] has allocations on shutting down node [%s]", + trainedModelAssignment.getModelId(), + nodeId + ) + ); return true; } } for (var nodeId : removedNodes) { if (trainedModelAssignment.isRoutedToNode(nodeId) && shuttingDownNodes.contains(nodeId) == false) { + logger.debug( + () -> format( + "should rebalance because model [%s] has allocations on removed node [%s]", + trainedModelAssignment.getModelId(), + nodeId + ) + ); return true; } } for (var nodeId : addedNodes) { if (StartTrainedModelDeploymentAction.TaskParams.mayAssignToNode(event.state().nodes().get(nodeId)) && shuttingDownNodes.contains(nodeId) == false) { + logger.debug(() -> format("should rebalance because ML eligible node [%s] was added", nodeId)); return true; } } @@ -507,63 +622,6 @@ static boolean shouldAllocateModels(final ClusterChangedEvent event) { return false; } - Optional nodeHasCapacity(ClusterState state, StartTrainedModelDeploymentAction.TaskParams params, DiscoveryNode node) { - NodeLoad load = nodeLoadDetector.detectNodeLoad(state, node, maxOpenJobs, maxMemoryPercentage, useAuto); - return handleNodeLoad(load, node.getId(), params); - } - - /** - * Gather current node capacity taking the passed assignment metadata into account instead of the one stored in cluster state. - */ - Optional nodeHasCapacity( - ClusterState state, - TrainedModelAssignmentMetadata.Builder builder, - StartTrainedModelDeploymentAction.TaskParams params, - DiscoveryNode node - ) { - NodeLoad load = nodeLoadDetector.detectNodeLoad(state, builder.build(), node, maxOpenJobs, maxMemoryPercentage, useAuto); - return handleNodeLoad(load, node.getId(), params); - } - - Optional handleNodeLoad(NodeLoad load, String nodeId, StartTrainedModelDeploymentAction.TaskParams params) { - if (Strings.isNullOrEmpty(load.getError()) == false) { - logger.warn("[{}] failed to calculate current node load with error [{}]", params.getModelId(), nodeId); - return Optional.of(load.getError()); - } - if (load.remainingJobs() == 0) { - return Optional.of( - org.elasticsearch.core.Strings.format( - "This node is full. Number of opened jobs and allocated native inference processes [%s], %s [%s].", - load.getNumAssignedJobs(), - MachineLearning.MAX_OPEN_JOBS_PER_NODE.getKey(), - maxOpenJobs - ) - ); - } - // If any ML processes are running on a node we require some space to load the shared libraries. - // So if none are currently running then this per-node overhead must be added to the requirement. - long requiredMemory = params.estimateMemoryUsageBytes() + ((load.getNumAssignedJobs() == 0) - ? MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes() - : 0); - if (load.getFreeMemory() < params.estimateMemoryUsageBytes()) { - return Optional.of( - org.elasticsearch.core.Strings.format( - "This node has insufficient available memory. Available memory for ML [%s (%s)], " - + "memory required by existing jobs and models [%s (%s)], " - + "estimated memory required for this model [%s (%s)].", - - load.getMaxMlMemory(), - ByteSizeValue.ofBytes(load.getMaxMlMemory()).toString(), - load.getAssignedJobMemory(), - ByteSizeValue.ofBytes(load.getAssignedJobMemory()).toString(), - requiredMemory, - ByteSizeValue.ofBytes(requiredMemory).toString() - ) - ); - } - return Optional.empty(); - } - /** * Returns the set of nodes that are currently shutting down */ @@ -573,5 +631,4 @@ static Set nodesShuttingDown(final ClusterState state) { .map(Map::keySet) .orElse(Collections.emptySet()); } - } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadata.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadata.java index baad31547e369..392d17e34a5ed 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadata.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadata.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.inference.assignment; import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.Version; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.Diff; @@ -155,7 +156,6 @@ public static Builder empty() { } private final Map modelRoutingEntries; - private boolean isChanged; public static Builder fromMetadata(TrainedModelAssignmentMetadata modelAssignmentMetadata) { return new Builder(modelAssignmentMetadata); @@ -181,7 +181,14 @@ public Builder addNewAssignment(String modelId, TrainedModelAssignment.Builder a throw new ResourceAlreadyExistsException("[{}] assignment already exists", modelId); } modelRoutingEntries.put(modelId, assignment); - isChanged = true; + return this; + } + + public Builder updateAssignment(String modelId, TrainedModelAssignment.Builder assignment) { + if (modelRoutingEntries.containsKey(modelId) == false) { + throw new ResourceNotFoundException("[{}] assignment does not exist", modelId); + } + modelRoutingEntries.put(modelId, assignment); return this; } @@ -190,14 +197,10 @@ public TrainedModelAssignment.Builder getAssignment(String modelId) { } public Builder removeAssignment(String modelId) { - isChanged |= modelRoutingEntries.remove(modelId) != null; + modelRoutingEntries.remove(modelId); return this; } - public boolean isChanged() { - return isChanged || modelRoutingEntries.values().stream().anyMatch(TrainedModelAssignment.Builder::isChanged); - } - public TrainedModelAssignmentMetadata build() { return build(NAME); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java index 891f07f477b4e..ab64f0cec35fe 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java @@ -31,7 +31,9 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; -import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentStateAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfoUpdate; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; @@ -64,6 +66,7 @@ public class TrainedModelAssignmentNodeService implements ClusterStateListener { private static final String NODE_NO_LONGER_REFERENCED = "node no longer referenced in model routing table"; private static final String ASSIGNMENT_NO_LONGER_EXISTS = "model assignment no longer exists"; private static final TimeValue MODEL_LOADING_CHECK_INTERVAL = TimeValue.timeValueSeconds(1); + private static final TimeValue UPDATE_NUMBER_OF_ALLOCATIONS_TIMEOUT = TimeValue.timeValueSeconds(60); private static final Logger logger = LogManager.getLogger(TrainedModelAssignmentNodeService.class); private final TrainedModelAssignmentService trainedModelAssignmentService; private final DeploymentManager deploymentManager; @@ -238,16 +241,20 @@ void loadQueuedModels() { } public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason, ActionListener listener) { + final RoutingInfoUpdate updateToStopped = RoutingInfoUpdate.updateStateAndReason( + new RoutingStateAndReason(RoutingState.STOPPED, reason) + ); + ActionListener notifyDeploymentOfStopped = ActionListener.wrap( - _void -> updateStoredState(task.getModelId(), new RoutingStateAndReason(RoutingState.STOPPED, reason), listener), + _void -> updateStoredState(task.getModelId(), updateToStopped, listener), failed -> { // if we failed to stop the process, something strange is going on, but we should still notify of stop logger.warn(() -> "[" + task.getModelId() + "] failed to stop due to error", failed); - updateStoredState(task.getModelId(), new RoutingStateAndReason(RoutingState.STOPPED, reason), listener); + updateStoredState(task.getModelId(), updateToStopped, listener); } ); updateStoredState( task.getModelId(), - new RoutingStateAndReason(RoutingState.STOPPING, reason), + RoutingInfoUpdate.updateStateAndReason(new RoutingStateAndReason(RoutingState.STOPPING, reason)), ActionListener.wrap(success -> stopDeploymentAsync(task, reason, notifyDeploymentOfStopped), e -> { if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { logger.debug( @@ -316,22 +323,47 @@ public void clusterChanged(ClusterChangedEvent event) { final boolean isResetMode = MlMetadata.getMlMetadata(event.state()).isResetMode(); TrainedModelAssignmentMetadata modelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(event.state()); final String currentNode = event.state().nodes().getLocalNodeId(); + final boolean isNewAllocationSupported = event.state() + .getNodes() + .getMinNodeVersion() + .onOrAfter(TrainedModelAssignmentClusterService.DISTRIBUTED_MODEL_ALLOCATION_VERSION); + + if (isResetMode == false && isNewAllocationSupported) { + updateNumberOfAllocations(modelAssignmentMetadata); + } + for (TrainedModelAssignment trainedModelAssignment : modelAssignmentMetadata.modelAssignments().values()) { - RoutingStateAndReason routingStateAndReason = trainedModelAssignment.getNodeRoutingTable().get(currentNode); + RoutingInfo routingInfo = trainedModelAssignment.getNodeRoutingTable().get(currentNode); // Add new models to start loading - if (routingStateAndReason != null - // periodic retries of `failed` should be handled in a separate process - && routingStateAndReason.getState().isAnyOf(RoutingState.STARTING, RoutingState.STARTED) - // This means we don't already have a task and should attempt creating one and starting the model loading - // If we don't have a task but are STARTED, this means the cluster state had a started assignment, - // the node crashed and then started again - && modelIdToTask.containsKey(trainedModelAssignment.getTaskParams().getModelId()) == false - // If we are in reset mode, don't start loading a new model on this node. - && isResetMode == false) { - prepareModelToLoad(trainedModelAssignment.getTaskParams()); + if (routingInfo != null && isNewAllocationSupported) { + if (routingInfo.getState() == RoutingState.STARTING + && modelIdToTask.containsKey(trainedModelAssignment.getModelId()) + && modelIdToTask.get(trainedModelAssignment.getModelId()).isFailed()) { + // This is a failed assignment and we are restarting it. For this we need to remove the task first. + taskManager.unregister(modelIdToTask.get(trainedModelAssignment.getModelId())); + modelIdToTask.remove(trainedModelAssignment.getModelId()); + } + if (routingInfo.getState().isAnyOf(RoutingState.STARTING, RoutingState.STARTED) // periodic retries of `failed` should + // be handled in a separate process + // This means we don't already have a task and should attempt creating one and starting the model loading + // If we don't have a task but are STARTED, this means the cluster state had a started assignment, + // the node crashed and then started again + && modelIdToTask.containsKey(trainedModelAssignment.getTaskParams().getModelId()) == false + // If we are in reset mode, don't start loading a new model on this node. + && isResetMode == false) { + prepareModelToLoad( + new StartTrainedModelDeploymentAction.TaskParams( + trainedModelAssignment.getModelId(), + trainedModelAssignment.getTaskParams().getModelBytes(), + trainedModelAssignment.getTaskParams().getThreadsPerAllocation(), + routingInfo.getCurrentAllocations(), + trainedModelAssignment.getTaskParams().getQueueCapacity() + ) + ); + } } // This model is not routed to the current node at all - if (routingStateAndReason == null) { + if (routingInfo == null) { TrainedModelDeploymentTask task = modelIdToTask.remove(trainedModelAssignment.getTaskParams().getModelId()); if (task != null) { stopDeploymentAsync( @@ -363,6 +395,59 @@ public void clusterChanged(ClusterChangedEvent event) { } } + private void updateNumberOfAllocations(TrainedModelAssignmentMetadata assignments) { + List modelsToUpdate = assignments.modelAssignments() + .values() + .stream() + .filter(a -> hasStartingAssignments(a) == false) + .filter(a -> a.isRoutedToNode(nodeId)) + .filter(a -> { + RoutingInfo routingInfo = a.getNodeRoutingTable().get(nodeId); + return routingInfo.getState() == RoutingState.STARTED + && routingInfo.getCurrentAllocations() != routingInfo.getTargetAllocations(); + }) + .toList(); + + for (TrainedModelAssignment assignment : modelsToUpdate) { + TrainedModelDeploymentTask task = modelIdToTask.get(assignment.getModelId()); + if (task == null) { + logger.debug(() -> format("[%s] task was removed whilst updating number of allocations", task.getModelId())); + continue; + } + RoutingInfo routingInfo = assignment.getNodeRoutingTable().get(nodeId); + deploymentManager.updateNumAllocations( + task, + assignment.getNodeRoutingTable().get(nodeId).getTargetAllocations(), + UPDATE_NUMBER_OF_ALLOCATIONS_TIMEOUT, + ActionListener.wrap(threadSettings -> { + logger.debug("[{}] Updated number of allocations to [{}]", assignment.getModelId(), threadSettings.numAllocations()); + task.updateNumberOfAllocations(threadSettings.numAllocations()); + updateStoredState( + assignment.getModelId(), + RoutingInfoUpdate.updateNumberOfAllocations(threadSettings.numAllocations()), + ActionListener.noop() + ); + }, + e -> logger.error( + format( + "[%s] Could not update number of allocations to [%s]", + assignment.getModelId(), + routingInfo.getTargetAllocations() + ), + e + ) + ) + ); + } + } + + private boolean hasStartingAssignments(TrainedModelAssignment assignment) { + return assignment.getNodeRoutingTable() + .values() + .stream() + .anyMatch(routingInfo -> routingInfo.getState().isAnyOf(RoutingState.STARTING)); + } + // For testing purposes TrainedModelDeploymentTask getTask(String modelId) { return modelIdToTask.get(modelId); @@ -397,9 +482,10 @@ private void handleLoadSuccess(TrainedModelDeploymentTask task) { ); return; } + updateStoredState( modelId, - new RoutingStateAndReason(RoutingState.STARTED, ""), + RoutingInfoUpdate.updateStateAndReason(new RoutingStateAndReason(RoutingState.STARTED, "")), ActionListener.wrap(r -> logger.debug(() -> "[" + modelId + "] model loaded and accepting routes"), e -> { // This means that either the assignment has been deleted, or this node's particular route has been removed if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { @@ -410,31 +496,25 @@ private void handleLoadSuccess(TrainedModelDeploymentTask task) { ), e ); + } else { + // this is an unexpected error + logger.warn(() -> "[" + modelId + "] model loaded but failed to start accepting routes", e); } - // this is an unexpected error - logger.warn(() -> "[" + modelId + "] model loaded but failed to start accepting routes", e); }) ); } - private void updateStoredState( - String modelId, - RoutingStateAndReason routingStateAndReason, - ActionListener listener - ) { + private void updateStoredState(String modelId, RoutingInfoUpdate update, ActionListener listener) { if (stopped) { return; } trainedModelAssignmentService.updateModelAssignmentState( - new UpdateTrainedModelAssignmentStateAction.Request(nodeId, modelId, routingStateAndReason), + new UpdateTrainedModelAssignmentRoutingInfoAction.Request(nodeId, modelId, update), ActionListener.wrap(success -> { - logger.debug(() -> format("[%s] model is [%s] and master notified", modelId, routingStateAndReason.getState())); + logger.debug(() -> format("[%s] model routing info was updated with [%s] and master notified", modelId, update)); listener.onResponse(AcknowledgedResponse.TRUE); }, error -> { - logger.warn( - () -> format("[%s] model is [%s] but failed to notify master", modelId, routingStateAndReason.getState()), - error - ); + logger.warn(() -> format("[%s] failed to update model routing info with [%s]", modelId, update), error); listener.onFailure(error); }) ); @@ -460,7 +540,9 @@ private void handleLoadFailure(TrainedModelDeploymentTask task, Exception ex) { ); updateStoredState( task.getModelId(), - new RoutingStateAndReason(RoutingState.FAILED, ExceptionsHelper.unwrapCause(ex).getMessage()), + RoutingInfoUpdate.updateStateAndReason( + new RoutingStateAndReason(RoutingState.FAILED, ExceptionsHelper.unwrapCause(ex).getMessage()) + ), ActionListener.wrap(r -> stopTask.run(), e -> stopTask.run()) ); } @@ -468,7 +550,7 @@ private void handleLoadFailure(TrainedModelDeploymentTask task, Exception ex) { public void failAssignment(TrainedModelDeploymentTask task, String reason) { updateStoredState( task.getModelId(), - new RoutingStateAndReason(RoutingState.FAILED, reason), + RoutingInfoUpdate.updateStateAndReason(new RoutingStateAndReason(RoutingState.FAILED, reason)), ActionListener.wrap( r -> logger.debug( () -> format( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java new file mode 100644 index 0000000000000..119e2a605e526 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java @@ -0,0 +1,277 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.assignment; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan; +import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlanner; +import org.elasticsearch.xpack.ml.job.NodeLoad; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; +import java.util.TreeMap; +import java.util.stream.Collectors; + +import static org.elasticsearch.core.Strings.format; + +class TrainedModelAssignmentRebalancer { + + private static final Logger logger = LogManager.getLogger(TrainedModelAssignmentRebalancer.class); + + private final TrainedModelAssignmentMetadata currentMetadata; + private final Map nodeLoads; + private final Optional modelToAdd; + + TrainedModelAssignmentRebalancer( + TrainedModelAssignmentMetadata currentMetadata, + Map nodeLoads, + Optional modelToAdd + ) { + this.currentMetadata = Objects.requireNonNull(currentMetadata); + this.nodeLoads = Objects.requireNonNull(nodeLoads); + this.modelToAdd = Objects.requireNonNull(modelToAdd); + } + + TrainedModelAssignmentMetadata.Builder rebalance() throws Exception { + TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.Builder.fromMetadata(currentMetadata); + if (modelToAdd.isPresent() && builder.hasModel(modelToAdd.get().getModelId())) { + throw new ResourceAlreadyExistsException("assignment for model with id [{}] already exists", modelToAdd.get().getModelId()); + } + + if (modelToAdd.isEmpty() && areAllModelsSatisfied()) { + logger.trace(() -> "No need to rebalance as all model deployments are satisfied"); + return builder; + } + + AssignmentPlan assignmentPlan = computeAssignmentPlan(); + buildAssignmentsFromPlan(assignmentPlan, builder); + return builder; + } + + private boolean areAllModelsSatisfied() { + Set assignableNodeIds = nodeLoads.keySet().stream().map(DiscoveryNode::getId).collect(Collectors.toSet()); + for (TrainedModelAssignment model : currentMetadata.modelAssignments().values()) { + if (model.isSatisfied(assignableNodeIds) == false) { + return false; + } + } + return true; + } + + AssignmentPlan computeAssignmentPlan() { + List planNodes = nodeLoads.entrySet() + .stream() + .filter(e -> Strings.isNullOrEmpty(e.getValue().getError())) + .map( + e -> new AssignmentPlan.Node( + e.getKey().getId(), + // We subtract native inference memory as the planner expects available memory for + // native inference including current assignments. + getNodeFreeMemoryExcludingPerNodeOverheadAndNativeInference(e.getValue()), + getNodeAllocatedProcessors(e.getKey()).orElse(0) + ) + ) + .toList(); + + final List planModels = new ArrayList<>( + currentMetadata.modelAssignments().size() + (modelToAdd.isPresent() ? 1 : 0) + ); + final Set assignableNodeIds = planNodes.stream().map(AssignmentPlan.Node::id).collect(Collectors.toSet()); + currentMetadata.modelAssignments().values().stream().map(assignment -> { + Map currentAssignments = assignment.getNodeRoutingTable() + .entrySet() + .stream() + // Filter out nodes that are no longer assignable + .filter(e -> assignableNodeIds.contains(e.getKey())) + // Filter out allocation without current and target allocations as they are from before using the rebalancer + .filter(e -> e.getValue().getCurrentAllocations() > 0 && e.getValue().getTargetAllocations() > 0) + .filter(e -> e.getValue().getState().isAnyOf(RoutingState.STARTING, RoutingState.STARTED, RoutingState.FAILED)) + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getTargetAllocations())); + return new AssignmentPlan.Model( + assignment.getModelId(), + assignment.getTaskParams().estimateMemoryUsageBytes(), + assignment.getTaskParams().getNumberOfAllocations(), + assignment.getTaskParams().getThreadsPerAllocation(), + currentAssignments + ); + }).forEach(planModels::add); + modelToAdd.ifPresent( + taskParams -> planModels.add( + new AssignmentPlan.Model( + taskParams.getModelId(), + taskParams.estimateMemoryUsageBytes(), + taskParams.getNumberOfAllocations(), + taskParams.getThreadsPerAllocation(), + Map.of() + ) + ) + ); + return new AssignmentPlanner(planNodes, planModels).computePlan(); + } + + private static OptionalInt getNodeAllocatedProcessors(DiscoveryNode node) { + String allocatedProcessorsString = node.getAttributes().get(MachineLearning.ALLOCATED_PROCESSORS_NODE_ATTR); + try { + return OptionalInt.of(Integer.parseInt(allocatedProcessorsString)); + } catch (NumberFormatException e) { + assert e == null + : MachineLearning.ALLOCATED_PROCESSORS_NODE_ATTR + + " should parse because we set it internally: invalid value was " + + allocatedProcessorsString; + return OptionalInt.empty(); + } + } + + private static long getNodeFreeMemoryExcludingPerNodeOverheadAndNativeInference(NodeLoad load) { + return load.getFreeMemoryExcludingPerNodeOverhead() - load.getAssignedNativeInferenceMemory(); + } + + private void buildAssignmentsFromPlan(AssignmentPlan assignmentPlan, TrainedModelAssignmentMetadata.Builder builder) { + for (AssignmentPlan.Model model : assignmentPlan.models()) { + TrainedModelAssignment existingAssignment = currentMetadata.getModelAssignment(model.id()); + + TrainedModelAssignment.Builder assignmentBuilder = TrainedModelAssignment.Builder.empty( + existingAssignment == null && modelToAdd.isPresent() + ? modelToAdd.get() + : currentMetadata.getModelAssignment(model.id()).getTaskParams() + ); + + Map assignments = assignmentPlan.assignments(model).orElseGet(Map::of); + for (Map.Entry assignment : assignments.entrySet()) { + if (existingAssignment != null && existingAssignment.isRoutedToNode(assignment.getKey().id())) { + RoutingInfo existingRoutingInfo = existingAssignment.getNodeRoutingTable().get(assignment.getKey().id()); + RoutingState state = existingRoutingInfo.getState(); + String reason = existingRoutingInfo.getReason(); + if (state == RoutingState.FAILED) { + state = RoutingState.STARTING; + reason = ""; + } + assignmentBuilder.addRoutingEntry( + assignment.getKey().id(), + new RoutingInfo(existingRoutingInfo.getCurrentAllocations(), assignment.getValue(), state, reason) + ); + } else { + assignmentBuilder.addRoutingEntry( + assignment.getKey().id(), + new RoutingInfo(assignment.getValue(), assignment.getValue(), RoutingState.STARTING, "") + ); + } + } + assignmentBuilder.calculateAndSetAssignmentState(); + + explainAssignments(assignmentPlan, nodeLoads, model).ifPresent(assignmentBuilder::setReason); + if (existingAssignment == null) { + builder.addNewAssignment(model.id(), assignmentBuilder); + } else { + builder.updateAssignment(model.id(), assignmentBuilder); + } + } + } + + private Optional explainAssignments( + AssignmentPlan assignmentPlan, + Map nodeLoads, + AssignmentPlan.Model model + ) { + if (assignmentPlan.satisfiesAllocations(model)) { + return Optional.empty(); + } + + if (nodeLoads.isEmpty()) { + return Optional.of("No ML nodes exist in the cluster"); + } + + Map nodeToReason = new TreeMap<>(); + for (Map.Entry nodeAndLoad : nodeLoads.entrySet()) { + Optional reason = explainAssignment(assignmentPlan, nodeAndLoad.getKey(), nodeAndLoad.getValue(), model); + reason.ifPresent(s -> nodeToReason.put(nodeAndLoad.getKey().getId(), s)); + } + + if (nodeToReason.isEmpty() == false) { + return Optional.of( + nodeToReason.entrySet() + .stream() + .map(entry -> format("Could not assign (more) allocations on node [%s]. Reason: %s", entry.getKey(), entry.getValue())) + .collect(Collectors.joining("|")) + ); + } + return Optional.empty(); + } + + private Optional explainAssignment( + AssignmentPlan assignmentPlan, + DiscoveryNode node, + NodeLoad load, + AssignmentPlan.Model model + ) { + if (Strings.isNullOrEmpty(load.getError()) == false) { + return Optional.of(load.getError()); + } + + if (model.memoryBytes() > assignmentPlan.getRemainingNodeMemory(node.getId())) { + // If any ML processes are running on a node we require some space to load the shared libraries. + // So if none are currently running then this per-node overhead must be added to the requirement. + // From node load we know if we had any jobs or models assigned before the rebalance. + // But we should also check if we managed to assign a model during the rebalance for which + // we check if the node has used up any of its allocated processors. + boolean isPerNodeOverheadAccountedFor = load.getNumAssignedJobsAndModels() > 0 + || assignmentPlan.getRemainingNodeCores(load.getNodeId()) < getNodeAllocatedProcessors(node).orElse(0); + long requiredMemory = model.memoryBytes() + (isPerNodeOverheadAccountedFor + ? 0 + : MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes()); + long nodeFreeMemory = assignmentPlan.getRemainingNodeMemory(node.getId()) + (isPerNodeOverheadAccountedFor + ? 0 + : MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes()); + return Optional.of( + ParameterizedMessage.format( + "This node has insufficient available memory. Available memory for ML [{} ({})], " + + "free memory [{} ({})], " + + "estimated memory required for this model [{} ({})].", + new Object[] { + load.getMaxMlMemory(), + ByteSizeValue.ofBytes(load.getMaxMlMemory()).toString(), + nodeFreeMemory, + ByteSizeValue.ofBytes(nodeFreeMemory).toString(), + requiredMemory, + ByteSizeValue.ofBytes(requiredMemory).toString() } + ) + ); + } + + if (model.threadsPerAllocation() > assignmentPlan.getRemainingNodeCores(node.getId())) { + return Optional.of( + ParameterizedMessage.format( + "This node has insufficient allocated processors. Available processors [{}], free processors [{}], " + + "processors required for each allocation of this model [{}]", + new Object[] { + getNodeAllocatedProcessors(node).orElse(0), + assignmentPlan.getRemainingNodeCores(node.getId()), + model.threadsPerAllocation() } + ) + ); + } + + return Optional.empty(); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentService.java index 74893dc204180..18dc4b239bdbd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentService.java @@ -30,7 +30,7 @@ import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; -import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentStateAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import java.util.Objects; @@ -53,7 +53,7 @@ public TrainedModelAssignmentService(Client client, ClusterService clusterServic } public void updateModelAssignmentState( - UpdateTrainedModelAssignmentStateAction.Request request, + UpdateTrainedModelAssignmentRoutingInfoAction.Request request, ActionListener listener ) { ClusterState currentState = clusterService.state(); @@ -61,26 +61,32 @@ public void updateModelAssignmentState( Predicate changePredicate = MasterNodeChangePredicate.build(currentState); DiscoveryNode masterNode = currentState.nodes().getMasterNode(); if (masterNode == null) { - logger.warn( - "[{}] no master known for assignment state update [{}]", - request.getModelId(), - request.getRoutingState().getState() - ); - waitForNewMasterAndRetry(observer, UpdateTrainedModelAssignmentStateAction.INSTANCE, request, listener, changePredicate); + logger.warn("[{}] no master known for assignment update [{}]", request.getModelId(), request.getUpdate()); + waitForNewMasterAndRetry(observer, UpdateTrainedModelAssignmentRoutingInfoAction.INSTANCE, request, listener, changePredicate); return; } - client.execute(UpdateTrainedModelAssignmentStateAction.INSTANCE, request, ActionListener.wrap(listener::onResponse, failure -> { - if (isMasterChannelException(failure)) { - logger.info( - "[{}] master channel exception will retry on new master node for assignment state update [{}]", - request.getModelId(), - request.getRoutingState().getState() - ); - waitForNewMasterAndRetry(observer, UpdateTrainedModelAssignmentStateAction.INSTANCE, request, listener, changePredicate); - return; - } - listener.onFailure(failure); - })); + client.execute( + UpdateTrainedModelAssignmentRoutingInfoAction.INSTANCE, + request, + ActionListener.wrap(listener::onResponse, failure -> { + if (isMasterChannelException(failure)) { + logger.info( + "[{}] master channel exception will retry on new master node for assignment update [{}]", + request.getModelId(), + request.getUpdate() + ); + waitForNewMasterAndRetry( + observer, + UpdateTrainedModelAssignmentRoutingInfoAction.INSTANCE, + request, + listener, + changePredicate + ); + return; + } + listener.onFailure(failure); + }) + ); } public void createNewModelAssignment( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java index 2b6a6e97074d7..6aa71bafb4662 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java @@ -34,7 +34,7 @@ private Node modifyNodePreservingAllocations(Node n) { long bytesUsed = 0; int coresUsed = 0; for (Model m : models) { - if (m.currentAllocationByNodeId().containsKey(n.id())) { + if (m.currentAllocationsByNodeId().containsKey(n.id())) { bytesUsed += m.memoryBytes(); coresUsed += calculateUsedCores(n, m); } @@ -48,7 +48,7 @@ List modelsPreservingAllocations() { } Model modifyModelPreservingPreviousAssignments(Model m) { - if (m.currentAllocationByNodeId().isEmpty()) { + if (m.currentAllocationsByNodeId().isEmpty()) { return m; } @@ -78,8 +78,11 @@ AssignmentPlan mergePreservedAllocations(AssignmentPlan assignmentPlan) { for (Model m : models) { for (Node n : nodes) { int allocations = assignmentsByModelNodeIdPair.getOrDefault(Tuple.tuple(m.id(), n.id()), 0); - if (m.currentAllocationByNodeId().containsKey(n.id())) { + if (m.currentAllocationsByNodeId().containsKey(n.id())) { allocations += addPreservedAllocations(n, m); + // As the node has all its available memory we need to manually account memory of models with + // current allocations. + mergedPlanBuilder.accountMemory(m, n); } if (allocations > 0) { mergedPlanBuilder.assignModelToNode(m, n, allocations); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java index 9d395dca8d6f8..cf8d325ff7460 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java @@ -32,11 +32,11 @@ public record Model( long memoryBytes, int allocations, int threadsPerAllocation, - Map currentAllocationByNodeId + Map currentAllocationsByNodeId ) { int getPreviouslyAssignedAllocations() { - return currentAllocationByNodeId.values().stream().mapToInt(Integer::intValue).sum(); + return currentAllocationsByNodeId.values().stream().mapToInt(Integer::intValue).sum(); } @Override @@ -49,7 +49,7 @@ public String toString() { + ") (threads_per_allocation = " + threadsPerAllocation + ") (current_allocations = " - + currentAllocationByNodeId + + currentAllocationsByNodeId + ")"; } }; @@ -69,8 +69,22 @@ public String toString() { */ private final Map> assignments; - private AssignmentPlan(Map> assignments) { + private final Map remainingNodeMemory; + private final Map remainingNodeCores; + private final Map remainingModelAllocations; + + private AssignmentPlan( + Map> assignments, + Map remainingNodeMemory, + Map remainingNodeCores, + Map remainingModelAllocations + ) { this.assignments = Objects.requireNonNull(assignments); + this.remainingNodeMemory = remainingNodeMemory.entrySet() + .stream() + .collect(Collectors.toMap(e -> e.getKey().id(), e -> e.getValue())); + this.remainingNodeCores = remainingNodeCores.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().id(), e -> e.getValue())); + this.remainingModelAllocations = Objects.requireNonNull(remainingModelAllocations); } public Set models() { @@ -99,7 +113,7 @@ public boolean satisfiesPreviousAssignments() { } private boolean isSatisfyingPreviousAssignmentsForModel(Model m) { - if (m.currentAllocationByNodeId().isEmpty()) { + if (m.currentAllocationsByNodeId().isEmpty()) { return true; } Map nodeAssignments = assignments.get(m); @@ -107,6 +121,18 @@ private boolean isSatisfyingPreviousAssignmentsForModel(Model m) { return currentAllocations >= m.getPreviouslyAssignedAllocations(); } + public boolean satisfiesAllocations(Model m) { + return remainingModelAllocations.getOrDefault(m, 0) == 0; + } + + public int getRemainingNodeCores(String nodeId) { + return remainingNodeCores.getOrDefault(nodeId, 0); + } + + public long getRemainingNodeMemory(String nodeId) { + return remainingNodeMemory.getOrDefault(nodeId, 0L); + } + private Quality computeQuality() { boolean isSatisfyingPreviousAssignments = true; double weighedAllocationsScore = 0; @@ -119,7 +145,7 @@ private Quality computeQuality() { if (modelAssignments != null) { for (Map.Entry nodeAllocations : modelAssignments.entrySet()) { Node n = nodeAllocations.getKey(); - weighedAllocationsScore += (1 + 0.1 * (m.currentAllocationByNodeId().containsKey(n.id()) ? 1 : 0)) * modelAssignments + weighedAllocationsScore += (1 + 0.1 * (m.currentAllocationsByNodeId().containsKey(n.id()) ? 1 : 0)) * modelAssignments .get(n); memoryScore -= (nodeAllocations.getValue() > 0 ? m.memoryBytes() : 0); } @@ -265,7 +291,11 @@ Builder assignModelToNode(Model model, Node node, int allocations) { } private boolean isAlreadyAssigned(Model model, Node node) { - return model.currentAllocationByNodeId().containsKey(node.id()) || assignments.get(model).get(node) > 0; + return model.currentAllocationsByNodeId().containsKey(node.id()) || assignments.get(model).get(node) > 0; + } + + void accountMemory(Model m, Node n) { + remainingNodeMemory.computeIfPresent(n, (k, v) -> v - m.memoryBytes()); } AssignmentPlan build() { @@ -279,7 +309,7 @@ AssignmentPlan build() { } finalAssignments.put(m, allocationsPerNode); } - return new AssignmentPlan(finalAssignments); + return new AssignmentPlan(finalAssignments, remainingNodeMemory, remainingNodeCores, remainingModelAllocations); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java index b27731062bcfe..f49495209f77d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java @@ -16,6 +16,8 @@ import java.util.Comparator; import java.util.List; +import static org.elasticsearch.core.Strings.format; + /** * A planner that computes how allocations for model deployments will be * distributed across a set of nodes. @@ -46,6 +48,7 @@ public AssignmentPlanner(List nodes, List models) { } public AssignmentPlan computePlan() { + logger.debug(() -> format("Computing plan for nodes = %s; models = %s", nodes, models)); AssignmentPlan planKeepingOneAllocationOnPreviousAssignments = solveKeepingOneAllocationOnPreviousAssignments(); AssignmentPlan bestPlan = planKeepingOneAllocationOnPreviousAssignments.satisfiesPreviousAssignments() ? planKeepingOneAllocationOnPreviousAssignments @@ -58,16 +61,20 @@ public AssignmentPlan computePlan() { private AssignmentPlan solveKeepingOneAllocationOnPreviousAssignments() { // We do not want to ever completely unassign a model from a node so we // can move allocations without having temporary impact on performance. + logger.trace(() -> format("Solving preserving one allocation on previous assignments")); return solvePreservingPreviousAssignments(new PreserveOneAllocation(nodes, models)); } private AssignmentPlan solvePreservingAllPreviousAssignments() { + logger.trace(() -> format("Solving preserving all allocations on previous assignments")); return solvePreservingPreviousAssignments(new PreserveAllAllocations(nodes, models)); } private AssignmentPlan solvePreservingPreviousAssignments(AbstractPreserveAllocations preserveAllocations) { List planNodes = preserveAllocations.nodesPreservingAllocations(); List planModels = preserveAllocations.modelsPreservingAllocations(); + logger.trace(() -> format("Nodes after applying allocation preserving strategy = %s", planNodes)); + logger.trace(() -> format("Models after applying allocation preserving strategy = %s", planModels)); AssignmentPlan assignmentPlan = new LinearProgrammingPlanSolver(planNodes, planModels).solvePlan(); return preserveAllocations.mergePreservedAllocations(assignmentPlan); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java index dd71ca8f64e70..e1628c395e44b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java @@ -21,6 +21,8 @@ import org.ojalgo.type.CalendarDateDuration; import org.ojalgo.type.CalendarDateUnit; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -77,7 +79,7 @@ class LinearProgrammingPlanSolver { long maxNodeMemory = nodes.stream().map(Node::availableMemoryBytes).max(Long::compareTo).orElse(0L); this.models = models.stream() // Filter out models that are not already assigned and do not fit on any node - .filter(m -> m.currentAllocationByNodeId().isEmpty() == false || m.memoryBytes() <= maxNodeMemory) + .filter(m -> m.currentAllocationsByNodeId().isEmpty() == false || m.memoryBytes() <= maxNodeMemory) // Also filter out models whose threads per allocation are more than the max node cores .filter(m -> m.threadsPerAllocation() <= maxNodeCores) .toList(); @@ -172,11 +174,11 @@ private Tuple, Double>, AssignmentPlan> calculateWeightsA } private double descendingSizeAnyFitsModelOrder(Model m) { - return (m.currentAllocationByNodeId().isEmpty() ? 1 : 2) * -normalizedMemoryPerModel.get(m) * m.threadsPerAllocation(); + return (m.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * -normalizedMemoryPerModel.get(m) * m.threadsPerAllocation(); } private double descendingSizeAnyFitsNodeOrder(Node n, Model m, AssignmentPlan.Builder assignmentPlan) { - return (m.currentAllocationByNodeId().containsKey(n.id()) ? 0 : 1) + (assignmentPlan.getRemainingCores(n) >= assignmentPlan + return (m.currentAllocationsByNodeId().containsKey(n.id()) ? 0 : 1) + (assignmentPlan.getRemainingCores(n) >= assignmentPlan .getRemainingThreads(m) ? 0 : 1) + (0.01 * distance(assignmentPlan.getRemainingCores(n), assignmentPlan.getRemainingThreads(m))) - (0.01 * assignmentPlan.getRemainingMemory(n)); } @@ -188,11 +190,11 @@ private static int distance(int x, int y) { } private double minWeight(Model m, Node n, double w) { - return m.currentAllocationByNodeId().containsKey(n.id()) ? w / 2 : 0; + return m.currentAllocationsByNodeId().containsKey(n.id()) ? w / 2 : 0; } private double maxWeight(Model m, Node n, double w) { - return m.currentAllocationByNodeId().containsKey(n.id()) ? w : w / 2; + return m.currentAllocationsByNodeId().containsKey(n.id()) ? w : w / 2; } private boolean solveLinearProgram( @@ -292,7 +294,7 @@ private boolean solveLinearProgram( // This is the m_i * a_i_j * t_i / N_j constraint. List allocations = new ArrayList<>(); List modelMemories = new ArrayList<>(); - models.stream().filter(m -> m.currentAllocationByNodeId().containsKey(n.id()) == false).forEach(m -> { + models.stream().filter(m -> m.currentAllocationsByNodeId().containsKey(n.id()) == false).forEach(m -> { allocations.add(allocationVars.get(Tuple.tuple(m, n))); modelMemories.add(normalizedMemoryPerModel.get(m) * m.threadsPerAllocation() / (double) coresPerNode.get(n)); }); @@ -301,7 +303,7 @@ private boolean solveLinearProgram( .setLinearFactors(allocations, Access1D.wrap(modelMemories)); } - Optimisation.Result result = model.maximise(); + Optimisation.Result result = privilegedModelMaximise(model); if (result.getState().isFeasible() == false) { logger.debug("Linear programming solution state [{}] is not feasible", result.getState()); @@ -323,6 +325,11 @@ private boolean solveLinearProgram( return true; } + @SuppressWarnings("removal") + private static Optimisation.Result privilegedModelMaximise(ExpressionsBasedModel model) { + return AccessController.doPrivileged((PrivilegedAction) () -> model.maximise()); + } + private int memoryComplexity() { // Looking at the internals of ojalgo, to solve the problem a 2D double array is created with // dimensions of approximately (n + m) * n * m, where n is the number of nodes and m the number of models. diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocations.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocations.java index 95e35536fa089..6f05359673735 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocations.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocations.java @@ -22,21 +22,21 @@ protected PreserveAllAllocations(List nodes, List models) { @Override protected int calculateUsedCores(Node n, Model m) { - return m.currentAllocationByNodeId().get(n.id()) * m.threadsPerAllocation(); + return m.currentAllocationsByNodeId().get(n.id()) * m.threadsPerAllocation(); } @Override protected Map calculateAllocationsPerNodeToPreserve(Model m) { - return m.currentAllocationByNodeId().entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> 0)); + return m.currentAllocationsByNodeId().entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> 0)); } @Override protected int calculatePreservedAllocations(Model m) { - return m.currentAllocationByNodeId().values().stream().mapToInt(Integer::intValue).sum(); + return m.currentAllocationsByNodeId().values().stream().mapToInt(Integer::intValue).sum(); } @Override protected int addPreservedAllocations(Node n, Model m) { - return m.currentAllocationByNodeId().get(n.id()); + return m.currentAllocationsByNodeId().get(n.id()); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocation.java index 6b62db1d60702..aa79f5254d3b7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocation.java @@ -27,12 +27,12 @@ protected int calculateUsedCores(Node n, Model m) { @Override protected Map calculateAllocationsPerNodeToPreserve(Model m) { - return m.currentAllocationByNodeId().entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue() - 1)); + return m.currentAllocationsByNodeId().entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue() - 1)); } @Override protected int calculatePreservedAllocations(Model m) { - return (int) m.currentAllocationByNodeId().values().stream().filter(v -> v > 0).count(); + return (int) m.currentAllocationsByNodeId().values().stream().filter(v -> v > 0).count(); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.java index b1cb0000c98f3..e875de4f24b92 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.java @@ -168,7 +168,7 @@ private double decreasingQualityNodeOrder(Node n) { for (Model m : models) { Tuple index = Tuple.tuple(m, n); if (allocations.get(index) > 0) { - quality += (1 + (m.currentAllocationByNodeId().containsKey(n.id()) ? 1 : 0)) * allocations.get(index) * m + quality += (1 + (m.currentAllocationsByNodeId().containsKey(n.id()) ? 1 : 0)) * allocations.get(index) * m .threadsPerAllocation(); } } @@ -205,7 +205,7 @@ private void assignExcessCores(Node n) { } private double remainingModelOrder(Model m) { - return (m.currentAllocationByNodeId().isEmpty() ? 1 : 2) * -m.memoryBytes(); + return (m.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * -m.memoryBytes(); } private boolean hasSoftAssignments(Node n) { @@ -370,7 +370,7 @@ private double remainingNodeOrder( long remainingNodeMemory, int remainingModelAllocations ) { - return (m.currentAllocationByNodeId().containsKey(n.id()) ? 0 : 1) + (remainingNodeCores <= remainingModelAllocations * m + return (m.currentAllocationsByNodeId().containsKey(n.id()) ? 0 : 1) + (remainingNodeCores <= remainingModelAllocations * m .threadsPerAllocation() ? 0 : 0.5) + (0.01 * distance( remainingNodeCores, remainingModelAllocations * m.threadsPerAllocation() @@ -403,7 +403,7 @@ private static class ResourceTracker { for (Model m : models) { for (Node n : nodes) { - if (m.currentAllocationByNodeId().containsKey(n.id())) { + if (m.currentAllocationsByNodeId().containsKey(n.id())) { assignments.add(Tuple.tuple(m, n)); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 2b41062ed3a4f..35e7f619a8e83 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -87,7 +87,7 @@ public DeploymentManager( this.pyTorchProcessFactory = Objects.requireNonNull(pyTorchProcessFactory); this.threadPool = Objects.requireNonNull(threadPool); this.executorServiceForDeployment = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME); - this.executorServiceForProcess = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); + this.executorServiceForProcess = threadPool.executor(MachineLearning.NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME); } public void startDeployment(TrainedModelDeploymentTask task, ActionListener listener) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java index 5b321c806cd5d..bc9ab284836bd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -29,15 +29,17 @@ import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentNodeService; import java.util.Map; +import java.util.Objects; import java.util.Optional; public class TrainedModelDeploymentTask extends CancellableTask implements StartTrainedModelDeploymentAction.TaskMatcher { private static final Logger logger = LogManager.getLogger(TrainedModelDeploymentTask.class); - private final TaskParams params; + private volatile TaskParams params; private final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService; private volatile boolean stopped; + private volatile boolean failed; private final SetOnce stoppedReasonHolder = new SetOnce<>(); private final SetOnce inferenceConfigHolder = new SetOnce<>(); private final XPackLicenseState licenseState; @@ -55,7 +57,7 @@ public TrainedModelDeploymentTask( LicensedFeature.Persistent licensedFeature ) { super(id, type, action, MlTasks.trainedModelAssignmentTaskDescription(taskParams.getModelId()), parentTask, headers); - this.params = taskParams; + this.params = Objects.requireNonNull(taskParams); this.trainedModelAssignmentNodeService = ExceptionsHelper.requireNonNull( trainedModelAssignmentNodeService, "trainedModelAssignmentNodeService" @@ -70,6 +72,16 @@ void init(InferenceConfig inferenceConfig) { } } + public void updateNumberOfAllocations(int numberOfAllocations) { + params = new TaskParams( + params.getModelId(), + params.getModelBytes(), + numberOfAllocations, + params.getThreadsPerAllocation(), + params.getQueueCapacity() + ); + } + public String getModelId() { return params.getModelId(); } @@ -145,6 +157,11 @@ public Optional modelStats() { } public void setFailed(String reason) { + failed = true; trainedModelAssignmentNodeService.failAssignment(this, reason); } + + public boolean isFailed() { + return failed; + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java index 700cdc8ffc798..7b4609a0df38e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java @@ -69,7 +69,7 @@ public NativePyTorchProcess createProcess( true, true, true, - false + false // We do not need a persist pipe. This is also why we use 3 threads per model assignment in the pytorch thread pool. ); executeProcess(processPipes, task); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java index b92b1a8a53d16..9326af7aa6785 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java @@ -201,7 +201,7 @@ public PersistentTasksCustomMetadata.Assignment selectNode( jobId, nodeNameAndMlAttributes(node), "This node is full. Number of opened jobs and allocated native inference processes [%s], %s [%s].", - currentLoad.getNumAssignedJobs(), + currentLoad.getNumAssignedJobsAndModels(), MAX_OPEN_JOBS_PER_NODE.getKey(), maxNumberOfOpenJobs ); @@ -234,7 +234,7 @@ public PersistentTasksCustomMetadata.Assignment selectNode( // If this will be the first job assigned to the node then it will need to // load the native code shared libraries, so add the overhead for this - if (currentLoad.getNumAssignedJobs() == 0) { + if (currentLoad.getNumAssignedJobsAndModels() == 0) { requiredMemoryForJob += MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes(); } long availableMemory = currentLoad.getMaxMlMemory() - currentLoad.getAssignedJobMemory(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoad.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoad.java index 89e75e7c23c47..a3ad30a5d0865 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoad.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoad.java @@ -27,7 +27,7 @@ public class NodeLoad { private final String error; private final int numAssignedAnomalyDetectorJobs; private final int numAssignedDataFrameAnalyticsJobs; - private final int numAssignedNativeInferenceJobs; + private final int numAssignedNativeInferenceModels; private final long assignedNativeCodeOverheadMemory; private final long assignedAnomalyDetectorMemory; private final long assignedDataFrameAnalyticsMemory; @@ -42,7 +42,7 @@ public class NodeLoad { String error, int numAssignedAnomalyDetectorJobs, int numAssignedDataFrameAnalyticsJobs, - int numAssignedNativeInferenceJobs, + int numAssignedNativeInferenceModels, long assignedNativeCodeOverheadMemory, long assignedAnomalyDetectorMemory, long assignedDataFrameAnalyticsMemory, @@ -56,7 +56,7 @@ public class NodeLoad { this.error = error; this.numAssignedAnomalyDetectorJobs = numAssignedAnomalyDetectorJobs; this.numAssignedDataFrameAnalyticsJobs = numAssignedDataFrameAnalyticsJobs; - this.numAssignedNativeInferenceJobs = numAssignedNativeInferenceJobs; + this.numAssignedNativeInferenceModels = numAssignedNativeInferenceModels; this.assignedNativeCodeOverheadMemory = assignedNativeCodeOverheadMemory; this.assignedAnomalyDetectorMemory = assignedAnomalyDetectorMemory; this.assignedDataFrameAnalyticsMemory = assignedDataFrameAnalyticsMemory; @@ -65,10 +65,10 @@ public class NodeLoad { } /** - * @return The total number of assigned jobs + * @return The total number of assigned jobs and models */ - public int getNumAssignedJobs() { - return numAssignedAnomalyDetectorJobs + numAssignedDataFrameAnalyticsJobs + numAssignedNativeInferenceJobs; + public int getNumAssignedJobsAndModels() { + return numAssignedAnomalyDetectorJobs + numAssignedDataFrameAnalyticsJobs + numAssignedNativeInferenceModels; } /** @@ -166,7 +166,8 @@ public long getFreeMemoryExcludingPerNodeOverhead() { * @return The number of jobs that can still be assigned to the node */ public int remainingJobs() { - return Math.max(maxJobs - getNumAssignedJobs(), 0); + // Native inference jobs use their own thread pool so they should not account towards the limit of open jobs. + return Math.max(maxJobs - (getNumAssignedJobsAndModels() - numAssignedNativeInferenceModels), 0); } /** @@ -194,7 +195,7 @@ public boolean equals(Object o) { && useMemory == nodeLoad.useMemory && numAssignedAnomalyDetectorJobs == nodeLoad.numAssignedAnomalyDetectorJobs && numAssignedDataFrameAnalyticsJobs == nodeLoad.numAssignedDataFrameAnalyticsJobs - && numAssignedNativeInferenceJobs == nodeLoad.numAssignedNativeInferenceJobs + && numAssignedNativeInferenceModels == nodeLoad.numAssignedNativeInferenceModels && assignedNativeCodeOverheadMemory == nodeLoad.assignedNativeCodeOverheadMemory && assignedAnomalyDetectorMemory == nodeLoad.assignedAnomalyDetectorMemory && assignedDataFrameAnalyticsMemory == nodeLoad.assignedDataFrameAnalyticsMemory @@ -214,7 +215,7 @@ public int hashCode() { error, numAssignedAnomalyDetectorJobs, numAssignedDataFrameAnalyticsJobs, - numAssignedNativeInferenceJobs, + numAssignedNativeInferenceModels, assignedNativeCodeOverheadMemory, assignedAnomalyDetectorMemory, assignedDataFrameAnalyticsMemory, @@ -239,7 +240,7 @@ public static class Builder { private String error; private int numAssignedAnomalyDetectorJobs; private int numAssignedDataFrameAnalyticsJobs; - private int numAssignedNativeInferenceJobs; + private int numAssignedNativeInferenceModels; private long assignedNativeCodeOverheadMemory; private long assignedAnomalyDetectorMemory; private long assignedDataFrameAnalyticsMemory; @@ -254,7 +255,7 @@ public Builder(NodeLoad nodeLoad) { this.error = nodeLoad.error; this.numAssignedAnomalyDetectorJobs = nodeLoad.numAssignedAnomalyDetectorJobs; this.numAssignedDataFrameAnalyticsJobs = nodeLoad.numAssignedDataFrameAnalyticsJobs; - this.numAssignedNativeInferenceJobs = nodeLoad.numAssignedNativeInferenceJobs; + this.numAssignedNativeInferenceModels = nodeLoad.numAssignedNativeInferenceModels; this.assignedNativeCodeOverheadMemory = nodeLoad.assignedNativeCodeOverheadMemory; this.assignedAnomalyDetectorMemory = nodeLoad.assignedAnomalyDetectorMemory; this.assignedDataFrameAnalyticsMemory = nodeLoad.assignedDataFrameAnalyticsMemory; @@ -275,7 +276,8 @@ public long getFreeMemory() { } public int remainingJobs() { - return Math.max(maxJobs - getNumAssignedJobs(), 0); + // Native inference jobs use their own thread pool so they should not account towards the limit of open jobs. + return Math.max(maxJobs - (getNumAssignedJobs() - numAssignedNativeInferenceModels), 0); } public String getNodeId() { @@ -283,7 +285,7 @@ public String getNodeId() { } public int getNumAssignedJobs() { - return numAssignedAnomalyDetectorJobs + numAssignedDataFrameAnalyticsJobs + numAssignedNativeInferenceJobs; + return numAssignedAnomalyDetectorJobs + numAssignedDataFrameAnalyticsJobs + numAssignedNativeInferenceModels; } public Builder setMaxMemory(long maxMemory) { @@ -320,8 +322,8 @@ public Builder incNumAssignedDataFrameAnalyticsJobs() { return this; } - public Builder incNumAssignedNativeInferenceJobs() { - ++this.numAssignedNativeInferenceJobs; + public Builder incNumAssignedNativeInferenceModels() { + ++this.numAssignedNativeInferenceModels; return this; } @@ -390,7 +392,7 @@ public NodeLoad build() { error, numAssignedAnomalyDetectorJobs, numAssignedDataFrameAnalyticsJobs, - numAssignedNativeInferenceJobs, + numAssignedNativeInferenceModels, assignedNativeCodeOverheadMemory, assignedAnomalyDetectorMemory, assignedDataFrameAnalyticsMemory, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoadDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoadDetector.java index f3c1ad9023cc3..8ee4be599826b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoadDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoadDetector.java @@ -6,13 +6,15 @@ */ package org.elasticsearch.xpack.ml.job; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Strings; import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; -import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.core.ml.utils.MemoryTrackedTaskState; import org.elasticsearch.xpack.core.ml.utils.MlTaskParams; @@ -33,6 +35,8 @@ public class NodeLoadDetector { + private static final Logger logger = LogManager.getLogger(NodeLoadDetector.class); + private final MlMemoryTracker mlMemoryTracker; /** @@ -101,7 +105,9 @@ public NodeLoad detectNodeLoad( .setMaxJobs(maxNumberOfOpenJobs) .setUseMemory(true); if (errors.isEmpty() == false) { - return nodeLoad.setError(Strings.collectionToCommaDelimitedString(errors)).build(); + String errorMsg = Strings.collectionToCommaDelimitedString(errors); + logger.warn("error detecting load for node [{}]: {}", node.getId(), errorMsg); + return nodeLoad.setError(errorMsg).build(); } updateLoadGivenTasks(nodeLoad, persistentTasks); updateLoadGivenModelAssignments(nodeLoad, assignmentMetadata); @@ -134,10 +140,10 @@ private void updateLoadGivenModelAssignments(NodeLoad.Builder nodeLoad, TrainedM if (trainedModelAssignmentMetadata != null && trainedModelAssignmentMetadata.modelAssignments().isEmpty() == false) { for (TrainedModelAssignment assignment : trainedModelAssignmentMetadata.modelAssignments().values()) { if (Optional.ofNullable(assignment.getNodeRoutingTable().get(nodeLoad.getNodeId())) - .map(RoutingStateAndReason::getState) + .map(RoutingInfo::getState) .orElse(RoutingState.STOPPED) .consumesMemory()) { - nodeLoad.incNumAssignedNativeInferenceJobs(); + nodeLoad.incNumAssignedNativeInferenceModels(); nodeLoad.incAssignedNativeInferenceMemory(assignment.getTaskParams().estimateMemoryUsageBytes()); } } diff --git a/x-pack/plugin/ml/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/ml/src/main/plugin-metadata/plugin-security.policy index a7cbf1a6f3fb4..1bf45f6d697a6 100644 --- a/x-pack/plugin/ml/src/main/plugin-metadata/plugin-security.policy +++ b/x-pack/plugin/ml/src/main/plugin-metadata/plugin-security.policy @@ -2,5 +2,6 @@ grant { // needed for Windows named pipes in machine learning permission java.io.FilePermission "\\\\.\\pipe\\*", "read,write"; + // needed for ojalgo linear programming solver permission java.lang.RuntimePermission "accessDeclaredMembers"; }; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java index 8384d6cf8255d..4ff352ea52af5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java @@ -17,8 +17,8 @@ import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStatsTests; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; -import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import java.net.InetAddress; @@ -48,12 +48,12 @@ public void testAddFailedRoutes_GivenNoTaskResponses() throws UnknownHostExcepti 0 ); - Map> badRoutes = new HashMap<>(); + Map> badRoutes = new HashMap<>(); for (var modelId : new String[] { "model1", "model2" }) { TrainedModelAssignment assignment = createAssignment(modelId); - Map nodeRoutes = new HashMap<>(); + Map nodeRoutes = new HashMap<>(); for (var nodeId : new String[] { "nodeA", "nodeB" }) { - nodeRoutes.put(nodeId, new RoutingStateAndReason(RoutingState.FAILED, "failure reason")); + nodeRoutes.put(nodeId, new RoutingInfo(1, 1, RoutingState.FAILED, "failure reason")); } badRoutes.put(assignment, nodeRoutes); } @@ -86,9 +86,9 @@ public void testAddFailedRoutes_GivenMixedResponses() throws UnknownHostExceptio nodeStatsList ); - Map> badRoutes = new HashMap<>(); - Map nodeRoutes = new HashMap<>(); - nodeRoutes.put("node3", new RoutingStateAndReason(RoutingState.FAILED, "failed on node3")); + Map> badRoutes = new HashMap<>(); + Map nodeRoutes = new HashMap<>(); + nodeRoutes.put("node3", new RoutingInfo(1, 1, RoutingState.FAILED, "failed on node3")); badRoutes.put(createAssignment("model1"), nodeRoutes); var response = new GetDeploymentStatsAction.Response(Collections.emptyList(), Collections.emptyList(), List.of(model1), 1); @@ -123,9 +123,9 @@ public void testAddFailedRoutes_TaskResultIsOverwritten() throws UnknownHostExce var response = new GetDeploymentStatsAction.Response(Collections.emptyList(), Collections.emptyList(), List.of(model1), 1); // failed state for node 2 conflicts with the task response - Map> badRoutes = new HashMap<>(); - Map nodeRoutes = new HashMap<>(); - nodeRoutes.put("node2", new RoutingStateAndReason(RoutingState.FAILED, "failed on node3")); + Map> badRoutes = new HashMap<>(); + Map nodeRoutes = new HashMap<>(); + nodeRoutes.put("node2", new RoutingInfo(1, 1, RoutingState.FAILED, "failed on node3")); badRoutes.put(createAssignment("model1"), nodeRoutes); var modified = TransportGetDeploymentStatsAction.addFailedRoutes(response, badRoutes, nodes); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java index 02e355a19b489..83fcd64b9521c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java @@ -11,6 +11,8 @@ import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.LatchedActionListener; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; @@ -26,35 +28,40 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.set.Sets; -import org.elasticsearch.persistent.PersistentTasksCustomMetadata; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; -import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentStateAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfoUpdate; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.job.NodeLoadDetector; -import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutorTests; import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import org.junit.Before; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CountDownLatch; import java.util.function.Function; -import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.anEmptyMap; -import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -62,6 +69,7 @@ public class TrainedModelAssignmentClusterServiceTests extends ESTestCase { private ClusterService clusterService; + private ThreadPool threadPool; private NodeLoadDetector nodeLoadDetector; @Before @@ -76,6 +84,9 @@ public void setupObjects() { ) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + threadPool = mock(ThreadPool.class); + MlMemoryTracker memoryTracker = mock(MlMemoryTracker.class); when(memoryTracker.isRecentlyRefreshed()).thenReturn(true); nodeLoadDetector = new NodeLoadDetector(memoryTracker); @@ -88,8 +99,8 @@ public void testUpdateModelRoutingTable() { ClusterState currentState = ClusterState.builder(new ClusterName("testUpdateModelRoutingTable")) .nodes( DiscoveryNodes.builder() - .add(buildNode(nodeId, true, ByteSizeValue.ofGb(4).getBytes())) - .add(buildNode(startedNode, true, ByteSizeValue.ofGb(4).getBytes())) + .add(buildNode(nodeId, true, ByteSizeValue.ofGb(4).getBytes(), 8)) + .add(buildNode(startedNode, true, ByteSizeValue.ofGb(4).getBytes(), 8)) .build() ) .metadata( @@ -100,8 +111,8 @@ public void testUpdateModelRoutingTable() { .addNewAssignment( modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L)) - .addNewRoutingEntry(nodeId) - .addNewRoutingEntry(startedNode) + .addRoutingEntry(nodeId, new RoutingInfo(1, 1, RoutingState.STARTING, "")) + .addRoutingEntry(startedNode, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() ) @@ -112,7 +123,7 @@ public void testUpdateModelRoutingTable() { assertThatStoppingAssignmentPreventsMutation( state -> TrainedModelAssignmentClusterService.updateModelRoutingTable( state, - new UpdateTrainedModelAssignmentStateAction.Request(nodeId, modelId, started()) + new UpdateTrainedModelAssignmentRoutingInfoAction.Request(nodeId, modelId, started()) ), currentState ); @@ -124,7 +135,7 @@ public void testUpdateModelRoutingTable() { ClusterState newState = TrainedModelAssignmentClusterService.updateModelRoutingTable( currentState, - new UpdateTrainedModelAssignmentStateAction.Request(startedNode, modelId, started()) + new UpdateTrainedModelAssignmentRoutingInfoAction.Request(startedNode, modelId, started()) ); assertThat( TrainedModelAssignmentMetadata.fromState(newState) @@ -143,22 +154,14 @@ public void testUpdateModelRoutingTable() { ResourceNotFoundException.class, () -> TrainedModelAssignmentClusterService.updateModelRoutingTable( newState, - new UpdateTrainedModelAssignmentStateAction.Request( - "missingNode", - modelId, - new RoutingStateAndReason(RoutingState.STARTED, "") - ) + new UpdateTrainedModelAssignmentRoutingInfoAction.Request("missingNode", modelId, started()) ) ); expectThrows( ResourceNotFoundException.class, () -> TrainedModelAssignmentClusterService.updateModelRoutingTable( newState, - new UpdateTrainedModelAssignmentStateAction.Request( - nodeId, - "missingModel", - new RoutingStateAndReason(RoutingState.STARTED, "") - ) + new UpdateTrainedModelAssignmentRoutingInfoAction.Request(nodeId, "missingModel", started()) ) ); @@ -167,16 +170,28 @@ public void testUpdateModelRoutingTable() { // We should allow a "stopped" update on missing models and nodes as entries may have already been deleted TrainedModelAssignmentClusterService.updateModelRoutingTable( newState, - new UpdateTrainedModelAssignmentStateAction.Request("missingNode", modelId, new RoutingStateAndReason(RoutingState.STOPPED, "")) + new UpdateTrainedModelAssignmentRoutingInfoAction.Request( + "missingNode", + modelId, + RoutingInfoUpdate.updateStateAndReason(new RoutingStateAndReason(RoutingState.STOPPED, "")) + ) ); TrainedModelAssignmentClusterService.updateModelRoutingTable( newState, - new UpdateTrainedModelAssignmentStateAction.Request(nodeId, "missingModel", new RoutingStateAndReason(RoutingState.STOPPED, "")) + new UpdateTrainedModelAssignmentRoutingInfoAction.Request( + nodeId, + "missingModel", + RoutingInfoUpdate.updateStateAndReason(new RoutingStateAndReason(RoutingState.STOPPED, "")) + ) ); ClusterState updateState = TrainedModelAssignmentClusterService.updateModelRoutingTable( newState, - new UpdateTrainedModelAssignmentStateAction.Request(nodeId, modelId, new RoutingStateAndReason(RoutingState.STOPPED, "")) + new UpdateTrainedModelAssignmentRoutingInfoAction.Request( + nodeId, + modelId, + RoutingInfoUpdate.updateStateAndReason(new RoutingStateAndReason(RoutingState.STOPPED, "")) + ) ); assertThat( TrainedModelAssignmentMetadata.fromState(updateState).getModelAssignment(modelId).getNodeRoutingTable(), @@ -200,7 +215,7 @@ public void testRemoveAssignment() { ); ClusterState clusterStateWithAssignment = ClusterState.builder(new ClusterName("testRemoveAssignment")) - .nodes(DiscoveryNodes.builder().add(buildNode("test-node", true, ByteSizeValue.ofGb(4).getBytes())).build()) + .nodes(DiscoveryNodes.builder().add(buildNode("test-node", true, ByteSizeValue.ofGb(4).getBytes(), 8)).build()) .metadata( Metadata.builder() .putCustom( @@ -228,7 +243,7 @@ public void testRemoveAllAssignments() { ); ClusterState clusterStateWithAssignments = ClusterState.builder(new ClusterName("testRemoveAllAssignments")) - .nodes(DiscoveryNodes.builder().add(buildNode("test-node", true, ByteSizeValue.ofGb(4).getBytes())).build()) + .nodes(DiscoveryNodes.builder().add(buildNode("test-node", true, ByteSizeValue.ofGb(4).getBytes(), 8)).build()) .metadata( Metadata.builder() .putCustom(TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadataTests.randomInstance()) @@ -239,22 +254,22 @@ public void testRemoveAllAssignments() { assertThat(TrainedModelAssignmentMetadata.fromState(modified).modelAssignments(), is(anEmptyMap())); } - public void testCreateAssignment() { + public void testCreateAssignment() throws Exception { ClusterState currentState = ClusterState.builder(new ClusterName("testCreateAssignment")) .nodes( DiscoveryNodes.builder() - .add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes())) - .add(buildNode("ml-node-without-room", true, 1000L)) - .add(buildNode("not-ml-node", false, ByteSizeValue.ofGb(4).getBytes())) - .add(buildNode("ml-node-shutting-down", true, ByteSizeValue.ofGb(4).getBytes())) - .add(buildOldNode("old-ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes())) + .add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes(), 2)) + .add(buildNode("ml-node-without-room", true, 1000L, 2)) + .add(buildNode("not-ml-node", false, ByteSizeValue.ofGb(4).getBytes(), 2)) + .add(buildNode("ml-node-shutting-down", true, ByteSizeValue.ofGb(4).getBytes(), 2)) + .add(buildOldNode("old-ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes(), 2)) .build() ) .metadata(Metadata.builder().putCustom(NodesShutdownMetadata.TYPE, shutdownMetadata("ml-node-shutting-down"))) .build(); TrainedModelAssignmentClusterService trainedModelAssignmentClusterService = createClusterService(); - ClusterState newState = trainedModelAssignmentClusterService.createModelAssignment(currentState, newParams("new-model", 150)); + ClusterState newState = trainedModelAssignmentClusterService.createModelAssignment(currentState, newParams("new-model", 150, 4, 1)); TrainedModelAssignment createdAssignment = TrainedModelAssignmentMetadata.fromState(newState).getModelAssignment("new-model"); assertThat(createdAssignment, is(not(nullValue()))); @@ -262,7 +277,10 @@ public void testCreateAssignment() { assertThat(createdAssignment.getNodeRoutingTable(), hasKey("ml-node-with-room")); assertThat(createdAssignment.getNodeRoutingTable().get("ml-node-with-room").getState(), equalTo(RoutingState.STARTING)); assertThat(createdAssignment.getReason().isPresent(), is(true)); - assertThat(createdAssignment.getReason().get(), containsString("Not allocating on node [ml-node-without-room]")); + assertThat( + createdAssignment.getReason().get(), + containsString("Could not assign (more) allocations on node [ml-node-without-room]") + ); assertThat(createdAssignment.getAssignmentState(), equalTo(AssignmentState.STARTING)); expectThrows( @@ -271,232 +289,42 @@ public void testCreateAssignment() { ); } - public void testCreateAssignmentWhileResetModeIsTrue() { + public void testCreateAssignmentWhileResetModeIsTrue() throws InterruptedException { ClusterState currentState = ClusterState.builder(new ClusterName("testCreateAssignment")) - .nodes(DiscoveryNodes.builder().add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes())).build()) + .nodes(DiscoveryNodes.builder().add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes(), 8)).build()) .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isResetMode(true).build())) .build(); + when(clusterService.state()).thenReturn(currentState); TrainedModelAssignmentClusterService trainedModelAssignmentClusterService = createClusterService(); - expectThrows( - ElasticsearchStatusException.class, - () -> trainedModelAssignmentClusterService.createModelAssignment(currentState, newParams("new-model", 150)) - ); - - ClusterState stateWithoutReset = ClusterState.builder(new ClusterName("testCreateAssignment")) - .nodes(DiscoveryNodes.builder().add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes())).build()) - .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isResetMode(false).build())) - .build(); - // Shouldn't throw - trainedModelAssignmentClusterService.createModelAssignment(stateWithoutReset, newParams("new-model", 150)); - } - - public void testAddRemoveAssignmentNodes() { - ClusterState currentState = ClusterState.builder(new ClusterName("testAddRemoveAssignmentNodes")) - .nodes( - DiscoveryNodes.builder() - .add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes())) - .add(buildNode("new-ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes())) - .add(buildNode("ml-node-without-room", true, 1000L)) - .add(buildNode("not-ml-node", false, ByteSizeValue.ofGb(4).getBytes())) - .add(buildNode("ml-node-shutting-down", true, ByteSizeValue.ofGb(4).getBytes())) - .add(buildOldNode("old-versioned-ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes())) - .build() - ) - .metadata( - Metadata.builder() - .putCustom(NodesShutdownMetadata.TYPE, shutdownMetadata("ml-node-shutting-down")) - .putCustom( - // We have to use deprecated name here as we have a node versioned before the rename - TrainedModelAssignmentMetadata.DEPRECATED_NAME, - TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment( - "model-1", - TrainedModelAssignment.Builder.empty(newParams("model-1", 10_000)) - .addNewRoutingEntry("ml-node-with-room") - .updateExistingRoutingEntry("ml-node-with-room", started()) - .addNewRoutingEntry("old-ml-node-with-room") - .updateExistingRoutingEntry("old-ml-node-with-room", started()) - .addNewRoutingEntry("ml-node-shutting-down") - ) - .addNewAssignment( - "model-2", - TrainedModelAssignment.Builder.empty(newParams("model-2", 10_000)) - .addNewRoutingEntry("old-ml-node-with-room") - .updateExistingRoutingEntry("old-ml-node-with-room", started()) - ) - .build() - ) - ) - .build(); - TrainedModelAssignmentClusterService trainedModelAssignmentClusterService = createClusterService(); - - // Stopping shouldn't cause any updates - assertThatStoppingAssignmentPreventsMutation(trainedModelAssignmentClusterService::addRemoveAssignmentNodes, currentState); - - ClusterState modified = trainedModelAssignmentClusterService.addRemoveAssignmentNodes(currentState); - TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(modified); - assertThat(trainedModelAssignmentMetadata.modelAssignments().keySet(), hasSize(2)); - assertThat(trainedModelAssignmentMetadata.modelAssignments(), allOf(hasKey("model-1"), hasKey("model-2"))); - - assertThat(trainedModelAssignmentMetadata.getModelAssignment("model-1").getNodeRoutingTable().keySet(), hasSize(2)); - assertThat( - trainedModelAssignmentMetadata.getModelAssignment("model-1").getNodeRoutingTable(), - allOf(hasKey("ml-node-with-room"), hasKey("new-ml-node-with-room")) - ); - assertNodeState(trainedModelAssignmentMetadata, "model-1", "ml-node-with-room", RoutingState.STARTED); - assertNodeState(trainedModelAssignmentMetadata, "model-1", "new-ml-node-with-room", RoutingState.STARTING); - assertThat(trainedModelAssignmentMetadata.modelAssignments().get("model-1").getAssignmentState(), equalTo(AssignmentState.STARTED)); - - assertThat(trainedModelAssignmentMetadata.getModelAssignment("model-2").getNodeRoutingTable().keySet(), hasSize(2)); - assertThat( - trainedModelAssignmentMetadata.getModelAssignment("model-2").getNodeRoutingTable(), - allOf(hasKey("ml-node-with-room"), hasKey("new-ml-node-with-room")) - ); - assertNodeState(trainedModelAssignmentMetadata, "model-2", "ml-node-with-room", RoutingState.STARTING); - assertNodeState(trainedModelAssignmentMetadata, "model-2", "new-ml-node-with-room", RoutingState.STARTING); - assertThat( - trainedModelAssignmentMetadata.modelAssignments().get("model-2").getAssignmentState(), - equalTo(AssignmentState.STARTING) - ); - } - public void testAddRemoveAllocationNodesPrioritizesAllocationsWithFewerNodes() { - ClusterState currentState = ClusterState.builder(new ClusterName("testAddRemoveAllocationNodes")) - .nodes( - DiscoveryNodes.builder() - .add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes())) - .add(buildNode("new-ml-node-with-just-enough-room", true, ByteSizeValue.ofGb(8).getBytes())) - .add(buildNode("ml-node-without-room", true, 1000L)) - .add(buildNode("not-ml-node", false, ByteSizeValue.ofGb(4).getBytes())) - .add(buildNode("ml-node-shutting-down", true, ByteSizeValue.ofGb(4).getBytes())) - .add(buildOldNode("old-versioned-ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes())) - .build() - ) - .metadata( - Metadata.builder() - .putCustom(NodesShutdownMetadata.TYPE, shutdownMetadata("ml-node-shutting-down")) - .putCustom( - // We have to use deprecated name here as we have a node versioned before the rename - TrainedModelAssignmentMetadata.DEPRECATED_NAME, - TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment( - "model-1", - TrainedModelAssignment.Builder.empty(newParams("model-1", ByteSizeValue.ofGb(1).getBytes())) - .addNewRoutingEntry("ml-node-with-room") - .updateExistingRoutingEntry("ml-node-with-room", started()) - .addNewRoutingEntry("old-ml-node-with-room") - .updateExistingRoutingEntry("old-ml-node-with-room", started()) - .addNewRoutingEntry("ml-node-shutting-down") - ) - .addNewAssignment( - "model-2", - TrainedModelAssignment.Builder.empty(newParams("model-2", ByteSizeValue.ofGb(1).getBytes())) - .addNewRoutingEntry("ml-node-with-room") - ) - .addNewAssignment( - "model-3", - TrainedModelAssignment.Builder.empty(newParams("model-3", ByteSizeValue.ofGb(1).getBytes())) - ) - .build() - ) + CountDownLatch latch = new CountDownLatch(1); + trainedModelAssignmentClusterService.createNewModelAssignment( + newParams("new-model", 150), + new LatchedActionListener<>( + ActionListener.wrap( + trainedModelAssignment -> fail("assignment should have failed to be created because reset mode is set"), + e -> { + assertThat(e, is(instanceOf(ElasticsearchStatusException.class))); + assertThat(((ElasticsearchStatusException) e).status(), equalTo(RestStatus.CONFLICT)); + assertThat( + e.getMessage(), + equalTo("cannot create new assignment for model [new-model] while feature reset is in progress.") + ); + } + ), + latch ) - .build(); - TrainedModelAssignmentClusterService trainedModelAssignmentClusterService = createClusterService(); - - ClusterState modified = trainedModelAssignmentClusterService.addRemoveAssignmentNodes(currentState); - TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(modified); - assertThat(trainedModelAssignmentMetadata.modelAssignments(), allOf(hasKey("model-1"), hasKey("model-2"), hasKey("model-3"))); - - assertThat(trainedModelAssignmentMetadata.getModelAssignment("model-1").getNodeRoutingTable().keySet(), hasSize(1)); - assertThat(trainedModelAssignmentMetadata.getModelAssignment("model-1").getNodeRoutingTable(), allOf(hasKey("ml-node-with-room"))); - assertNodeState(trainedModelAssignmentMetadata, "model-1", "ml-node-with-room", RoutingState.STARTED); - assertThat(trainedModelAssignmentMetadata.modelAssignments().get("model-1").getAssignmentState(), equalTo(AssignmentState.STARTED)); - - assertThat(trainedModelAssignmentMetadata.getModelAssignment("model-2").getNodeRoutingTable().keySet(), hasSize(1)); - assertThat(trainedModelAssignmentMetadata.getModelAssignment("model-2").getNodeRoutingTable(), allOf(hasKey("ml-node-with-room"))); - assertNodeState(trainedModelAssignmentMetadata, "model-2", "ml-node-with-room", RoutingState.STARTING); - assertThat( - trainedModelAssignmentMetadata.modelAssignments().get("model-2").getAssignmentState(), - equalTo(AssignmentState.STARTING) - ); - - assertThat(trainedModelAssignmentMetadata.getModelAssignment("model-3").getNodeRoutingTable().keySet(), hasSize(1)); - assertThat( - trainedModelAssignmentMetadata.getModelAssignment("model-3").getNodeRoutingTable(), - allOf(hasKey("new-ml-node-with-just-enough-room")) - ); - assertNodeState(trainedModelAssignmentMetadata, "model-3", "new-ml-node-with-just-enough-room", RoutingState.STARTING); - assertThat( - trainedModelAssignmentMetadata.modelAssignments().get("model-3").getAssignmentState(), - equalTo(AssignmentState.STARTING) ); + latch.await(); } - public void testAddRemoveAllocationNodes_GivenNodeThatReachedMaxOpenJobs() { - - PersistentTasksCustomMetadata.Builder tasksBuilder = PersistentTasksCustomMetadata.builder(); - for (int i = 0; i < MachineLearning.DEFAULT_MAX_OPEN_JOBS_PER_NODE; i++) { - OpenJobPersistentTasksExecutorTests.addJobTask("job_id_" + i, "ml-node-full-load", null, tasksBuilder); - } - PersistentTasksCustomMetadata persistentTasks = tasksBuilder.build(); - - ClusterState currentState = ClusterState.builder(new ClusterName("testAddRemoveAllocationNodes")) - .nodes( - DiscoveryNodes.builder() - .add(buildNode("ml-node-full-load", true, ByteSizeValue.ofGb(4).getBytes())) - .add(buildNode("ml-node-no-load", true, ByteSizeValue.ofGb(4).getBytes())) - .build() - ) - .metadata( - Metadata.builder() - .putCustom( - TrainedModelAssignmentMetadata.NAME, - TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment( - "model-1", - TrainedModelAssignment.Builder.empty(newParams("model-1", 10_000)) - .addNewRoutingEntry("ml-node-no-load") - .updateExistingRoutingEntry("ml-node-no-load", started()) - ) - .build() - ) - .putCustom(PersistentTasksCustomMetadata.TYPE, persistentTasks) - ) - .build(); - TrainedModelAssignmentClusterService trainedModelAssignmentClusterService = createClusterService(); - - ClusterState modified = trainedModelAssignmentClusterService.addRemoveAssignmentNodes(currentState); - TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(modified); - assertThat(trainedModelAssignmentMetadata.modelAssignments().keySet(), contains("model-1")); - - assertThat(trainedModelAssignmentMetadata.getModelAssignment("model-1").getNodeRoutingTable().keySet(), hasSize(1)); - assertThat( - trainedModelAssignmentMetadata.getModelAssignment("model-1").getNodeRoutingTable().keySet(), - contains("ml-node-no-load") - ); - assertThat( - trainedModelAssignmentMetadata.getModelAssignment("model-1").getNodeRoutingTable().get("ml-node-no-load").getState(), - equalTo(RoutingState.STARTED) - ); - - TrainedModelAssignment allocation = trainedModelAssignmentMetadata.getModelAssignment("model-1"); - assertThat( - allocation.getReason().get(), - equalTo( - "Not allocating on node [ml-node-full-load]." - + " Reason: This node is full. Number of opened jobs and allocated native inference processes [512], " - + "xpack.ml.max_open_jobs [512]." - ) - ); - } - - public void testShouldAllocateModels() { + public void testShouldRebalanceModels() { String model1 = "model-1"; String model2 = "model-2"; String mlNode1 = "ml-node-with-room"; String mlNode2 = "new-ml-node-with-room"; - DiscoveryNode mlNode1Node = buildNode(mlNode1, true, ByteSizeValue.ofGb(4).getBytes()); - DiscoveryNode mlNode2Node = buildNode(mlNode2, true, ByteSizeValue.ofGb(4).getBytes()); + DiscoveryNode mlNode1Node = buildNode(mlNode1, true, ByteSizeValue.ofGb(4).getBytes(), 8); + DiscoveryNode mlNode2Node = buildNode(mlNode2, true, ByteSizeValue.ofGb(4).getBytes(), 8); ClusterState stateWithTwoNodes = ClusterState.builder(new ClusterName("testShouldAllocateModels")) .nodes(DiscoveryNodes.builder().add(mlNode1Node).add(mlNode2Node)) .build(); @@ -504,12 +332,12 @@ public void testShouldAllocateModels() { .nodes(DiscoveryNodes.builder().add(mlNode1Node)) .build(); ClusterState stateWithOneNodeNotMl = ClusterState.builder(new ClusterName("testShouldAllocateModels")) - .nodes(DiscoveryNodes.builder().add(mlNode1Node).add(buildNode("not-ml-node", false, ByteSizeValue.ofGb(4).getBytes()))) + .nodes(DiscoveryNodes.builder().add(mlNode1Node).add(buildNode("not-ml-node", false, ByteSizeValue.ofGb(4).getBytes(), 8))) .build(); // No metadata in the new state means no allocations, so no updates assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels( + TrainedModelAssignmentClusterService.shouldRebalanceModels( new ClusterChangedEvent( "test", ClusterState.builder(randomFrom(stateWithOneNodeNotMl, stateWithOneNode, stateWithTwoNodes)).build(), @@ -533,7 +361,7 @@ public void testShouldAllocateModels() { // Even with metadata changes, unless there are node changes, do nothing ClusterState randomState = randomFrom(stateWithOneNodeNotMl, stateWithOneNode, stateWithTwoNodes); assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels( + TrainedModelAssignmentClusterService.shouldRebalanceModels( new ClusterChangedEvent( "test", ClusterState.builder(randomState) @@ -557,7 +385,7 @@ public void testShouldAllocateModels() { // If the node removed is not even an ML node, we should not attempt to re-allocate assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels( + TrainedModelAssignmentClusterService.shouldRebalanceModels( new ClusterChangedEvent( "test", ClusterState.builder(stateWithOneNode) @@ -591,7 +419,7 @@ public void testShouldAllocateModels() { // If the node removed is an ML node, but no models are allocated to it, we should not attempt to re-allocate assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels( + TrainedModelAssignmentClusterService.shouldRebalanceModels( new ClusterChangedEvent( "test", ClusterState.builder(stateWithOneNode) @@ -625,7 +453,7 @@ public void testShouldAllocateModels() { // If a new ML node is added, we should attempt to re-allocate assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels( + TrainedModelAssignmentClusterService.shouldRebalanceModels( new ClusterChangedEvent( "test", ClusterState.builder(stateWithTwoNodes) @@ -659,7 +487,7 @@ public void testShouldAllocateModels() { // If a new ML node is added, but allocation is stopping, we should not re-allocate assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels( + TrainedModelAssignmentClusterService.shouldRebalanceModels( new ClusterChangedEvent( "test", ClusterState.builder(stateWithTwoNodes) @@ -696,7 +524,7 @@ public void testShouldAllocateModels() { // If a new ML node is added, but its shutting down, don't re-allocate assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels( + TrainedModelAssignmentClusterService.shouldRebalanceModels( new ClusterChangedEvent( "test", ClusterState.builder(stateWithTwoNodes) @@ -731,7 +559,7 @@ public void testShouldAllocateModels() { // If a ML node is removed and its routed to, re-allocate assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels( + TrainedModelAssignmentClusterService.shouldRebalanceModels( new ClusterChangedEvent( "test", ClusterState.builder(stateWithOneNode) @@ -742,13 +570,14 @@ public void testShouldAllocateModels() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)).addNewRoutingEntry(mlNode1) + TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( model2, TrainedModelAssignment.Builder.empty(newParams("model-2", 100)) - .addNewRoutingEntry(mlNode1) - .addNewRoutingEntry(mlNode2) + .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) + .addRoutingEntry(mlNode2, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() ) @@ -763,13 +592,14 @@ public void testShouldAllocateModels() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)).addNewRoutingEntry(mlNode1) + TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( model2, TrainedModelAssignment.Builder.empty(newParams("model-2", 100)) - .addNewRoutingEntry(mlNode1) - .addNewRoutingEntry(mlNode2) + .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) + .addRoutingEntry(mlNode2, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() ) @@ -783,7 +613,7 @@ public void testShouldAllocateModels() { // If a ML node is removed and its routed to, but the allocation is stopping, don't re-allocate assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels( + TrainedModelAssignmentClusterService.shouldRebalanceModels( new ClusterChangedEvent( "test", ClusterState.builder(stateWithOneNode) @@ -794,13 +624,14 @@ public void testShouldAllocateModels() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)).addNewRoutingEntry(mlNode1) + TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( model2, TrainedModelAssignment.Builder.empty(newParams("model-2", 100)) - .addNewRoutingEntry(mlNode1) - .addNewRoutingEntry(mlNode2) + .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) + .addRoutingEntry(mlNode2, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .stopAssignment("test") ) .build() @@ -816,13 +647,14 @@ public void testShouldAllocateModels() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)).addNewRoutingEntry(mlNode1) + TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( model2, TrainedModelAssignment.Builder.empty(newParams("model-2", 100)) - .addNewRoutingEntry(mlNode1) - .addNewRoutingEntry(mlNode2) + .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) + .addRoutingEntry(mlNode2, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() ) @@ -835,23 +667,21 @@ public void testShouldAllocateModels() { ); } - public void testShouldAllocateModels_WithNodeShutdowns() { + public void testShouldRebalanceModels_WithNodeShutdowns() { String clusterName = "testShouldAllocateModels_WithNodeShutdowns"; String model1 = "model-1"; - DiscoveryNode mlNode1 = buildNode("ml-node-1", true, ByteSizeValue.ofGb(4).getBytes()); - DiscoveryNode mlNode2 = buildNode("ml-node-2", true, ByteSizeValue.ofGb(4).getBytes()); - DiscoveryNode esNode1 = buildNode("es-node-1", false, ByteSizeValue.ofGb(4).getBytes()); - DiscoveryNode esNode2 = buildNode("es-node-2", false, ByteSizeValue.ofGb(4).getBytes()); - DiscoveryNode esNode3 = buildNode("es-node-3", false, ByteSizeValue.ofGb(4).getBytes()); + DiscoveryNode mlNode1 = buildNode("ml-node-1", true, ByteSizeValue.ofGb(4).getBytes(), 8); + DiscoveryNode mlNode2 = buildNode("ml-node-2", true, ByteSizeValue.ofGb(4).getBytes(), 8); + DiscoveryNode esNode1 = buildNode("es-node-1", false, ByteSizeValue.ofGb(4).getBytes(), 8); + DiscoveryNode esNode2 = buildNode("es-node-2", false, ByteSizeValue.ofGb(4).getBytes(), 8); + DiscoveryNode esNode3 = buildNode("es-node-3", false, ByteSizeValue.ofGb(4).getBytes(), 8); TrainedModelAssignmentMetadata fullModelAllocation = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100)) - .addNewRoutingEntry(mlNode1.getId()) - .updateExistingRoutingEntry(mlNode1.getId(), started()) - .addNewRoutingEntry(mlNode2.getId()) - .updateExistingRoutingEntry(mlNode2.getId(), started()) + .addRoutingEntry(mlNode1.getId(), new RoutingInfo(1, 1, RoutingState.STARTED, "")) + .addRoutingEntry(mlNode2.getId(), new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -871,7 +701,7 @@ public void testShouldAllocateModels_WithNodeShutdowns() { .build(); assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels(new ClusterChangedEvent("test", currentState, previousState)), + TrainedModelAssignmentClusterService.shouldRebalanceModels(new ClusterChangedEvent("test", currentState, previousState)), is(true) ); @@ -887,7 +717,7 @@ public void testShouldAllocateModels_WithNodeShutdowns() { ).build(); assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels(new ClusterChangedEvent("test", currentState, previousState)), + TrainedModelAssignmentClusterService.shouldRebalanceModels(new ClusterChangedEvent("test", currentState, previousState)), is(false) ); @@ -902,7 +732,7 @@ public void testShouldAllocateModels_WithNodeShutdowns() { ).build(); assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels(new ClusterChangedEvent("test", currentState, previousState)), + TrainedModelAssignmentClusterService.shouldRebalanceModels(new ClusterChangedEvent("test", currentState, previousState)), is(false) ); @@ -914,7 +744,7 @@ public void testShouldAllocateModels_WithNodeShutdowns() { ).build(); assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels(new ClusterChangedEvent("test", currentState, previousState)), + TrainedModelAssignmentClusterService.shouldRebalanceModels(new ClusterChangedEvent("test", currentState, previousState)), is(true) ); @@ -929,7 +759,7 @@ public void testShouldAllocateModels_WithNodeShutdowns() { ).build(); assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels(new ClusterChangedEvent("test", currentState, previousState)), + TrainedModelAssignmentClusterService.shouldRebalanceModels(new ClusterChangedEvent("test", currentState, previousState)), is(false) ); @@ -944,7 +774,7 @@ public void testShouldAllocateModels_WithNodeShutdowns() { ).build(); assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels(new ClusterChangedEvent("test", currentState, previousState)), + TrainedModelAssignmentClusterService.shouldRebalanceModels(new ClusterChangedEvent("test", currentState, previousState)), is(false) ); @@ -959,7 +789,7 @@ public void testShouldAllocateModels_WithNodeShutdowns() { ).build(); assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels(new ClusterChangedEvent("test", currentState, previousState)), + TrainedModelAssignmentClusterService.shouldRebalanceModels(new ClusterChangedEvent("test", currentState, previousState)), is(false) ); @@ -974,7 +804,7 @@ public void testShouldAllocateModels_WithNodeShutdowns() { ).build(); assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels(new ClusterChangedEvent("test", currentState, previousState)), + TrainedModelAssignmentClusterService.shouldRebalanceModels(new ClusterChangedEvent("test", currentState, previousState)), is(false) ); @@ -988,7 +818,7 @@ public void testShouldAllocateModels_WithNodeShutdowns() { ).build(); assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels(new ClusterChangedEvent("test", currentState, previousState)), + TrainedModelAssignmentClusterService.shouldRebalanceModels(new ClusterChangedEvent("test", currentState, previousState)), is(true) ); @@ -998,11 +828,221 @@ public void testShouldAllocateModels_WithNodeShutdowns() { currentState = fullyAllocated; assertThat( - TrainedModelAssignmentClusterService.shouldAllocateModels(new ClusterChangedEvent("test", currentState, previousState)), + TrainedModelAssignmentClusterService.shouldRebalanceModels(new ClusterChangedEvent("test", currentState, previousState)), is(true) ); } + public void testAreAssignedNodesRemoved_GivenRemovedNodeThatIsRouted() { + String modelId = "existing-model"; + String nodeId1 = "node-1"; + String nodeId2 = "node-2"; + Metadata metadata = Metadata.builder() + .putCustom( + TrainedModelAssignmentMetadata.NAME, + TrainedModelAssignmentMetadata.Builder.empty() + .addNewAssignment( + modelId, + TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L)) + .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) + .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) + ) + .build() + ) + .build(); + DiscoveryNode node1 = buildNode(nodeId1, true, ByteSizeValue.ofGb(4).getBytes(), 8); + DiscoveryNode node2 = buildNode(nodeId2, true, ByteSizeValue.ofGb(4).getBytes(), 8); + ClusterState previousState = ClusterState.builder(new ClusterName("testAreAssignedNodesRemoved")) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).build()) + .metadata(metadata) + .build(); + ClusterState currentState = ClusterState.builder(new ClusterName("testAreAssignedNodesRemoved")) + .nodes(DiscoveryNodes.builder().add(node1).build()) + .metadata(metadata) + .build(); + ClusterChangedEvent event = new ClusterChangedEvent("test", currentState, previousState); + + assertThat(TrainedModelAssignmentClusterService.areAssignedNodesRemoved(event), is(true)); + } + + public void testAreAssignedNodesRemoved_GivenRemovedNodeThatIsNotRouted() { + String modelId = "existing-model"; + String nodeId1 = "node-1"; + String nodeId2 = "node-2"; + Metadata metadata = Metadata.builder() + .putCustom( + TrainedModelAssignmentMetadata.NAME, + TrainedModelAssignmentMetadata.Builder.empty() + .addNewAssignment( + modelId, + TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L)) + .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) + ) + .build() + ) + .build(); + DiscoveryNode node1 = buildNode(nodeId1, true, ByteSizeValue.ofGb(4).getBytes(), 8); + DiscoveryNode node2 = buildNode(nodeId2, true, ByteSizeValue.ofGb(4).getBytes(), 8); + ClusterState previousState = ClusterState.builder(new ClusterName("testAreAssignedNodesRemoved")) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).build()) + .metadata(metadata) + .build(); + ClusterState currentState = ClusterState.builder(new ClusterName("testAreAssignedNodesRemoved")) + .nodes(DiscoveryNodes.builder().add(node1).build()) + .metadata(metadata) + .build(); + ClusterChangedEvent event = new ClusterChangedEvent("test", currentState, previousState); + + assertThat(TrainedModelAssignmentClusterService.areAssignedNodesRemoved(event), is(false)); + } + + public void testAreAssignedNodesRemoved_GivenShuttingDownNodeThatIsRouted() { + String modelId = "existing-model"; + String nodeId1 = "node-1"; + String nodeId2 = "node-2"; + TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.Builder.empty() + .addNewAssignment( + modelId, + TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L)) + .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) + .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) + ) + .build(); + DiscoveryNode node1 = buildNode(nodeId1, true, ByteSizeValue.ofGb(4).getBytes(), 8); + DiscoveryNode node2 = buildNode(nodeId2, true, ByteSizeValue.ofGb(4).getBytes(), 8); + ClusterState previousState = ClusterState.builder(new ClusterName("testAreAssignedNodesRemoved")) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).build()) + .metadata(Metadata.builder().putCustom(TrainedModelAssignmentMetadata.NAME, trainedModelAssignmentMetadata)) + .build(); + ClusterState currentState = ClusterState.builder(new ClusterName("testAreAssignedNodesRemoved")) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).build()) + .metadata( + Metadata.builder() + .putCustom(TrainedModelAssignmentMetadata.NAME, trainedModelAssignmentMetadata) + .putCustom( + NodesShutdownMetadata.TYPE, + new NodesShutdownMetadata( + Map.of( + nodeId1, + SingleNodeShutdownMetadata.builder() + .setNodeId(nodeId1) + .setType(SingleNodeShutdownMetadata.Type.REMOVE) + .setStartedAtMillis(System.currentTimeMillis()) + .setReason("test") + .build() + ) + ) + ) + ) + .build(); + ClusterChangedEvent event = new ClusterChangedEvent("test", currentState, previousState); + + assertThat(TrainedModelAssignmentClusterService.areAssignedNodesRemoved(event), is(true)); + } + + public void testAreAssignedNodesRemoved_GivenShuttingDownNodeThatIsNotRouted() { + String modelId = "existing-model"; + String nodeId1 = "node-1"; + String nodeId2 = "node-2"; + TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.Builder.empty() + .addNewAssignment( + modelId, + TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L)) + .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) + ) + .build(); + DiscoveryNode node1 = buildNode(nodeId1, true, ByteSizeValue.ofGb(4).getBytes(), 8); + DiscoveryNode node2 = buildNode(nodeId2, true, ByteSizeValue.ofGb(4).getBytes(), 8); + ClusterState previousState = ClusterState.builder(new ClusterName("testAreAssignedNodesRemoved")) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).build()) + .metadata(Metadata.builder().putCustom(TrainedModelAssignmentMetadata.NAME, trainedModelAssignmentMetadata)) + .build(); + ClusterState currentState = ClusterState.builder(new ClusterName("testAreAssignedNodesRemoved")) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).build()) + .metadata( + Metadata.builder() + .putCustom(TrainedModelAssignmentMetadata.NAME, trainedModelAssignmentMetadata) + .putCustom( + NodesShutdownMetadata.TYPE, + new NodesShutdownMetadata( + Map.of( + nodeId1, + SingleNodeShutdownMetadata.builder() + .setNodeId(nodeId1) + .setType(SingleNodeShutdownMetadata.Type.REMOVE) + .setStartedAtMillis(System.currentTimeMillis()) + .setReason("test") + .build() + ) + ) + ) + ) + .build(); + ClusterChangedEvent event = new ClusterChangedEvent("test", currentState, previousState); + + assertThat(TrainedModelAssignmentClusterService.areAssignedNodesRemoved(event), is(false)); + } + + public void testRemoveRoutingToUnassignableNodes() { + String modelId1 = "model-1"; + String modelId2 = "model-2"; + String nodeId1 = "node-1"; + String nodeId2 = "node-2"; + String nodeId3 = "node-3"; + Metadata metadata = Metadata.builder() + .putCustom( + TrainedModelAssignmentMetadata.NAME, + TrainedModelAssignmentMetadata.Builder.empty() + .addNewAssignment( + modelId1, + TrainedModelAssignment.Builder.empty(newParams(modelId1, 10_000L)) + .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) + .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) + .addRoutingEntry(nodeId3, new RoutingInfo(1, 1, RoutingState.STARTED, "")) + ) + .addNewAssignment( + modelId2, + TrainedModelAssignment.Builder.empty(newParams(modelId2, 10_000L)) + .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) + .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) + .addRoutingEntry(nodeId3, new RoutingInfo(1, 1, RoutingState.STARTED, "")) + ) + .build() + ) + .putCustom( + NodesShutdownMetadata.TYPE, + new NodesShutdownMetadata( + Map.of( + nodeId3, + SingleNodeShutdownMetadata.builder() + .setNodeId(nodeId3) + .setType(SingleNodeShutdownMetadata.Type.REMOVE) + .setStartedAtMillis(System.currentTimeMillis()) + .setReason("test") + .build() + ) + ) + ) + .build(); + DiscoveryNode node1 = buildNode(nodeId1, true, ByteSizeValue.ofGb(4).getBytes(), 8); + DiscoveryNode node3 = buildNode(nodeId3, true, ByteSizeValue.ofGb(4).getBytes(), 8); + ClusterState currentState = ClusterState.builder(new ClusterName("testAreAssignedNodesRemoved")) + .nodes(DiscoveryNodes.builder().add(node1).add(node3).build()) + .metadata(metadata) + .build(); + + ClusterState resultState = TrainedModelAssignmentClusterService.removeRoutingToUnassignableNodes(currentState); + + TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(resultState); + assertThat(trainedModelAssignmentMetadata.modelAssignments(), is(aMapWithSize(2))); + for (String modelId : List.of(modelId1, modelId2)) { + TrainedModelAssignment assignment = trainedModelAssignmentMetadata.getModelAssignment(modelId); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getNodeRoutingTable(), is(aMapWithSize(1))); + assertThat(assignment.getNodeRoutingTable(), hasKey(nodeId1)); + } + } + private ClusterState.Builder csBuilderWithNodes(String name, DiscoveryNode... nodes) { var csBuilder = ClusterState.builder(new ClusterName(name)); var nodeBuilder = DiscoveryNodes.builder(); @@ -1039,7 +1079,7 @@ public void testSetAllocationToStopping() { ); ClusterState clusterStateWithAllocation = ClusterState.builder(new ClusterName("testSetAllocationToStopping")) - .nodes(DiscoveryNodes.builder().add(buildNode("test-node", true, ByteSizeValue.ofGb(4).getBytes())).build()) + .nodes(DiscoveryNodes.builder().add(buildNode("test-node", true, ByteSizeValue.ofGb(4).getBytes(), 8)).build()) .metadata( Metadata.builder() .putCustom( @@ -1087,14 +1127,14 @@ private void assertThatStoppingAssignmentPreventsMutation( } private TrainedModelAssignmentClusterService createClusterService() { - return new TrainedModelAssignmentClusterService(Settings.EMPTY, clusterService, nodeLoadDetector); + return new TrainedModelAssignmentClusterService(Settings.EMPTY, clusterService, threadPool, nodeLoadDetector); } - private static DiscoveryNode buildNode(String name, boolean isML, long nativeMemory) { - return buildNode(name, isML, nativeMemory, Version.CURRENT); + private static DiscoveryNode buildNode(String name, boolean isML, long nativeMemory, int allocatedProcessors) { + return buildNode(name, isML, nativeMemory, allocatedProcessors, Version.CURRENT); } - private static DiscoveryNode buildNode(String name, boolean isML, long nativeMemory, Version version) { + private static DiscoveryNode buildNode(String name, boolean isML, long nativeMemory, int allocatedProcessors, Version version) { return new DiscoveryNode( name, name, @@ -1102,26 +1142,32 @@ private static DiscoveryNode buildNode(String name, boolean isML, long nativeMem MapBuilder.newMapBuilder() .put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, String.valueOf(nativeMemory)) .put(MachineLearning.MAX_JVM_SIZE_NODE_ATTR, String.valueOf(10)) + .put(MachineLearning.ALLOCATED_PROCESSORS_NODE_ATTR, String.valueOf(allocatedProcessors)) .map(), isML ? DiscoveryNodeRole.roles() : Set.of(DiscoveryNodeRole.DATA_ROLE, DiscoveryNodeRole.MASTER_ROLE), version ); } - private static RoutingStateAndReason started() { - return new RoutingStateAndReason(RoutingState.STARTED, ""); + private static RoutingInfoUpdate started() { + return RoutingInfoUpdate.updateStateAndReason(new RoutingStateAndReason(RoutingState.STARTED, "")); } - private static DiscoveryNode buildOldNode(String name, boolean isML, long nativeMemory) { - return buildNode(name, isML, nativeMemory, Version.V_7_15_0); + private static DiscoveryNode buildOldNode(String name, boolean isML, long nativeMemory, int allocatedProcessors) { + return buildNode(name, isML, nativeMemory, allocatedProcessors, Version.V_7_15_0); } private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId, long modelSize) { - return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, 1, 1, 1024); + return newParams(modelId, modelSize, 1, 1); } - private static void assertNodeState(TrainedModelAssignmentMetadata metadata, String modelId, String nodeId, RoutingState routingState) { - assertThat(metadata.getModelAssignment(modelId).getNodeRoutingTable().get(nodeId).getState(), equalTo(routingState)); + private static StartTrainedModelDeploymentAction.TaskParams newParams( + String modelId, + long modelSize, + int numberOfAllocations, + int threadsPerAllocation + ) { + return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, threadsPerAllocation, numberOfAllocations, 1024); } private static NodesShutdownMetadata shutdownMetadata(String nodeId) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java index a9f23e2a56425..f0f07bbaaa472 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java @@ -49,32 +49,6 @@ protected TrainedModelAssignmentMetadata createTestInstance() { return new TrainedModelAssignmentMetadata(new HashMap<>()); } - public void testBuilderChanged_WhenAddingRemovingModel() { - TrainedModelAssignmentMetadata original = randomInstance(); - String newModel = "foo_model"; - - TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.Builder.fromMetadata(original); - assertThat(builder.isChanged(), is(false)); - - assertUnchanged(builder, b -> b.removeAssignment(newModel)); - - builder.addNewAssignment(newModel, TrainedModelAssignment.Builder.empty(randomParams(newModel))); - assertThat(builder.isChanged(), is(true)); - } - - public void testBuilderChangedWhenAssignmentChanged() { - String allocatedModelId = "test_model_id"; - TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.Builder.fromMetadata( - TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(allocatedModelId, TrainedModelAssignment.Builder.empty(randomParams(allocatedModelId))) - .build() - ); - assertThat(builder.isChanged(), is(false)); - - builder.getAssignment(allocatedModelId).addNewRoutingEntry("new-node"); - assertThat(builder.isChanged(), is(true)); - } - public void testIsAllocated() { String allocatedModelId = "test_model_id"; TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.Builder.empty() @@ -84,15 +58,6 @@ public void testIsAllocated() { assertThat(metadata.isAssigned("unknown_model_id"), is(false)); } - private static TrainedModelAssignmentMetadata.Builder assertUnchanged( - TrainedModelAssignmentMetadata.Builder builder, - Function function - ) { - function.apply(builder); - assertThat(builder.isChanged(), is(false)); - return builder; - } - private static StartTrainedModelDeploymentAction.TaskParams randomParams(String modelId) { return new StartTrainedModelDeploymentAction.TaskParams( modelId, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java index c607f2fd7382d..21af812cbeab3 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java @@ -32,9 +32,9 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; -import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentStateAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; -import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager; import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; @@ -55,6 +55,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; public class TrainedModelAssignmentNodeServiceTests extends ESTestCase { @@ -101,16 +102,22 @@ public void shutdown() throws InterruptedException { terminate(threadPool); } - public void testLoadQueuedModels() { + public void testLoadQueuedModels_GivenNoQueuedModels() { TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); // When there are no queued models trainedModelAssignmentNodeService.loadQueuedModels(); verify(deploymentManager, never()).startDeployment(any(), any()); + } + + public void testLoadQueuedModels() { + TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); String modelToLoad = "loading-model"; String anotherModel = "loading-model-again"; + givenAssignmentsInClusterStateForModels(modelToLoad, anotherModel); + // Should only load each model once trainedModelAssignmentNodeService.prepareModelToLoad(newParams(modelToLoad)); trainedModelAssignmentNodeService.prepareModelToLoad(newParams(modelToLoad)); @@ -119,8 +126,8 @@ public void testLoadQueuedModels() { trainedModelAssignmentNodeService.loadQueuedModels(); ArgumentCaptor taskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); - ArgumentCaptor requestCapture = ArgumentCaptor.forClass( - UpdateTrainedModelAssignmentStateAction.Request.class + ArgumentCaptor requestCapture = ArgumentCaptor.forClass( + UpdateTrainedModelAssignmentRoutingInfoAction.Request.class ); verify(deploymentManager, times(2)).startDeployment(taskCapture.capture(), any()); verify(trainedModelAssignmentService, times(2)).updateModelAssignmentState(requestCapture.capture(), any()); @@ -128,12 +135,12 @@ public void testLoadQueuedModels() { assertThat(taskCapture.getAllValues().get(0).getModelId(), equalTo(modelToLoad)); assertThat(requestCapture.getAllValues().get(0).getModelId(), equalTo(modelToLoad)); assertThat(requestCapture.getAllValues().get(0).getNodeId(), equalTo(NODE_ID)); - assertThat(requestCapture.getAllValues().get(0).getRoutingState().getState(), equalTo(RoutingState.STARTED)); + assertThat(requestCapture.getAllValues().get(0).getUpdate().getStateAndReason().get().getState(), equalTo(RoutingState.STARTED)); assertThat(taskCapture.getAllValues().get(1).getModelId(), equalTo(anotherModel)); assertThat(requestCapture.getAllValues().get(1).getModelId(), equalTo(anotherModel)); assertThat(requestCapture.getAllValues().get(1).getNodeId(), equalTo(NODE_ID)); - assertThat(requestCapture.getAllValues().get(1).getRoutingState().getState(), equalTo(RoutingState.STARTED)); + assertThat(requestCapture.getAllValues().get(1).getUpdate().getStateAndReason().get().getState(), equalTo(RoutingState.STARTED)); // Since models are loaded, there shouldn't be any more loadings to occur trainedModelAssignmentNodeService.prepareModelToLoad(newParams(anotherModel)); @@ -144,6 +151,7 @@ public void testLoadQueuedModels() { public void testLoadQueuedModelsWhenFailureIsRetried() { String modelToLoad = "loading-model"; String failedModelToLoad = "failed-search-loading-model"; + givenAssignmentsInClusterStateForModels(modelToLoad, failedModelToLoad); withSearchingLoadFailure(failedModelToLoad); TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); @@ -155,8 +163,8 @@ public void testLoadQueuedModelsWhenFailureIsRetried() { trainedModelAssignmentNodeService.loadQueuedModels(); ArgumentCaptor startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); - ArgumentCaptor requestCapture = ArgumentCaptor.forClass( - UpdateTrainedModelAssignmentStateAction.Request.class + ArgumentCaptor requestCapture = ArgumentCaptor.forClass( + UpdateTrainedModelAssignmentRoutingInfoAction.Request.class ); verify(deploymentManager, times(3)).startDeployment(startTaskCapture.capture(), any()); // Only the successful one is notifying, the failed one keeps retrying but not notifying as it is never successful @@ -165,7 +173,7 @@ public void testLoadQueuedModelsWhenFailureIsRetried() { assertThat(startTaskCapture.getAllValues().get(0).getModelId(), equalTo(modelToLoad)); assertThat(requestCapture.getAllValues().get(0).getModelId(), equalTo(modelToLoad)); assertThat(requestCapture.getAllValues().get(0).getNodeId(), equalTo(NODE_ID)); - assertThat(requestCapture.getAllValues().get(0).getRoutingState().getState(), equalTo(RoutingState.STARTED)); + assertThat(requestCapture.getAllValues().get(0).getUpdate().getStateAndReason().get().getState(), equalTo(RoutingState.STARTED)); assertThat(startTaskCapture.getAllValues().get(1).getModelId(), equalTo(failedModelToLoad)); assertThat(startTaskCapture.getAllValues().get(2).getModelId(), equalTo(failedModelToLoad)); @@ -194,6 +202,8 @@ public void testLoadQueuedModelsWhenTaskIsStopped() throws Exception { String modelToLoad = "loading-model"; String stoppedModelToLoad = "stopped-loading-model"; + givenAssignmentsInClusterStateForModels(modelToLoad, stoppedModelToLoad); + // Only one model should be loaded, the other should be stopped trainedModelAssignmentNodeService.prepareModelToLoad(newParams(modelToLoad)); trainedModelAssignmentNodeService.prepareModelToLoad(newParams(stoppedModelToLoad)); @@ -206,26 +216,26 @@ public void testLoadQueuedModelsWhenTaskIsStopped() throws Exception { assertThat(stoppedTaskCapture.getValue().getModelId(), equalTo(stoppedModelToLoad)); }); ArgumentCaptor startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); - ArgumentCaptor requestCapture = ArgumentCaptor.forClass( - UpdateTrainedModelAssignmentStateAction.Request.class + ArgumentCaptor requestCapture = ArgumentCaptor.forClass( + UpdateTrainedModelAssignmentRoutingInfoAction.Request.class ); verify(deploymentManager, times(1)).startDeployment(startTaskCapture.capture(), any()); assertBusy(() -> verify(trainedModelAssignmentService, times(3)).updateModelAssignmentState(requestCapture.capture(), any())); boolean seenStopping = false; for (int i = 0; i < 3; i++) { - UpdateTrainedModelAssignmentStateAction.Request request = requestCapture.getAllValues().get(i); + UpdateTrainedModelAssignmentRoutingInfoAction.Request request = requestCapture.getAllValues().get(i); assertThat(request.getNodeId(), equalTo(NODE_ID)); if (request.getModelId().equals(stoppedModelToLoad)) { if (seenStopping) { - assertThat(request.getRoutingState().getState(), equalTo(RoutingState.STOPPED)); + assertThat(request.getUpdate().getStateAndReason().get().getState(), equalTo(RoutingState.STOPPED)); } else { - assertThat(request.getRoutingState().getState(), equalTo(RoutingState.STOPPING)); + assertThat(request.getUpdate().getStateAndReason().get().getState(), equalTo(RoutingState.STOPPING)); seenStopping = true; } } else { assertThat(request.getModelId(), equalTo(modelToLoad)); - assertThat(request.getRoutingState().getState(), equalTo(RoutingState.STARTED)); + assertThat(request.getUpdate().getStateAndReason().get().getState(), equalTo(RoutingState.STARTED)); } } assertThat(startTaskCapture.getAllValues().get(0).getModelId(), equalTo(modelToLoad)); @@ -236,6 +246,7 @@ public void testLoadQueuedModelsWhenTaskIsStopped() throws Exception { public void testLoadQueuedModelsWhenOneFails() throws InterruptedException { String modelToLoad = "loading-model"; String failedModelToLoad = "failed-loading-model"; + givenAssignmentsInClusterStateForModels(modelToLoad, failedModelToLoad); withLoadFailure(failedModelToLoad); TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); @@ -253,8 +264,8 @@ public void testLoadQueuedModelsWhenOneFails() throws InterruptedException { latch.await(5, TimeUnit.SECONDS); ArgumentCaptor startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); - ArgumentCaptor requestCapture = ArgumentCaptor.forClass( - UpdateTrainedModelAssignmentStateAction.Request.class + ArgumentCaptor requestCapture = ArgumentCaptor.forClass( + UpdateTrainedModelAssignmentRoutingInfoAction.Request.class ); verify(deploymentManager, times(2)).startDeployment(startTaskCapture.capture(), any()); verify(trainedModelAssignmentService, times(2)).updateModelAssignmentState(requestCapture.capture(), any()); @@ -265,12 +276,12 @@ public void testLoadQueuedModelsWhenOneFails() throws InterruptedException { assertThat(startTaskCapture.getAllValues().get(0).getModelId(), equalTo(modelToLoad)); assertThat(requestCapture.getAllValues().get(0).getModelId(), equalTo(modelToLoad)); assertThat(requestCapture.getAllValues().get(0).getNodeId(), equalTo(NODE_ID)); - assertThat(requestCapture.getAllValues().get(0).getRoutingState().getState(), equalTo(RoutingState.STARTED)); + assertThat(requestCapture.getAllValues().get(0).getUpdate().getStateAndReason().get().getState(), equalTo(RoutingState.STARTED)); assertThat(startTaskCapture.getAllValues().get(1).getModelId(), equalTo(failedModelToLoad)); assertThat(requestCapture.getAllValues().get(1).getModelId(), equalTo(failedModelToLoad)); assertThat(requestCapture.getAllValues().get(1).getNodeId(), equalTo(NODE_ID)); - assertThat(requestCapture.getAllValues().get(1).getRoutingState().getState(), equalTo(RoutingState.FAILED)); + assertThat(requestCapture.getAllValues().get(1).getUpdate().getStateAndReason().get().getState(), equalTo(RoutingState.FAILED)); assertThat(stopTaskCapture.getValue().getModelId(), equalTo(failedModelToLoad)); @@ -306,15 +317,18 @@ public void testClusterChangedWithResetMode() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelOne, - TrainedModelAssignment.Builder.empty(newParams(modelOne)).addNewRoutingEntry(NODE_ID) + TrainedModelAssignment.Builder.empty(newParams(modelOne)) + .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( modelTwo, - TrainedModelAssignment.Builder.empty(newParams(modelTwo)).addNewRoutingEntry(NODE_ID) + TrainedModelAssignment.Builder.empty(newParams(modelTwo)) + .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( notUsedModel, - TrainedModelAssignment.Builder.empty(newParams(notUsedModel)).addNewRoutingEntry("some-other-node") + TrainedModelAssignment.Builder.empty(newParams(notUsedModel)) + .addRoutingEntry("some-other-node", new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() ) @@ -349,6 +363,7 @@ public void testClusterChanged() throws Exception { String modelTwo = "model-2"; String notUsedModel = "model-3"; String previouslyUsedModel = "model-4"; + givenAssignmentsInClusterStateForModels(modelOne, modelTwo, previouslyUsedModel); ClusterChangedEvent event = new ClusterChangedEvent( "testClusterChanged", ClusterState.builder(new ClusterName("testClusterChanged")) @@ -360,15 +375,18 @@ public void testClusterChanged() throws Exception { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelOne, - TrainedModelAssignment.Builder.empty(newParams(modelOne)).addNewRoutingEntry(NODE_ID) + TrainedModelAssignment.Builder.empty(newParams(modelOne)) + .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( modelTwo, TrainedModelAssignment.Builder.empty(newParams(modelTwo)) - .addNewRoutingEntry(NODE_ID) + .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .updateExistingRoutingEntry( NODE_ID, - new RoutingStateAndReason( + new RoutingInfo( + 1, + 1, randomFrom(RoutingState.STARTED, RoutingState.STARTING), randomAlphaOfLength(10) ) @@ -377,10 +395,12 @@ public void testClusterChanged() throws Exception { .addNewAssignment( previouslyUsedModel, TrainedModelAssignment.Builder.empty(newParams(modelTwo)) - .addNewRoutingEntry(NODE_ID) + .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .updateExistingRoutingEntry( NODE_ID, - new RoutingStateAndReason( + new RoutingInfo( + 1, + 1, randomFrom(RoutingState.STOPPED, RoutingState.FAILED, RoutingState.STOPPING), randomAlphaOfLength(10) ) @@ -388,7 +408,8 @@ public void testClusterChanged() throws Exception { ) .addNewAssignment( notUsedModel, - TrainedModelAssignment.Builder.empty(newParams(notUsedModel)).addNewRoutingEntry("some-other-node") + TrainedModelAssignment.Builder.empty(newParams(notUsedModel)) + .addRoutingEntry("some-other-node", new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() ) @@ -411,15 +432,18 @@ public void testClusterChanged() throws Exception { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelOne, - TrainedModelAssignment.Builder.empty(newParams(modelOne)).addNewRoutingEntry(NODE_ID) + TrainedModelAssignment.Builder.empty(newParams(modelOne)) + .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( modelTwo, - TrainedModelAssignment.Builder.empty(newParams(modelTwo)).addNewRoutingEntry("some-other-node") + TrainedModelAssignment.Builder.empty(newParams(modelTwo)) + .addRoutingEntry("some-other-node", new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( notUsedModel, - TrainedModelAssignment.Builder.empty(newParams(notUsedModel)).addNewRoutingEntry("some-other-node") + TrainedModelAssignment.Builder.empty(newParams(notUsedModel)) + .addRoutingEntry("some-other-node", new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() ) @@ -438,8 +462,8 @@ public void testClusterChanged() throws Exception { assertThat(stoppedTaskCapture.getAllValues().get(0).getModelId(), equalTo(modelTwo)); }); ArgumentCaptor startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); - ArgumentCaptor requestCapture = ArgumentCaptor.forClass( - UpdateTrainedModelAssignmentStateAction.Request.class + ArgumentCaptor requestCapture = ArgumentCaptor.forClass( + UpdateTrainedModelAssignmentRoutingInfoAction.Request.class ); verify(deploymentManager, times(1)).startDeployment(startTaskCapture.capture(), any()); verify(trainedModelAssignmentService, times(1)).updateModelAssignmentState(requestCapture.capture(), any()); @@ -447,7 +471,7 @@ public void testClusterChanged() throws Exception { assertThat(startTaskCapture.getAllValues().get(0).getModelId(), equalTo(modelOne)); assertThat(requestCapture.getAllValues().get(0).getModelId(), equalTo(modelOne)); assertThat(requestCapture.getAllValues().get(0).getNodeId(), equalTo(NODE_ID)); - assertThat(requestCapture.getAllValues().get(0).getRoutingState().getState(), equalTo(RoutingState.STARTED)); + assertThat(requestCapture.getAllValues().get(0).getUpdate().getStateAndReason().get().getState(), equalTo(RoutingState.STARTED)); event = new ClusterChangedEvent( "testClusterChanged", @@ -460,7 +484,8 @@ public void testClusterChanged() throws Exception { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelOne, - TrainedModelAssignment.Builder.empty(newParams(modelOne)).addNewRoutingEntry(NODE_ID) + TrainedModelAssignment.Builder.empty(newParams(modelOne)) + .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() ) @@ -476,6 +501,100 @@ public void testClusterChanged() throws Exception { verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); } + public void testClusterChanged_GivenAllStartedAssignments_AndNonMatchingTargetAllocations() throws Exception { + final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); + final DiscoveryNodes nodes = DiscoveryNodes.builder() + .localNodeId(NODE_ID) + .add( + new DiscoveryNode( + NODE_ID, + NODE_ID, + buildNewFakeTransportAddress(), + Collections.emptyMap(), + DiscoveryNodeRole.roles(), + Version.CURRENT + ) + ) + .build(); + String modelOne = "model-1"; + String modelTwo = "model-2"; + givenAssignmentsInClusterStateForModels(modelOne, modelTwo); + trainedModelAssignmentNodeService.prepareModelToLoad(newParams(modelOne)); + trainedModelAssignmentNodeService.prepareModelToLoad(newParams(modelTwo)); + trainedModelAssignmentNodeService.loadQueuedModels(); + + ClusterChangedEvent event = new ClusterChangedEvent( + "shouldUpdateAllocations", + ClusterState.builder(new ClusterName("shouldUpdateAllocations")) + .nodes(nodes) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAssignmentMetadata.NAME, + TrainedModelAssignmentMetadata.Builder.empty() + .addNewAssignment( + modelOne, + TrainedModelAssignment.Builder.empty(newParams(modelOne)) + .addRoutingEntry(NODE_ID, new RoutingInfo(1, 3, RoutingState.STARTED, "")) + ) + .addNewAssignment( + modelTwo, + TrainedModelAssignment.Builder.empty(newParams(modelTwo)) + .addRoutingEntry(NODE_ID, new RoutingInfo(2, 1, RoutingState.STARTED, "")) + ) + .build() + ) + .build() + ) + .build(), + ClusterState.EMPTY_STATE + ); + + trainedModelAssignmentNodeService.clusterChanged(event); + + assertBusy(() -> { + ArgumentCaptor updatedTasks = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); + ArgumentCaptor updatedAllocations = ArgumentCaptor.forClass(Integer.class); + verify(deploymentManager, times(2)).updateNumAllocations(updatedTasks.capture(), updatedAllocations.capture(), any(), any()); + assertThat(updatedTasks.getAllValues().get(0).getModelId(), equalTo(modelOne)); + assertThat(updatedTasks.getAllValues().get(1).getModelId(), equalTo(modelTwo)); + }); + ArgumentCaptor startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); + ArgumentCaptor updateCapture = ArgumentCaptor.forClass( + UpdateTrainedModelAssignmentRoutingInfoAction.Request.class + ); + verify(deploymentManager, times(2)).startDeployment(startTaskCapture.capture(), any()); + verify(trainedModelAssignmentService, times(2)).updateModelAssignmentState(updateCapture.capture(), any()); + + assertThat(startTaskCapture.getAllValues().get(0).getModelId(), equalTo(modelOne)); + assertThat(startTaskCapture.getAllValues().get(1).getModelId(), equalTo(modelTwo)); + assertThat(updateCapture.getAllValues().get(0).getModelId(), equalTo(modelOne)); + assertThat(updateCapture.getAllValues().get(0).getNodeId(), equalTo(NODE_ID)); + assertThat(updateCapture.getAllValues().get(0).getUpdate().getStateAndReason().get().getState(), equalTo(RoutingState.STARTED)); + assertThat(updateCapture.getAllValues().get(1).getModelId(), equalTo(modelTwo)); + assertThat(updateCapture.getAllValues().get(1).getNodeId(), equalTo(NODE_ID)); + assertThat(updateCapture.getAllValues().get(1).getUpdate().getStateAndReason().get().getState(), equalTo(RoutingState.STARTED)); + + verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); + } + + private void givenAssignmentsInClusterStateForModels(String... modelIds) { + TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.Builder.empty(); + for (String modelId : modelIds) { + builder.addNewAssignment( + modelId, + TrainedModelAssignment.Builder.empty(newParams(modelId)) + .addRoutingEntry("test-node", new RoutingInfo(1, 1, RoutingState.STARTING, "")) + ); + } + + ClusterState currentState = ClusterState.builder(new ClusterName("testLoadQueuedModels")) + .metadata(Metadata.builder().putCustom(TrainedModelAssignmentMetadata.NAME, builder.build()).build()) + .build(); + + when(clusterService.state()).thenReturn(currentState); + } + @SuppressWarnings({ "rawtypes", "unchecked" }) private void withLoadFailure(String modelId) { doAnswer(invocationOnMock -> { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java new file mode 100644 index 0000000000000..3c660f8494da5 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java @@ -0,0 +1,482 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.assignment; + +import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.Version; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodeRole; +import org.elasticsearch.common.collect.MapBuilder; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.job.NodeLoad; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.anEmptyMap; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; + +public class TrainedModelAssignmentRebalancerTests extends ESTestCase { + + public void testRebalance_GivenNoAssignments() throws Exception { + TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer( + TrainedModelAssignmentMetadata.Builder.empty().build(), + Map.of(), + Optional.empty() + ).rebalance().build(); + assertThat(result.modelAssignments().isEmpty(), is(true)); + } + + public void testRebalance_GivenModelToAddAlreadyExists() { + String modelId = "model-to-add"; + StartTrainedModelDeploymentAction.TaskParams taskParams = newParams(modelId, 1024L, 1, 1); + TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() + .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(taskParams)) + .build(); + expectThrows( + ResourceAlreadyExistsException.class, + () -> new TrainedModelAssignmentRebalancer(currentMetadata, Map.of(), Optional.of(taskParams)).rebalance() + ); + } + + public void testRebalance_GivenFirstModelToAdd_NoMLNodes() throws Exception { + String modelId = "model-to-add"; + StartTrainedModelDeploymentAction.TaskParams taskParams = newParams(modelId, 1024L, 1, 1); + TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty().build(); + + TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, Map.of(), Optional.of(taskParams)) + .rebalance() + .build(); + + TrainedModelAssignment assignment = result.getModelAssignment(modelId); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTING)); + assertThat(assignment.getNodeRoutingTable(), is(anEmptyMap())); + assertThat(assignment.getReason().isPresent(), is(true)); + assertThat(assignment.getReason().get(), equalTo("No ML nodes exist in the cluster")); + } + + public void testRebalance_GivenFirstModelToAdd_NotEnoughProcessors() throws Exception { + String modelId = "model-to-add"; + StartTrainedModelDeploymentAction.TaskParams taskParams = newParams(modelId, 1024L, 1, 4); + TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty().build(); + Map nodeLoads = new HashMap<>(); + long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes(); + nodeLoads.put(buildNode("node-1", nodeMemoryBytes, 3), NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build()); + + TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.of(taskParams)) + .rebalance() + .build(); + + TrainedModelAssignment assignment = result.getModelAssignment(modelId); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTING)); + assertThat(assignment.getNodeRoutingTable(), is(anEmptyMap())); + assertThat(assignment.getReason().isPresent(), is(true)); + assertThat( + assignment.getReason().get(), + equalTo( + "Could not assign (more) allocations on node [node-1]. Reason: This node has insufficient allocated processors. " + + "Available processors [3], free processors [3], processors required for each allocation of this model [4]" + ) + ); + } + + public void testRebalance_GivenFirstModelToAdd_NotEnoughMemory() throws Exception { + String modelId = "model-to-add"; + StartTrainedModelDeploymentAction.TaskParams taskParams = newParams(modelId, ByteSizeValue.ofGb(2).getBytes(), 1, 1); + TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty().build(); + Map nodeLoads = new HashMap<>(); + long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes(); + nodeLoads.put(buildNode("node-1", nodeMemoryBytes, 3), NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build()); + + TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.of(taskParams)) + .rebalance() + .build(); + + TrainedModelAssignment assignment = result.getModelAssignment(modelId); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTING)); + assertThat(assignment.getNodeRoutingTable(), is(anEmptyMap())); + assertThat(assignment.getReason().isPresent(), is(true)); + assertThat( + assignment.getReason().get(), + containsString("Could not assign (more) allocations on node [node-1]. Reason: This node has insufficient available memory.") + ); + } + + public void testRebalance_GivenFirstModelToAdd_ErrorDetectingNodeLoad() throws Exception { + String modelId = "model-to-add"; + StartTrainedModelDeploymentAction.TaskParams taskParams = newParams(modelId, ByteSizeValue.ofGb(2).getBytes(), 1, 1); + TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty().build(); + Map nodeLoads = new HashMap<>(); + long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes(); + nodeLoads.put( + buildNode("node-1", nodeMemoryBytes, 3), + NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).setError("error detecting load").build() + ); + + TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.of(taskParams)) + .rebalance() + .build(); + + TrainedModelAssignment assignment = result.getModelAssignment(modelId); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTING)); + assertThat(assignment.getNodeRoutingTable(), is(anEmptyMap())); + assertThat(assignment.getReason().isPresent(), is(true)); + assertThat( + assignment.getReason().get(), + containsString("Could not assign (more) allocations on node [node-1]. Reason: error detecting load") + ); + } + + public void testRebalance_GivenProblemsOnMultipleNodes() throws Exception { + String modelId = "model-to-add"; + StartTrainedModelDeploymentAction.TaskParams taskParams = newParams(modelId, ByteSizeValue.ofGb(2).getBytes(), 1, 4); + TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty().build(); + Map nodeLoads = new HashMap<>(); + nodeLoads.put( + buildNode("node-1", ByteSizeValue.ofGb(1).getBytes(), 8), + NodeLoad.builder("node-1").setMaxMemory(ByteSizeValue.ofGb(1).getBytes()).build() + ); + nodeLoads.put( + buildNode("node-2", ByteSizeValue.ofGb(10).getBytes(), 3), + NodeLoad.builder("node-2").setMaxMemory(ByteSizeValue.ofGb(10).getBytes()).build() + ); + + TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.of(taskParams)) + .rebalance() + .build(); + + TrainedModelAssignment assignment = result.getModelAssignment(modelId); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTING)); + assertThat(assignment.getNodeRoutingTable(), is(anEmptyMap())); + assertThat(assignment.getReason().isPresent(), is(true)); + assertThat( + assignment.getReason().get(), + containsString("Could not assign (more) allocations on node [node-1]. Reason: This node has insufficient available memory.") + ); + assertThat( + assignment.getReason().get(), + containsString("Could not assign (more) allocations on node [node-2]. Reason: This node has insufficient allocated processors.") + ); + } + + public void testRebalance_GivenFirstModelToAdd_FitsFully() throws Exception { + String modelId = "model-to-add"; + StartTrainedModelDeploymentAction.TaskParams taskParams = newParams(modelId, 1024L, 1, 1); + TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty().build(); + Map nodeLoads = new HashMap<>(); + long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes(); + nodeLoads.put(buildNode("node-1", nodeMemoryBytes, 4), NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build()); + + TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.of(taskParams)) + .rebalance() + .build(); + + TrainedModelAssignment assignment = result.getModelAssignment(modelId); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTING)); + assertThat(assignment.getNodeRoutingTable(), is(aMapWithSize(1))); + assertThat(assignment.getNodeRoutingTable(), hasKey("node-1")); + assertThat(assignment.getNodeRoutingTable().get("node-1").getCurrentAllocations(), equalTo(1)); + assertThat(assignment.getNodeRoutingTable().get("node-1").getTargetAllocations(), equalTo(1)); + assertThat(assignment.getNodeRoutingTable().get("node-1").getState(), equalTo(RoutingState.STARTING)); + assertThat(assignment.getReason().isPresent(), is(false)); + } + + public void testRebalance_GivenModelToAdd_AndPreviousAssignments_AndTwoNodes_AllFit() throws Exception { + String modelToAddId = "model-to-add"; + String previousModelId = "previous-model"; + StartTrainedModelDeploymentAction.TaskParams taskParams = newParams(modelToAddId, 1024L, 1, 2); + TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() + .addNewAssignment( + previousModelId, + TrainedModelAssignment.Builder.empty(newParams(previousModelId, 1024L, 3, 2)) + .addRoutingEntry("node-1", new RoutingInfo(2, 2, RoutingState.STARTED, "")) + .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + ) + .build(); + Map nodeLoads = new HashMap<>(); + long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes(); + nodeLoads.put(buildNode("node-1", nodeMemoryBytes, 4), NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build()); + nodeLoads.put(buildNode("node-2", nodeMemoryBytes, 4), NodeLoad.builder("node-2").setMaxMemory(nodeMemoryBytes).build()); + + TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.of(taskParams)) + .rebalance() + .build(); + + assertThat(result.modelAssignments(), is(aMapWithSize(2))); + + { + TrainedModelAssignment assignment = result.getModelAssignment(modelToAddId); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTING)); + assertThat(assignment.getNodeRoutingTable(), is(aMapWithSize(1))); + assertThat(assignment.getNodeRoutingTable(), hasKey("node-2")); + assertThat(assignment.getNodeRoutingTable().get("node-2").getCurrentAllocations(), equalTo(1)); + assertThat(assignment.getNodeRoutingTable().get("node-2").getTargetAllocations(), equalTo(1)); + assertThat(assignment.getNodeRoutingTable().get("node-2").getState(), equalTo(RoutingState.STARTING)); + assertThat(assignment.getReason().isPresent(), is(false)); + } + { + TrainedModelAssignment assignment = result.getModelAssignment(previousModelId); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTED)); + assertThat(assignment.getNodeRoutingTable(), is(aMapWithSize(2))); + assertThat(assignment.getNodeRoutingTable(), hasKey("node-1")); + assertThat(assignment.getNodeRoutingTable(), hasKey("node-2")); + assertThat(assignment.getNodeRoutingTable().get("node-1").getCurrentAllocations(), equalTo(2)); + assertThat(assignment.getNodeRoutingTable().get("node-1").getTargetAllocations(), equalTo(2)); + assertThat(assignment.getNodeRoutingTable().get("node-1").getState(), equalTo(RoutingState.STARTED)); + assertThat(assignment.getNodeRoutingTable().get("node-2").getCurrentAllocations(), equalTo(1)); + assertThat(assignment.getNodeRoutingTable().get("node-2").getTargetAllocations(), equalTo(1)); + assertThat(assignment.getNodeRoutingTable().get("node-2").getState(), equalTo(RoutingState.STARTED)); + assertThat(assignment.getReason().isPresent(), is(false)); + } + } + + public void testRebalance_GivenPreviousAssignments_AndNewNode() throws Exception { + String previousModel1Id = "previous-model-1"; + String previousModel2Id = "previous-model-2"; + TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() + .addNewAssignment( + previousModel1Id, + TrainedModelAssignment.Builder.empty(newParams(previousModel1Id, 1024L, 3, 2)) + .addRoutingEntry("node-1", new RoutingInfo(2, 2, RoutingState.STARTED, "")) + .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + ) + .addNewAssignment( + previousModel2Id, + TrainedModelAssignment.Builder.empty(newParams(previousModel2Id, 1024L, 4, 1)) + .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + ) + .build(); + Map nodeLoads = new HashMap<>(); + long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes(); + nodeLoads.put(buildNode("node-1", nodeMemoryBytes, 4), NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build()); + nodeLoads.put(buildNode("node-2", nodeMemoryBytes, 4), NodeLoad.builder("node-2").setMaxMemory(nodeMemoryBytes).build()); + nodeLoads.put(buildNode("node-3", nodeMemoryBytes, 4), NodeLoad.builder("node-3").setMaxMemory(nodeMemoryBytes).build()); + + TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.empty()) + .rebalance() + .build(); + + assertThat(result.modelAssignments(), is(aMapWithSize(2))); + + { + TrainedModelAssignment assignment = result.getModelAssignment(previousModel1Id); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTED)); + assertThat(assignment.getNodeRoutingTable(), is(aMapWithSize(2))); + assertThat(assignment.getNodeRoutingTable(), hasKey("node-1")); + assertThat(assignment.getNodeRoutingTable(), hasKey("node-2")); + assertThat(assignment.getNodeRoutingTable().get("node-1").getCurrentAllocations(), equalTo(2)); + assertThat(assignment.getNodeRoutingTable().get("node-1").getTargetAllocations(), equalTo(2)); + assertThat(assignment.getNodeRoutingTable().get("node-1").getState(), equalTo(RoutingState.STARTED)); + assertThat(assignment.getNodeRoutingTable().get("node-2").getCurrentAllocations(), equalTo(1)); + assertThat(assignment.getNodeRoutingTable().get("node-2").getTargetAllocations(), equalTo(1)); + assertThat(assignment.getNodeRoutingTable().get("node-2").getState(), equalTo(RoutingState.STARTED)); + assertThat(assignment.getReason().isPresent(), is(false)); + } + { + TrainedModelAssignment assignment = result.getModelAssignment(previousModel2Id); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTED)); + assertThat(assignment.getNodeRoutingTable(), is(aMapWithSize(2))); + assertThat(assignment.getNodeRoutingTable(), hasKey("node-2")); + assertThat(assignment.getNodeRoutingTable(), hasKey("node-3")); + assertThat(assignment.getNodeRoutingTable().get("node-2").getCurrentAllocations(), equalTo(1)); + assertThat(assignment.getNodeRoutingTable().get("node-2").getTargetAllocations(), equalTo(2)); + assertThat(assignment.getNodeRoutingTable().get("node-2").getState(), equalTo(RoutingState.STARTED)); + assertThat(assignment.getNodeRoutingTable().get("node-3").getCurrentAllocations(), equalTo(2)); + assertThat(assignment.getNodeRoutingTable().get("node-3").getTargetAllocations(), equalTo(2)); + assertThat(assignment.getNodeRoutingTable().get("node-3").getState(), equalTo(RoutingState.STARTING)); + assertThat(assignment.getReason().isPresent(), is(false)); + } + } + + public void testRebalance_GivenPreviousAssignments_AndRemovedNode_AndRemainingNodeNotLargeEnough() throws Exception { + String previousModel1Id = "previous-model-1"; + String previousModel2Id = "previous-model-2"; + TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() + .addNewAssignment( + previousModel1Id, + TrainedModelAssignment.Builder.empty(newParams(previousModel1Id, 1024L, 3, 2)) + .addRoutingEntry("node-1", new RoutingInfo(2, 2, RoutingState.STARTED, "")) + .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + ) + .addNewAssignment( + previousModel2Id, + TrainedModelAssignment.Builder.empty(newParams(previousModel2Id, 1024L, 4, 1)) + .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + ) + .build(); + Map nodeLoads = new HashMap<>(); + long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes(); + nodeLoads.put(buildNode("node-1", nodeMemoryBytes, 4), NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build()); + + TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.empty()) + .rebalance() + .build(); + + assertThat(result.modelAssignments(), is(aMapWithSize(2))); + + { + TrainedModelAssignment assignment = result.getModelAssignment(previousModel1Id); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTED)); + assertThat(assignment.getNodeRoutingTable(), is(aMapWithSize(1))); + assertThat(assignment.getNodeRoutingTable(), hasKey("node-1")); + assertThat(assignment.getNodeRoutingTable().get("node-1").getCurrentAllocations(), equalTo(2)); + assertThat(assignment.getNodeRoutingTable().get("node-1").getTargetAllocations(), equalTo(2)); + assertThat(assignment.getNodeRoutingTable().get("node-1").getState(), equalTo(RoutingState.STARTED)); + assertThat(assignment.getReason().isPresent(), is(true)); + assertThat( + assignment.getReason().get(), + equalTo( + "Could not assign (more) allocations on node [node-1]. Reason: This node has insufficient allocated processors. " + + "Available processors [4], free processors [0], processors required for each allocation of this model [2]" + ) + ); + } + { + TrainedModelAssignment assignment = result.getModelAssignment(previousModel2Id); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTING)); + assertThat(assignment.getNodeRoutingTable(), is(anEmptyMap())); + assertThat(assignment.getReason().isPresent(), is(true)); + assertThat( + assignment.getReason().get(), + equalTo( + "Could not assign (more) allocations on node [node-1]. Reason: This node has insufficient allocated processors. " + + "Available processors [4], free processors [0], processors required for each allocation of this model [1]" + ) + ); + } + } + + public void testRebalance_GivenPreviousAssignments_AndRemovedNode_AndRemainingNodeLargeEnough() throws Exception { + String previousModel1Id = "previous-model-1"; + String previousModel2Id = "previous-model-2"; + TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() + .addNewAssignment( + previousModel1Id, + TrainedModelAssignment.Builder.empty(newParams(previousModel1Id, 1024L, 3, 2)) + .addRoutingEntry("node-1", new RoutingInfo(2, 2, RoutingState.STARTED, "")) + .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + ) + .addNewAssignment( + previousModel2Id, + TrainedModelAssignment.Builder.empty(newParams(previousModel2Id, 1024L, 1, 1)) + .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + ) + .build(); + Map nodeLoads = new HashMap<>(); + long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes(); + nodeLoads.put(buildNode("node-1", nodeMemoryBytes, 7), NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build()); + + TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.empty()) + .rebalance() + .build(); + + assertThat(result.modelAssignments(), is(aMapWithSize(2))); + + { + TrainedModelAssignment assignment = result.getModelAssignment(previousModel1Id); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTED)); + assertThat(assignment.getNodeRoutingTable(), is(aMapWithSize(1))); + assertThat(assignment.getNodeRoutingTable(), hasKey("node-1")); + assertThat(assignment.getNodeRoutingTable().get("node-1").getCurrentAllocations(), equalTo(2)); + assertThat(assignment.getNodeRoutingTable().get("node-1").getTargetAllocations(), equalTo(3)); + assertThat(assignment.getNodeRoutingTable().get("node-1").getState(), equalTo(RoutingState.STARTED)); + assertThat(assignment.getReason().isPresent(), is(false)); + } + { + TrainedModelAssignment assignment = result.getModelAssignment(previousModel2Id); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTING)); + assertThat(assignment.getNodeRoutingTable(), is(aMapWithSize(1))); + assertThat(assignment.getNodeRoutingTable(), hasKey("node-1")); + assertThat(assignment.getNodeRoutingTable().get("node-1").getCurrentAllocations(), equalTo(1)); + assertThat(assignment.getNodeRoutingTable().get("node-1").getTargetAllocations(), equalTo(1)); + assertThat(assignment.getNodeRoutingTable().get("node-1").getState(), equalTo(RoutingState.STARTING)); + assertThat(assignment.getReason().isPresent(), is(false)); + } + } + + public void testRebalance_GivenFailedAssignment_RestartsAssignment() throws Exception { + String modelId = "model-1"; + TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() + .addNewAssignment( + modelId, + TrainedModelAssignment.Builder.empty(newParams(modelId, 1024L, 1, 1)) + .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.FAILED, "some error")) + ) + .build(); + Map nodeLoads = new HashMap<>(); + long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes(); + nodeLoads.put(buildNode("node-1", nodeMemoryBytes, 4), NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build()); + + TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.empty()) + .rebalance() + .build(); + + assertThat(result.modelAssignments(), is(aMapWithSize(1))); + + TrainedModelAssignment assignment = result.getModelAssignment(modelId); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTING)); + assertThat(assignment.getNodeRoutingTable(), is(aMapWithSize(1))); + assertThat(assignment.getNodeRoutingTable(), hasKey("node-1")); + assertThat(assignment.getNodeRoutingTable().get("node-1").getCurrentAllocations(), equalTo(1)); + assertThat(assignment.getNodeRoutingTable().get("node-1").getTargetAllocations(), equalTo(1)); + assertThat(assignment.getNodeRoutingTable().get("node-1").getState(), equalTo(RoutingState.STARTING)); + assertThat(assignment.getReason().isPresent(), is(false)); + } + + private static StartTrainedModelDeploymentAction.TaskParams newParams( + String modelId, + long modelSize, + int numberOfAllocations, + int threadsPerAllocation + ) { + return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, threadsPerAllocation, numberOfAllocations, 1024); + } + + private static DiscoveryNode buildNode(String name, long nativeMemory, int allocatedProcessors) { + return new DiscoveryNode( + name, + name, + buildNewFakeTransportAddress(), + MapBuilder.newMapBuilder() + .put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, String.valueOf(nativeMemory)) + .put(MachineLearning.MAX_JVM_SIZE_NODE_ATTR, String.valueOf(10)) + .put(MachineLearning.ALLOCATED_PROCESSORS_NODE_ATTR, String.valueOf(allocatedProcessors)) + .map(), + DiscoveryNodeRole.roles(), + Version.CURRENT + ); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java index 196fd01652963..0e5d8a69eafbc 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java @@ -382,18 +382,18 @@ private static Model randomModel(String idSuffix) { } private static void assertPreviousAssignmentsAreSatisfied(List models, AssignmentPlan assignmentPlan) { - for (Model m : models.stream().filter(m -> m.currentAllocationByNodeId().isEmpty() == false).toList()) { + for (Model m : models.stream().filter(m -> m.currentAllocationsByNodeId().isEmpty() == false).toList()) { Map assignments = assignmentPlan.assignments(m).get(); Set assignedNodeIds = new HashSet<>(); int allocations = 0; for (Map.Entry e : assignments.entrySet()) { assignedNodeIds.add(e.getKey().id()); - if (m.currentAllocationByNodeId().containsKey(e.getKey().id())) { + if (m.currentAllocationsByNodeId().containsKey(e.getKey().id())) { assertThat(e.getValue(), greaterThanOrEqualTo(1)); } allocations += e.getValue(); } - assertThat(m.currentAllocationByNodeId().keySet(), everyItem(in(assignedNodeIds))); + assertThat(m.currentAllocationsByNodeId().keySet(), everyItem(in(assignedNodeIds))); assertThat(allocations, greaterThanOrEqualTo(m.getPreviouslyAssignedAllocations())); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java index a9083c02f619c..7add808f37978 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java @@ -60,13 +60,13 @@ public void testGivenPreviousAssignments() { assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(30L)); assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).currentAllocationByNodeId(), equalTo(Map.of("n_1", 0))); + assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(50L)); assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(3)); assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); - assertThat(modelsPreservingAllocations.get(1).currentAllocationByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 0))); + assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 0))); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(model1, model2)) .assignModelToNode(model1, node1, 2) @@ -78,6 +78,10 @@ public void testGivenPreviousAssignments() { assertThat(plan.assignments(model1).get(), equalTo(Map.of(node1, 3))); assertThat(plan.assignments(model2).get(), equalTo(Map.of(node1, 1, node2, 2))); + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(20L)); + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); + assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(50L)); + assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); } public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments() { @@ -91,5 +95,7 @@ public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments plan = preserveAllAllocations.mergePreservedAllocations(plan); assertThat(plan.assignments(model).isPresent(), is(true)); assertThat(plan.assignments(model).get(), equalTo(Map.of(node, 2))); + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(70L)); + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(0)); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java index f32e7fc26ba71..7c8ea92cd8d49 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java @@ -60,13 +60,13 @@ public void testGivenPreviousAssignments() { assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(30L)); assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).currentAllocationByNodeId(), equalTo(Map.of("n_1", 0))); + assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(50L)); assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(4)); assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); - assertThat(modelsPreservingAllocations.get(1).currentAllocationByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 1))); + assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 1))); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(model1, model2)) .assignModelToNode(model1, node1, 2) @@ -79,6 +79,10 @@ public void testGivenPreviousAssignments() { assertThat(plan.assignments(model1).get(), equalTo(Map.of(node1, 3))); assertThat(plan.assignments(model2).get(), equalTo(Map.of(node1, 1, node2, 2))); + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(20L)); + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); + assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(50L)); + assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); } public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments() { @@ -92,5 +96,7 @@ public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments plan = preserveOneAllocation.mergePreservedAllocations(plan); assertThat(plan.assignments(model).isPresent(), is(true)); assertThat(plan.assignments(model).get(), equalTo(Map.of(node, 1))); + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(70L)); + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(2)); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java index 66acac9cf7fb1..10b2813603d59 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java @@ -26,7 +26,7 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; -import static org.elasticsearch.xpack.ml.MachineLearning.JOB_COMMS_THREAD_POOL_NAME; +import static org.elasticsearch.xpack.ml.MachineLearning.NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME; import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -50,12 +50,12 @@ public void managerSetup() { "xpack.ml.utility_thread_pool" ), new ScalingExecutorBuilder( - JOB_COMMS_THREAD_POOL_NAME, + NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME, 1, 4, TimeValue.timeValueMinutes(10), false, - "xpack.ml.job_comms_thread_pool" + "xpack.ml.native_inference_comms_thread_pool" ) ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java index 27a842f57e958..c3f8871bab8be 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java @@ -18,8 +18,8 @@ import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; -import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.core.ml.job.config.JobState; import org.elasticsearch.xpack.ml.MachineLearning; @@ -128,12 +128,12 @@ public void testNodeLoadDetection() { TrainedModelAssignment.Builder.empty( new StartTrainedModelDeploymentAction.TaskParams("model1", MODEL_MEMORY_REQUIREMENT, 1, 1, 1024) ) - .addNewRoutingEntry("_node_id4") - .addNewFailedRoutingEntry("_node_id2", "test") - .addNewRoutingEntry("_node_id1") + .addRoutingEntry("_node_id4", new RoutingInfo(1, 1, RoutingState.STARTING, "")) + .addRoutingEntry("_node_id2", new RoutingInfo(1, 1, RoutingState.FAILED, "test")) + .addRoutingEntry("_node_id1", new RoutingInfo(1, 1, RoutingState.STARTING, "")) .updateExistingRoutingEntry( "_node_id1", - new RoutingStateAndReason(randomFrom(RoutingState.STOPPED, RoutingState.FAILED), "test") + new RoutingInfo(1, 1, randomFrom(RoutingState.STOPPED, RoutingState.FAILED), "test") ) ) .build() @@ -144,28 +144,28 @@ public void testNodeLoadDetection() { NodeLoad load = nodeLoadDetector.detectNodeLoad(cs, nodes.get("_node_id1"), 10, 30, false); assertThat(load.getAssignedJobMemory(), equalTo(52428800L)); assertThat(load.getNumAllocatingJobs(), equalTo(2)); - assertThat(load.getNumAssignedJobs(), equalTo(2)); + assertThat(load.getNumAssignedJobsAndModels(), equalTo(2)); assertThat(load.getMaxJobs(), equalTo(10)); assertThat(load.getMaxMlMemory(), equalTo(0L)); load = nodeLoadDetector.detectNodeLoad(cs, nodes.get("_node_id2"), 5, 30, false); assertThat(load.getAssignedJobMemory(), equalTo(41943040L)); assertThat(load.getNumAllocatingJobs(), equalTo(1)); - assertThat(load.getNumAssignedJobs(), equalTo(1)); + assertThat(load.getNumAssignedJobsAndModels(), equalTo(1)); assertThat(load.getMaxJobs(), equalTo(5)); assertThat(load.getMaxMlMemory(), equalTo(0L)); load = nodeLoadDetector.detectNodeLoad(cs, nodes.get("_node_id3"), 5, 30, false); assertThat(load.getAssignedJobMemory(), equalTo(0L)); assertThat(load.getNumAllocatingJobs(), equalTo(0)); - assertThat(load.getNumAssignedJobs(), equalTo(0)); + assertThat(load.getNumAssignedJobsAndModels(), equalTo(0)); assertThat(load.getMaxJobs(), equalTo(5)); assertThat(load.getMaxMlMemory(), equalTo(0L)); load = nodeLoadDetector.detectNodeLoad(cs, nodes.get("_node_id4"), 5, 30, false); assertThat(load.getAssignedJobMemory(), equalTo(398458880L)); assertThat(load.getNumAllocatingJobs(), equalTo(0)); - assertThat(load.getNumAssignedJobs(), equalTo(2)); + assertThat(load.getNumAssignedJobsAndModels(), equalTo(2)); assertThat(load.getMaxJobs(), equalTo(5)); assertThat(load.getMaxMlMemory(), equalTo(0L)); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadTests.java index a2b1696c2b3c2..b20876ac6364a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadTests.java @@ -19,12 +19,12 @@ public void testIncrementCounts() { .incNumAssignedAnomalyDetectorJobs() .incNumAssignedDataFrameAnalyticsJobs() .incNumAssignedDataFrameAnalyticsJobs() - .incNumAssignedNativeInferenceJobs() - .incNumAssignedNativeInferenceJobs() - .incNumAssignedNativeInferenceJobs() + .incNumAssignedNativeInferenceModels() + .incNumAssignedNativeInferenceModels() + .incNumAssignedNativeInferenceModels() .build(); - assertThat(nodeLoad.getNumAssignedJobs(), equalTo(6)); - assertThat(nodeLoad.remainingJobs(), equalTo(4)); + assertThat(nodeLoad.getNumAssignedJobsAndModels(), equalTo(6)); + assertThat(nodeLoad.remainingJobs(), equalTo(7)); } } diff --git a/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java b/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java index 25621a903f7fa..f1c7c04905bea 100644 --- a/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java +++ b/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java @@ -68,6 +68,7 @@ public void setLogging() throws IOException { loggingSettings.setJsonEntity(""" {"persistent" : { "logger.org.elasticsearch.xpack.ml.inference.assignment" : "TRACE", + "logger.org.elasticsearch.xpack.ml.process.assignment.planning" : "TRACE", "logger.org.elasticsearch.xpack.ml.inference.deployment" : "TRACE", "logger.org.elasticsearch.xpack.ml.process.logging" : "TRACE" }}"""); diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java index fe63ab6f56da4..2c7cb5dbb14c6 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java @@ -141,11 +141,6 @@ private void waitForDeploymentStarted(String modelId) throws Exception { List> stats = (List>) map.get("trained_model_stats"); assertThat(stats, hasSize(1)); var stat = stats.get(0); - assertThat( - stat.toString(), - XContentMapValues.extractValue("deployment_stats.allocation_status.state", stat), - equalTo("fully_allocated") - ); assertThat(stat.toString(), XContentMapValues.extractValue("deployment_stats.state", stat), equalTo("started")); }, 30, TimeUnit.SECONDS); }