Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding node_count to ML Usage (#33850) #33863

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
Expand All @@ -24,28 +25,39 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
public static final String DETECTORS = "detectors";
public static final String FORECASTS = "forecasts";
public static final String MODEL_SIZE = "model_size";
public static final String NODE_COUNT = "node_count";

private final Map<String, Object> jobsUsage;
private final Map<String, Object> datafeedsUsage;
private final int nodeCount;

public MachineLearningFeatureSetUsage(boolean available, boolean enabled, Map<String, Object> jobsUsage,
Map<String, Object> datafeedsUsage) {
Map<String, Object> datafeedsUsage, int nodeCount) {
super(XPackField.MACHINE_LEARNING, available, enabled);
this.jobsUsage = Objects.requireNonNull(jobsUsage);
this.datafeedsUsage = Objects.requireNonNull(datafeedsUsage);
this.nodeCount = nodeCount;
}

public MachineLearningFeatureSetUsage(StreamInput in) throws IOException {
super(in);
this.jobsUsage = in.readMap();
this.datafeedsUsage = in.readMap();
if (in.getVersion().onOrAfter(Version.V_6_5_0)) {
this.nodeCount = in.readInt();
} else {
this.nodeCount = -1;
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeMap(jobsUsage);
out.writeMap(datafeedsUsage);
if (out.getVersion().onOrAfter(Version.V_6_5_0)) {
out.writeInt(nodeCount);
}
}

@Override
Expand All @@ -57,6 +69,9 @@ protected void innerXContent(XContentBuilder builder, Params params) throws IOEx
if (datafeedsUsage != null) {
builder.field(DATAFEEDS_FIELD, datafeedsUsage);
}
if (nodeCount >= 0) {
builder.field(NODE_COUNT, nodeCount);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.inject.Inject;
Expand Down Expand Up @@ -132,7 +133,22 @@ public Map<String, Object> nativeCodeInfo() {
@Override
public void usage(ActionListener<XPackFeatureSet.Usage> listener) {
ClusterState state = clusterService.state();
new Retriever(client, MlMetadata.getMlMetadata(state), available(), enabled()).execute(listener);
new Retriever(client, MlMetadata.getMlMetadata(state), available(), enabled(), mlNodeCount(state)).execute(listener);
}

private int mlNodeCount(final ClusterState clusterState) {
if (enabled == false) {
return 0;
}

int mlNodeCount = 0;
for (DiscoveryNode node : clusterState.getNodes()) {
String enabled = node.getAttributes().get(MachineLearning.ML_ENABLED_NODE_ATTR);
if (Boolean.parseBoolean(enabled)) {
++mlNodeCount;
}
}
return mlNodeCount;
}

public static class Retriever {
Expand All @@ -143,19 +159,22 @@ public static class Retriever {
private final boolean enabled;
private Map<String, Object> jobsUsage;
private Map<String, Object> datafeedsUsage;
private int nodeCount;

public Retriever(Client client, MlMetadata mlMetadata, boolean available, boolean enabled) {
public Retriever(Client client, MlMetadata mlMetadata, boolean available, boolean enabled, int nodeCount) {
this.client = Objects.requireNonNull(client);
this.mlMetadata = mlMetadata;
this.available = available;
this.enabled = enabled;
this.jobsUsage = new LinkedHashMap<>();
this.datafeedsUsage = new LinkedHashMap<>();
this.nodeCount = nodeCount;
}

public void execute(ActionListener<Usage> listener) {
if (enabled == false) {
listener.onResponse(new MachineLearningFeatureSetUsage(available, enabled, Collections.emptyMap(), Collections.emptyMap()));
listener.onResponse(
new MachineLearningFeatureSetUsage(available, enabled, Collections.emptyMap(), Collections.emptyMap(), 0));
return;
}

Expand All @@ -164,11 +183,9 @@ public void execute(ActionListener<Usage> listener) {
ActionListener.wrap(response -> {
addDatafeedsUsage(response);
listener.onResponse(new MachineLearningFeatureSetUsage(
available, enabled, jobsUsage, datafeedsUsage));
available, enabled, jobsUsage, datafeedsUsage, nodeCount));
},
error -> {
listener.onFailure(error);
}
listener::onFailure
);

// Step 1. Extract usage from jobs stats and then request stats for all datafeeds
Expand All @@ -181,9 +198,7 @@ public void execute(ActionListener<Usage> listener) {
client.execute(GetDatafeedsStatsAction.INSTANCE, datafeedStatsRequest,
datafeedStatsListener);
},
error -> {
listener.onFailure(error);
}
listener::onFailure
);

// Step 0. Kick off the chain of callbacks by requesting jobs stats
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@
package org.elasticsearch.xpack.ml;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
Expand Down Expand Up @@ -46,7 +51,11 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;
Expand Down Expand Up @@ -223,6 +232,49 @@ public void testUsage() throws Exception {
}
}

public void testNodeCount() throws Exception {
when(licenseState.isMachineLearningAllowed()).thenReturn(true);
int nodeCount = randomIntBetween(1, 3);
givenNodeCount(nodeCount);
Settings.Builder settings = Settings.builder().put(commonSettings);
settings.put("xpack.ml.enabled", true);
MachineLearningFeatureSet featureSet = new MachineLearningFeatureSet(TestEnvironment.newEnvironment(settings.build()),
clusterService, client, licenseState);

PlainActionFuture<Usage> future = new PlainActionFuture<>();
featureSet.usage(future);
XPackFeatureSet.Usage usage = future.get();

assertThat(usage.available(), is(true));
assertThat(usage.enabled(), is(true));

BytesStreamOutput out = new BytesStreamOutput();
usage.writeTo(out);
XPackFeatureSet.Usage serializedUsage = new MachineLearningFeatureSetUsage(out.bytes().streamInput());

XContentSource source;
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
serializedUsage.toXContent(builder, ToXContent.EMPTY_PARAMS);
source = new XContentSource(builder);
}
assertThat(source.getValue("node_count"), equalTo(nodeCount));

BytesStreamOutput oldOut = new BytesStreamOutput();
oldOut.setVersion(Version.V_6_0_0);
usage.writeTo(oldOut);
StreamInput oldInput = oldOut.bytes().streamInput();
oldInput.setVersion(Version.V_6_0_0);
XPackFeatureSet.Usage oldSerializedUsage = new MachineLearningFeatureSetUsage(oldInput);

XContentSource oldSource;
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
oldSerializedUsage.toXContent(builder, ToXContent.EMPTY_PARAMS);
oldSource = new XContentSource(builder);
}

assertNull(oldSource.getValue("node_count"));
}

public void testUsageGivenMlMetadataNotInstalled() throws Exception {
when(licenseState.isMachineLearningAllowed()).thenReturn(true);
Settings.Builder settings = Settings.builder().put(commonSettings);
Expand Down Expand Up @@ -286,6 +338,37 @@ private void givenJobs(List<Job> jobs, List<GetJobsStatsAction.Response.JobStats
}).when(client).execute(same(GetJobsStatsAction.INSTANCE), any(), any());
}

private void givenNodeCount(int nodeCount) {
DiscoveryNodes.Builder nodesBuilder = DiscoveryNodes.builder();
for (int i = 0; i < nodeCount; i++) {
Map<String, String> attrs = new HashMap<>();
attrs.put(MachineLearning.ML_ENABLED_NODE_ATTR, Boolean.toString(true));
Set<DiscoveryNode.Role> roles = new HashSet<>();
roles.add(DiscoveryNode.Role.DATA);
roles.add(DiscoveryNode.Role.MASTER);
roles.add(DiscoveryNode.Role.INGEST);
nodesBuilder.add(new DiscoveryNode(randomAlphaOfLength(i+1),
new TransportAddress(TransportAddress.META_ADDRESS, 9100 + i),
attrs,
roles,
Version.CURRENT));
}
for (int i = 0; i < randomIntBetween(1, 3); i++) {
Map<String, String> attrs = new HashMap<>();
Set<DiscoveryNode.Role> roles = new HashSet<>();
roles.add(DiscoveryNode.Role.DATA);
roles.add(DiscoveryNode.Role.MASTER);
roles.add(DiscoveryNode.Role.INGEST);
nodesBuilder.add(new DiscoveryNode(randomAlphaOfLength(i+1),
new TransportAddress(TransportAddress.META_ADDRESS, 9300 + i),
attrs,
roles,
Version.CURRENT));
}
ClusterState clusterState = new ClusterState.Builder(ClusterState.EMPTY_STATE).nodes(nodesBuilder.build()).build();
when(clusterService.state()).thenReturn(clusterState);
}

private void givenDatafeeds(List<GetDatafeedsStatsAction.Response.DatafeedStats> datafeedStats) {
doAnswer(invocationOnMock -> {
ActionListener<GetDatafeedsStatsAction.Response> listener =
Expand Down